import numpy as np
import random
import torch
import torch.nn as nn
import torch.functional as F
from torchvision import transforms, datasets
from PIL import Image

def to_img(x):
x = 0.5 * (x + 1) # [-1,1] => [0, 1]
x = x.clamp(0, 1)
return x

def make_images_with_class_labels(images, classes, num_classes):
labels = torch.eye(num_classes)[classes] # [batch_size, img_channels]
labels = labels.view(labels.size(0), labels.size(1), 1, 1) # [batch_size, img_channels, 1, 1]
# [batch_size, img_channels, width_orig, width_orig]
labels = labels.expand(labels.size(0), labels.size(1), images.size(2), images.size(3))
return torch.cat([images, labels], dim=1)

def make_noise_with_class_labels(classes, num_classes, latent_size):
batch_size = classes.shape[0]
class_labels = torch.from_numpy(
np.array([np.eye(num_classes)[c] for c in classes], dtype=np.float32))
noise = torch.randn(batch_size, latent_size)
return torch.cat([noise, class_labels], dim=1)

# param
num_epochs = 50
batch_size = 32
width_orig = 64
img_channels = 3
latent_size = 100
nch_d = nch_g = 64
use_cuda = True

# data
transform = transforms.Compose([
transforms.Resize((width_orig, width_orig)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
dataset = datasets.ImageFolder(root='/root/dataset/celeba_hair/',
transform=transform)
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=batch_size,
shuffle=True)
num_classes = len(dataset.classes)

class View(nn.Module):
def __init__(self):
super(View, self).__init__()
def forward(self, x):
return x.view(-1)

class Discriminator(nn.Module):
def __init__(self, nch=3, nch_d=64):
super(Discriminator, self).__init__()
self.layers = nn.Sequential(
nn.Conv2d(nch, nch_d, 4, 2, 1),
nn.BatchNorm2d(nch_d),
nn.LeakyReLU(negative_slope=0.2),
nn.Conv2d(nch_d, nch_d * 2, 4, 2, 1),
nn.BatchNorm2d(nch_d * 2),
nn.LeakyReLU(negative_slope=0.2),
nn.Conv2d(nch_d * 2, nch_d * 4, 4, 2, 1),
nn.BatchNorm2d(nch_d * 4),
nn.LeakyReLU(negative_slope=0.2),
nn.Conv2d(nch_d * 4, nch_d * 8, 4, 2, 1),
nn.BatchNorm2d(nch_d * 8),
nn.LeakyReLU(negative_slope=0.2),
nn.Conv2d(nch_d * 8, 1, 4, 1, 0),
View(),
nn.Sigmoid()
)
def forward(self, x):
return self.layers(x)

class Generator(nn.Module):
def __init__(self, latent_size, img_channels, nch_g):
super(Generator, self).__init__()
self.latent_size = latent_size
self.layers = nn.Sequential(
nn.ConvTranspose2d(latent_size, nch_g * 8, 4, 1, 0),
nn.BatchNorm2d(nch_g * 8),
nn.ReLU(),
nn.ConvTranspose2d(nch_g * 8, nch_g * 4, 4, 2, 1),
nn.BatchNorm2d(nch_g * 4),
nn.ReLU(),
nn.ConvTranspose2d(nch_g * 4, nch_g * 2, 4, 2, 1),
nn.BatchNorm2d(nch_g * 2),
nn.ReLU(),
nn.ConvTranspose2d(nch_g * 2, nch_g, 4, 2, 1),
nn.BatchNorm2d(nch_g),
nn.ReLU(),
nn.ConvTranspose2d(nch_g, img_channels, 4, 2, 1),
nn.Tanh()
)

def forward(self, x):
x = x.view(-1, self.latent_size, 1, 1)
return self.layers(x)

D = Discriminator(img_channels + num_classes, nch_d)
G = Generator(latent_size + num_classes, img_channels, nch_g)

# Device setting
device = torch.device("cuda" if use_cuda else "cpu")
D = D.to(device)
G = G.to(device)

# Loss and Optimizer
criterion = nn.BCELoss() # binary cross entropy
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

# label
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)

for epoch in range(num_epochs):
print(epoch)
for i, (real_images, real_classes) in enumerate(data_loader):

if real_images.size(0) != batch_size:
continue

real_images_with_labels = make_images_with_class_labels(
real_images, real_classes.numpy(), num_classes).to(device)
outputs = D(real_images_with_labels)
d_loss_real = criterion(outputs, real_labels)

random_classes = np.array([random.randint(0, num_classes - 1) for i in range(batch_size)])
z = make_noise_with_class_labels(random_classes, num_classes, latent_size).to(device)
fake_images = G(z)
fake_images_with_label = make_images_with_class_labels(
fake_images.cpu(), random_classes, num_classes).to(device)
outputs = D(fake_images_with_label)
d_loss_fake = criterion(outputs, fake_labels)
# optimize discriminator
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()

# loss of generator
random_classes = np.array([random.randint(0, num_classes - 1) for i in range(batch_size)])
z = make_noise_with_class_labels(random_classes, num_classes, latent_size).to(device)
fake_images = G(z)
fake_images_with_label = make_images_with_class_labels(
fake_images.cpu(), random_classes, num_classes).to(device)
outputs = D(fake_images_with_label)
g_loss = criterion(outputs, real_labels)

# optimize discriminator
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()

if i % 200 == 0:
print(epoch, i * batch_size,
d_loss_real.item(), d_loss_fake.item(), d_loss.item(), g_loss.item())

if epoch % 1 == 0:
z = make_noise_with_class_labels(
np.array([i for i in range(num_classes)]), num_classes, latent_size).to(device)
fake_images = G(z)
show_images = fake_images.view(num_classes, img_channels, width_orig, width_orig) # [-1, 1]
show_images = to_img(show_images)
for c in range(num_classes):
pilImg = transforms.ToPILImage()(show_images.detach().cpu()[c]) # [0, 255]
pilImg = pilImg.resize((128, 128), Image.BILINEAR)
pilImg.show()