In [2]:
!mkdir diff-run

In [3]:
!mkdir diff-run/images

In [4]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import datetime

In [5]:
torch.manual_seed(1)

<torch._C.Generator at 0x7f645c183f90>

In [6]:
writer = SummaryWriter('diff-run/py-gan')

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 128

In [8]:
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])])
train_dataset = datasets.FashionMNIST(root='./data/', train=True, transform=train_transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz



HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0…

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [9]:
image_shape = (1, 28, 28)
image_dim = int(np.prod(image_shape))
latent_dim = 100

In [10]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.model = nn.Sequential(nn.Linear(latent_dim, 128),
                                    nn.LeakyReLU(0.2, inplace=True),
                                    nn.Linear(128, 256),
                      nn.BatchNorm1d(256, 0.8),
                                    nn.LeakyReLU(0.2, inplace=True), 
                                    nn.Linear(256, 512),
                      nn.BatchNorm1d(512, 0.8),
                                    nn.LeakyReLU(0.2, inplace=True), 
                      nn.Linear(512, 1024),
                      nn.BatchNorm1d(1024, 0.8),
                                    nn.LeakyReLU(0.2, inplace=True), 
                                    nn.Linear(1024, image_dim),
                                    nn.Tanh())

    def forward(self, noise_vector): 
        image = self.model(noise_vector)
        image = image.view(image.size(0), *image_shape)
        return image




In [11]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(nn.Linear(image_dim, 512),
                                    nn.LeakyReLU(0.2, inplace=True),
                                    nn.Linear(512, 256),
                                    nn.LeakyReLU(0.2, inplace=True),
                                    nn.Linear(256, 1),
                                    nn.Sigmoid())

    def forward(self, image):
        image_flattened = image.view(image.size(0), -1)
        result = self.model(image_flattened)
        return result     

In [12]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [13]:
#torch.save(generator.state_dict(), 'generator.pth')

In [14]:
# for layer in generator.children():
#     print(layer.type)

In [15]:
summary(generator, (100,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 128]          12,928
         LeakyReLU-2                  [-1, 128]               0
            Linear-3                  [-1, 256]          33,024
       BatchNorm1d-4                  [-1, 256]             512
         LeakyReLU-5                  [-1, 256]               0
            Linear-6                  [-1, 512]         131,584
       BatchNorm1d-7                  [-1, 512]           1,024
         LeakyReLU-8                  [-1, 512]               0
            Linear-9                 [-1, 1024]         525,312
      BatchNorm1d-10                 [-1, 1024]           2,048
        LeakyReLU-11                 [-1, 1024]               0
           Linear-12                  [-1, 784]         803,600
             Tanh-13                  [-1, 784]               0
Total params: 1,510,032
Trainable param

In [16]:
summary(discriminator, (1,28,28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 512]         401,920
         LeakyReLU-2                  [-1, 512]               0
            Linear-3                  [-1, 256]         131,328
         LeakyReLU-4                  [-1, 256]               0
            Linear-5                    [-1, 1]             257
           Sigmoid-6                    [-1, 1]               0
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 2.04
Estimated Total Size (MB): 2.05
----------------------------------------------------------------


In [17]:
adversarial_loss = nn.BCELoss() 

In [18]:
learning_rate = 0.0002 
G_optimizer = optim.Adam(generator.parameters(), lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

In [19]:
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [None]:
num_epochs = 500
D_loss_plot, G_loss_plot = [], []
for epoch in range(1, num_epochs+1): 

    D_loss_list, G_loss_list = [], []
   
    for index, (real_images, _) in enumerate(train_loader):
        D_optimizer.zero_grad()
        real_images = real_images.to(device)
        real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
        fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))

        D_real_loss = adversarial_loss(discriminator(real_images), real_target)
        # print(discriminator(real_images))

        noise_vector = Variable(torch.randn(real_images.size(0), latent_dim).to(device))
        #noise_vector = Variable(Tensor(np.random.normal(0, 1, \
        #                                                (real_images.size(0),\
        #                                                 latent_dim))))
        noise_vector = noise_vector.to(device)
        generated_image = generator(noise_vector)

        D_fake_loss = adversarial_loss(discriminator(generated_image),\
                                     fake_target)

        D_total_loss = D_real_loss + D_fake_loss
        D_loss_list.append(D_total_loss)
        D_total_loss.backward()
        D_optimizer.step()

        G_optimizer.zero_grad()
        generated_image = generator(noise_vector)
        G_loss = adversarial_loss(discriminator(generated_image), real_target)
        G_loss_list.append(G_loss)

        G_loss.backward()
        G_optimizer.step()
        d = generated_image.data

        writer.add_scalar('Discriminator Loss',
                            D_total_loss,
                            epoch * len(train_loader) + index)

        writer.add_scalar('Generator Loss',
                            G_loss,
                            epoch * len(train_loader) + index)


    print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
            (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)),\
             torch.mean(torch.FloatTensor(G_loss_list))))

    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))
    save_image(generated_image.data[:90], 'diff-run/images/sample_%d'%epoch + '.png', nrow=10, normalize=True)