In [None]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
bs = 100

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [None]:
class Generator_Network(nn.Module):
    def __init__(self, latent_vector_dim, output_image_dim):
        super(Generator_Network, self).__init__()       
        self.fc1 = nn.Linear(latent_vector_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, output_image_dim)
    
    # forward method
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
class Discriminator_Network(nn.Module):
    def __init__(self, input_image_dim):
        super(Discriminator_Network, self).__init__()
        self.fc1 = nn.Linear(input_image_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [None]:
# build network
latent_vector_dim = 100
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)

Generator = Generator_Network(latent_vector_dim, mnist_dim).to(device)
Discriminator = Discriminator_Network(mnist_dim).to(device)



In [None]:

def Discriminator_train(x):
    #=======================Train the discriminator=======================#
    Discriminator.zero_grad()

    # train discriminator on real
    x_real, y_real = x.view(-1, mnist_dim), torch.ones(bs, 1)
    x_real = Variable(x_real.to(device))
    y_real = Variable(y_real.to(device))

    Discriminator_output = Discriminator(x_real)
    Discriminator_real_loss = criterion(Discriminator_output, y_real)

    # train discriminator on fake
    #latent_vector_dim = 100
    #z = (batch_size, 100) - > ouputs (batch_size, 28*28)
    z = Variable(torch.randn(bs, latent_vector_dim).to(device))
    x_fake = Generator(z)
    y_fake = Variable(torch.zeros(bs, 1).to(device))

    Discriminator_output = Discriminator(x_fake)
    Discriminator_fake_loss = criterion(Discriminator_output, y_fake)

    # gradient backprop & optimize ONLY D's parameters
    Discriminator_loss = Discriminator_real_loss + Discriminator_fake_loss
    Discriminator_loss.backward()
    D_optimizer.step()
        
    return  Discriminator_loss.data.item()

In [None]:
def Generator_train(x):
    #=======================Train the generator=======================#
    Generator.zero_grad()

    latent_vector = Variable(torch.randn(bs, latent_vector_dim ).to(device))
    labels = Variable(torch.ones(bs, 1).to(device))

    Generator_output = Generator(latent_vector)
    Discriminator_output = Discriminator(Generator_output)
    Generator_loss = criterion(Discriminator_output, labels)

    # gradient backprop & optimize ONLY G's parameters
    Generator_loss.backward()
    G_optimizer.step()
        
    return Generator_loss.data.item()


In [None]:
# loss
criterion = nn.BCELoss() 

# optimizer
lr = 0.0002 
G_optimizer = optim.Adam(Generator.parameters(), lr = lr)
D_optimizer = optim.Adam(Discriminator.parameters(), lr = lr)

In [None]:
n_epoch = 200
for epoch in range(1, n_epoch+1):           
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(Discriminator_train(x))
        G_losses.append(Generator_train(x))

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
    
    if epoch % 20 == 0:
      print('inside')
      with torch.no_grad():
        test_z = Variable(torch.randn(bs, latent_vector_dim).to(device))
        generated = Generator(test_z)

        save_image(generated.view(generated.size(0), 1, 28, 28), 'sample_' + str(epoch) + '.png')


[1/200]: loss_d: 0.418, loss_g: 3.400
[2/200]: loss_d: 0.475, loss_g: 3.317
[3/200]: loss_d: 0.474, loss_g: 3.103
[4/200]: loss_d: 0.502, loss_g: 2.986
[5/200]: loss_d: 0.532, loss_g: 2.814
[6/200]: loss_d: 0.580, loss_g: 2.571
[7/200]: loss_d: 0.656, loss_g: 2.407
[8/200]: loss_d: 0.662, loss_g: 2.228
[9/200]: loss_d: 0.675, loss_g: 2.277
[10/200]: loss_d: 0.708, loss_g: 2.145
[11/200]: loss_d: 0.782, loss_g: 1.852
[12/200]: loss_d: 0.772, loss_g: 1.909
[13/200]: loss_d: 0.770, loss_g: 2.030
[14/200]: loss_d: 0.842, loss_g: 1.779
[15/200]: loss_d: 0.841, loss_g: 1.759
[16/200]: loss_d: 0.856, loss_g: 1.720
[17/200]: loss_d: 0.873, loss_g: 1.720
[18/200]: loss_d: 0.877, loss_g: 1.698
[19/200]: loss_d: 0.894, loss_g: 1.622
[20/200]: loss_d: 0.942, loss_g: 1.539
inside
[21/200]: loss_d: 0.955, loss_g: 1.516
[22/200]: loss_d: 0.941, loss_g: 1.507
[23/200]: loss_d: 0.970, loss_g: 1.468
[24/200]: loss_d: 0.965, loss_g: 1.461
[25/200]: loss_d: 1.005, loss_g: 1.392
[26/200]: loss_d: 1.004, lo

In [None]:
      print('inside')
      with torch.no_grad():
        test_z = Variable(torch.randn(bs, latent_vector_dim).to(device))
        generated = Generator(test_z)

        save_image(generated.view(generated.size(0), 1, 28, 28), 'sample_' + str(epoch) + '.png')

inside
