import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from PIL import Image

def to_img(x):
x = 0.5 * (x + 1) # [-1,1] => [0, 1]
x = x.clamp(0, 1)
return x

# param
num_epochs = 50
batch_size = 32
width_orig = 64
width_resize = int(width_orig / 4)
img_channels = 3
hidden_size = 1024
latent_size = 64
use_cuda = False

# data
transform = transforms.Compose([
transforms.Resize((width_orig, width_orig)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
dataset = datasets.ImageFolder(root='/root/dataset/img_align_celeba/',
transform=transform)
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
shuffle=True)

class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()

self.conv = nn.Sequential(
nn.Conv2d(img_channels, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64), # add
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
)

self.fc = nn.Sequential(
nn.Linear(128 * width_resize * width_resize, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1),
nn.Sigmoid(), # use sigmoid for 2 class classification
)

def forward(self, input):
x = self.conv(input)
x = x.view(-1, 128 * width_resize * width_resize)
x = self.fc(x)
return x


class Generator(nn.Module):

def __init__(self):
super(Generator, self).__init__()

self.fc = nn.Sequential(
nn.Linear(latent_size, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 128 * width_resize * width_resize),
nn.BatchNorm1d(128 * width_resize * width_resize),
nn.ReLU(),
)

self.deconv = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, img_channels, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, input):
x = self.fc(input)
x = x.view(-1, 128, width_resize, width_resize)
x = self.deconv(x)
return x
# make model
D = Discriminator()
G = Generator()
# D.load_state_dict(torch.load('discriminator.pth'))
# G.load_state_dict(torch.load('generator.pth'))

# Device setting
device = torch.device("cuda" if use_cuda else "cpu")
D = D.to(device)
G = G.to(device)

# Loss and Optimizer
criterion = nn.BCELoss() # binary cross entropy
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

# label
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)

for epoch in range(num_epochs):
print(epoch)
for i, (real_images, _) in enumerate(data_loader):

if real_images.size(0) != batch_size:
continue

# loss of discriminator for real data
real_images = real_images.to(device)
outputs = D(real_images)
d_loss_real = criterion(outputs, real_labels)

# loss of discriminator for fake data
z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z)
outputs = D(fake_images)
d_loss_fake = criterion(outputs, fake_labels)

# optimize discriminator
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()

# loss of generator
z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z)
outputs = D(fake_images)
g_loss = criterion(outputs, real_labels)

# optimize discriminator
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()

if i % 200 == 0:
print(epoch, i * batch_size, d_loss_real.item(),
d_loss_fake.item(), d_loss.item(), g_loss.item())

# save model
if epoch % 5 == 0:
torch.save(D.state_dict(), 'discriminator' + str(epoch) + '.pth')
torch.save(G.state_dict(), 'generator' + str(epoch) + '.pth')

if epoch % 1 == 0:
show_images = fake_images.view(batch_size, img_channels, width_orig, width_orig) # [-1, 1]
show_images = to_img(show_images) # [0, 1]
pilImg = transforms.ToPILImage()(show_images.detach().cpu()[c][0]) # [0, 255]
pilImg = pilImg.resize((128, 128), Image.BILINEAR)
pilImg.show()