import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from PIL import Image

def to_img(x):
x = 0.5 * (x + 1) # [-1,1] => [0, 1]
x = x.clamp(0, 1)
return x

num_epochs = 50
batch_size_tmp = 320 # before extraction
width_orig = 28
width_resize = int(width_orig / 4)
img_channels = 1
hidden_size = 1024
latent_size = 64
target = 7 # sneaker
use_cuda = False

# transform to [-1, 1]
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, ), std=(0.5, ))])

# MNIST dataset
mnist = torchvision.datasets.FashionMNIST(root='./data/',
train=True,
transform=transform,
download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=batch_size_tmp,
shuffle=True)

class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()

self.conv = nn.Sequential(
nn.Conv2d(img_channels, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64), # add
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
)

self.fc = nn.Sequential(
nn.Linear(128 * width_resize * width_resize, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1),
nn.Sigmoid(), # use sigmoid for 2 class classification
)

def forward(self, input):
x = self.conv(input)
x = x.view(-1, 128 * width_resize * width_resize)
x = self.fc(x)
return x


class Generator(nn.Module):

def __init__(self):
super(Generator, self).__init__()

self.fc = nn.Sequential(
nn.Linear(latent_size, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 128 * width_resize * width_resize),
nn.BatchNorm1d(128 * width_resize * width_resize),
nn.ReLU(),
)

self.deconv = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, img_channels, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, input):
x = self.fc(input)
x = x.view(-1, 128, width_resize, width_resize)
x = self.deconv(x)
return x

# make model
D = Discriminator()
G = Generator()
# D.load_state_dict(torch.load('discriminator.pth'))
# G.load_state_dict(torch.load('generator.pth'))

# Device setting
device = torch.device("cuda" if use_cuda else "cpu")
D = D.to(device)
G = G.to(device)

# Loss and Optimizer
criterion = nn.BCELoss() # binary cross entropy
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

for epoch in range(num_epochs):
print(epoch)
for i, (real_images, targets) in enumerate(data_loader):

# extract target data
real_images = real_images[targets == target]
batch_size = real_images.size(0)

# label
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)

# loss of discriminator for real data
real_images = real_images.to(device)
outputs = D(real_images)
d_loss_real = criterion(outputs, real_labels)

# loss of discriminator for fake data
z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z)
outputs = D(fake_images)
d_loss_fake = criterion(outputs, fake_labels)

# optimize discriminator
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()

# loss of generator
z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z)
outputs = D(fake_images)
g_loss = criterion(outputs, real_labels)

# optimize discriminator
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()

if i == 0:
print(d_loss_real.item(), d_loss_fake.item(), d_loss.item(), g_loss.item())

if epoch % 5 == 0:
show_images = fake_images.view(batch_size, width_orig, width_orig) # [-1, 1]
  show_images = to_img(show_images) # [0, 1]
pilImg = transforms.ToPILImage()(show_images.detach().cpu()[0]) # [0, 255]
pilImg = pilImg.resize((256, 256), Image.BILINEAR)
pilImg.show()

# save model
torch.save(D.state_dict(), 'discriminator.pth')
torch.save(G.state_dict(), 'generator.pth')