# -*- coding: utf-8 -*-
"""GAN_Pytorch_Fashion-MNIST.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1g85f0DkXmfrygkdlXd1GMYA8divxacf9
"""

# Commented out IPython magic to ensure Python compatibility.
# %load_ext tensorboard

import torch
import numpy as np
import argparse
import os
import tqdm
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import datetime

# construct the argument parser
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=128, help="size of the batches")
parser.add_argument("--lr", type=float, default=2e-4, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100, help="dimension of the latent space (generator's input)")
parser.add_argument("--img_size", type=int, default=28, help="image size")
parser.add_argument("--channels", type=int, default=1, help="image channels")
args = parser.parse_args()

torch.manual_seed(1)

os.makedirs('diff-run/py-gan', exist_ok=True)
os.makedirs('diff-run/images', exist_ok=True)

writer = SummaryWriter('diff-run/py-gan')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])])
train_dataset = datasets.FashionMNIST(root='./data/', train=True, transform=train_transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)

image_shape = (args.channels, args.img_size, args.img_size)
image_dim = int(np.prod(image_shape))

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.model = nn.Sequential(nn.Linear(args.latent_dim, 128),
                                  nn.LeakyReLU(0.2, inplace=True),
                                  nn.Linear(128, 256),
                                  nn.BatchNorm1d(256, 0.8),
                                  nn.LeakyReLU(0.2, inplace=True), 
                                  nn.Linear(256, 512),
                                  nn.BatchNorm1d(512, 0.8),
                                  nn.LeakyReLU(0.2, inplace=True), 
                                  nn.Linear(512, 1024),
                                  nn.BatchNorm1d(1024, 0.8),
                                  nn.LeakyReLU(0.2, inplace=True), 
                                  nn.Linear(1024, image_dim),
                                  nn.Tanh())
    
    def forward(self, noise_vector): 
        image = self.model(noise_vector)
        image = image.view(image.size(0), *image_shape)
        return image

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(nn.Linear(image_dim, 512),
                                  nn.LeakyReLU(0.2, inplace=True),
                                  nn.Linear(512, 256),
                                  nn.LeakyReLU(0.2, inplace=True),
                                  nn.Linear(256, 1),
                                  nn.Sigmoid())
    
    def forward(self, image):
        image_flattened = image.view(image.size(0), -1)
        result = self.model(image_flattened)
        return result

generator = Generator().to(device)
discriminator = Discriminator().to(device)

#torch.save(generator.state_dict(), 'generator.pth')

# for layer in generator.children():
#     print(layer.type)

summary(generator, (100,))

summary(discriminator, (1,28,28))

adversarial_loss = nn.BCELoss()

G_optimizer = optim.Adam(generator.parameters(), lr=args.lr, betas=(args.b1, args.b2))
D_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.b1, args.b2))

num_epochs = 500
D_loss_plot, G_loss_plot = [], []
for epoch in range(1, args.n_epochs+1): 

    D_loss_list, G_loss_list = [], []
   
    for index, (real_images, _) in enumerate(train_loader):
        D_optimizer.zero_grad()
        real_images = real_images.to(device)
        real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
        fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))

        D_real_loss = adversarial_loss(discriminator(real_images), real_target)
        # print(discriminator(real_images))

        noise_vector = Variable(torch.randn(real_images.size(0), args.latent_dim).to(device))
        noise_vector = noise_vector.to(device)
        generated_image = generator(noise_vector)

        D_fake_loss = adversarial_loss(discriminator(generated_image),\
                                     fake_target)

        D_total_loss = D_real_loss + D_fake_loss
        D_loss_list.append(D_total_loss)
        D_total_loss.backward()
        D_optimizer.step()

        G_optimizer.zero_grad()
        generated_image = generator(noise_vector)
        G_loss = adversarial_loss(discriminator(generated_image), real_target)
        G_loss_list.append(G_loss)

        G_loss.backward()
        G_optimizer.step()
        d = generated_image.data

        writer.add_scalar('Discriminator Loss',
                            D_total_loss,
                            epoch * len(train_loader) + index)

        writer.add_scalar('Generator Loss',
                            G_loss,
                            epoch * len(train_loader) + index)


    print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
            (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)),\
             torch.mean(torch.FloatTensor(G_loss_list))))
    
    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))
    save_image(generated_image.data[:90], 'diff-run/images/sample_%d'%epoch + '.png', nrow=10, normalize=True)