import numpy as np
import random
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from PIL import Image
def to_img(x):
x = 0.5 * (x + 1) # [-1,1] => [0, 1]
x = x.clamp(0, 1)
return x
def make_noise_with_class_labels(classes, batch_size, latent_size, num_classes):
class_labels = torch.from_numpy(np.array([np.eye(num_classes)[c] for c in classes], dtype=np.float32))
noise = torch.randn(batch_size, latent_size)
return torch.cat([noise, class_labels], dim=1)
def make_images_with_class_labels(images, label_images, classes, num_classes):
batch_size = images.size(0)
images_with_labels = torch.zeros(
batch_size, images.size(1) + num_classes, images.size(2), images.size(3))
for b in range(batch_size):
images_with_labels[b] = torch.cat([images[b], label_images[classes[b]]], dim=0)
return images_with_labels
# param
num_epochs = 50
batch_size = 32
width_orig = 28
width_resize = int(width_orig / 4)
img_channels = 1
hidden_size = 1024
latent_size = 64
num_classes = 10
use_cuda = False
# data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, ), std=(0.5, ))])
mnist = torchvision.datasets.MNIST(root='./data/',
train=True,
transform=transform,
download=True)
data_loader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=batch_size,
shuffle=True)
# model
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(img_channels + num_classes, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
)
self.fc = nn.Sequential(
nn.Linear(128 * width_resize * width_resize, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1),
nn.Sigmoid(), # use sigmoid for 2 class classification
)
def forward(self, input):
x = self.conv(input)
x = x.view(-1, 128 * width_resize * width_resize)
x = self.fc(x)
return x
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(latent_size + num_classes, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 128 * width_resize * width_resize),
nn.BatchNorm1d(128 * width_resize * width_resize),
nn.ReLU(),
)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, img_channels, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, input):
x = self.fc(input)
x = x.view(-1, 128, width_resize, width_resize)
x = self.deconv(x)
return x
# make model
D = Discriminator()
G = Generator()
# D.load_state_dict(torch.load('discriminator.pth'))
# G.load_state_dict(torch.load('generator.pth'))
# Device setting
device = torch.device("cuda" if use_cuda else "cpu")
D = D.to(device)
G = G.to(device)
# Loss and Optimizer
criterion = nn.BCELoss() # binary cross entropy
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
# label
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
# label images
label_images = -torch.ones(num_classes, num_classes, width_orig, width_orig)
for i in range(num_classes):
label_images[i, i, :, :] = -label_images[i, i, :, :]
for epoch in range(num_epochs):
print(epoch)
for i, (real_images, targets) in enumerate(data_loader):
targets = targets.numpy()
# loss of discriminator for real data
real_images_with_label = make_images_with_class_labels(
real_images, label_images, targets, num_classes).to(device)
outputs = D(real_images_with_label)
d_loss_real = criterion(outputs, real_labels)
# loss of discriminator for fake data
random_classes = [random.randint(0, num_classes - 1) for i in range(batch_size)]
z = make_noise_with_class_labels(random_classes, batch_size, latent_size, num_classes).to(device)
fake_images = G(z)
fake_images_with_label = make_images_with_class_labels(
fake_images.cpu(), label_images, random_classes, num_classes).to(device)
outputs = D(fake_images_with_label)
d_loss_fake = criterion(outputs, fake_labels)
# optimize discriminator
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# loss of generator
random_classes = [random.randint(0, num_classes - 1) for i in range(batch_size)]
z = make_noise_with_class_labels(random_classes, batch_size, latent_size, num_classes).to(device)
fake_images = G(z)
fake_images_with_label = make_images_with_class_labels(
fake_images.cpu(), label_images, random_classes, num_classes).to(device)
outputs = D(fake_images_with_label)
g_loss = criterion(outputs, real_labels)
# optimize discriminator
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
# save model
if epoch % 5 == 0:
torch.save(D.state_dict(), 'discriminator' + str(epoch) + '.pth')
torch.save(G.state_dict(), 'generator' + str(epoch) + '.pth')
if epoch % 1 == 0:
z = make_noise_with_class_labels(
[i for i in range(num_classes)], num_classes, latent_size, num_classes).to(device)
fake_images = G(z)
show_images = fake_images.view(num_classes, img_channels, width_orig, width_orig) # [-1, 1]
show_images = to_img(show_images)
for c in range(num_classes):
pilImg = transforms.ToPILImage()(show_images.detach().cpu()[c][0]) # [0, 255]
pilImg = pilImg.resize((128, 128), Image.BILINEAR)
pilImg.show()