# IDL 11785 Fall 2024 Lab 12 :  AutoEncoders (AEs) and Variational AutoEncoders (VAEs)

## Presented By: Gabrial Zencha & Puru Samal
#### Picture Credits: Dareen Alharthi, Harshith Kumar, Yuzhou Wang

# What are AutoEncoders (AEs)


**AutoEncoders** (AEs) are a type of neural network designed to learn efficient representations of data, typically in an unsupervised manner. Their primary goal is to compress data (encode) and then reconstruct it (decode) as accurately as possible.

## Structure of AutoEncoders
- **Encoder**: The encoder compresses the input data into a lower-dimensional representation (often called the "latent space" or "latent vector").
- **Latent Space**: The latent space represents the compressed information and captures the essential features of the input data.
- **Decoder**: The decoder reconstructs the original data from the latent space representation.
 ![AutoEncoder Structure](https://drive.google.com/uc?export=view&id=1uofZ-KL_BrpQqq8fqKQpnttulZ74pGPx)


## Key Features of AutoEncoders
- **Unsupervised Learning**: They do not require labeled data since they aim to reconstruct the input itself.

- **Compression**: AEs learn to reduce the dimensionality of data, making them useful for tasks like data compression and feature extraction.

- **Reconstruction Loss**: The training objective is to minimize the difference between the input and reconstructed output, typically using a loss function like Mean Squared Error (MSE).


## Objective Formulation of AutoEncoders

- **Input:** $x$
- **Encoding:** $z=f(x)$
- **Latent Vector:** $z$ optained from output of Encode
- **Decoding:** $\hat{x} = g(z)$. Reconstructs the input from latent vector
- **Loss:** The goal is to minimize the reconstruction error, often measured by Mean Squared Error (MSE) or Binary Cross-Entropy (BCE), depending on the type of data. Example with MSE <br>
$Loss = || x - \hat{x} ||^2 = ||x - g(f(x))||^2$ <br>
Where:  
- $x$: Original input data
- $\hat{x}$: = $g(f(x))$ Reconstructed data from the encoded representation

## Training Objective
The objective of training an AutoEncoder is to minimize this loss function, which forces the model to learn a latent representation that captures the important features of the input data while preserving as much information as possible during the reconstruction.

## Limitations of AutoEncoders
- **Limited Control Over the Latent Space Representation**: AutoEncoders learn a latent representation in an unsupervised manner without imposing a specific structure or distribution on the latent space. As a result, the learned latent space might not have properties like smoothness or continuity, making it difficult to interpret or control.
- **Overfitting**: Overfitting can occur if the AutoEncoder is too complex (e.g., having too many parameters or layers) relative to the training data size. If the model memorizes the data rather than generalizing patterns, it may reconstruct training data well but fail to perform effectively on unseen data.

- **Limitations in Generation**: AutoEncoders are not inherently designed for generation, they can still generate new data by sampling from the learned latent space and passing it through the decoder. However, the lack of structure in the latent space makes this challenging especially for controled generation.

## How to overcome these limiations ? See Variational Autoencoders

# What are Variational AutoEncoders (VAE)


Variational AutoEncoders (VAEs) are an extension of AutoEncoders that introduce a probabilistic approach to the latent space. They are commonly used in generative modeling.

## Structure of Variaonal AutoEncoders
Similar to AutoEncoders described above
![AutoEncoder Structure](https://drive.google.com/uc?export=view&id=1bRxVo_vWxNy44BQgEOLJ6fJ_bPiB4DG-)


## Key Properties of VAEs
- VAEs are probabilistic models to learn data distribution
- Map inputs to a probability distribution
- Objective is to maximize the evidence lower bound (ELBO) : NLL of Data
- Allows learning of a structured latent space representation


### Loss function of VAE
\begin{equation}
\mathcal{L}(\theta, \phi; x^{(i)}) = -D_{KL}\left(q_\phi(z|x^{(i)}) \,||\, p_\theta(z)\right) + \mathbb{E}_{q_\phi(z|x^{(i)})} \left[ \log p_\theta(x^{(i)}|z) \right]
\end{equation}
where:


- - $ \mathcal{L}(\theta, \phi; x^{(i)}) $:   is the VAE loss for a single data point $x^{(i)} $.
- - $ \theta $: Parameter of decoder (which maps the latent space $z$ back to the original space to reconstruct  $x$).
- - $ \phi $ : Parameter of encoder (which maps the input $x$ to the latent space distribution $q(z|x)$.


### First Term: KL Divergence $-D_{KL}\left(q_\phi(z|x^{(i)}) \,||\, p_\theta(z)\right)
$

- - This term encourages the distribution $ q_\phi(z|x^{(i)}) $ (the encoder output) to be close to the prior $ p_\theta(z) $, typically $ \mathcal{N}(0, 1) $ thereby regularizing the latent space, ensuring that the encoder outputs are distributed in a way that allows sampling from the latent space.

- - By minimizing this KL divergence, we ensure that the latent space representations stay close to the prior distribution, making it easier to generate new data points by sampling from this space.

### Second Term: Reconstruction Loss $\mathbb{E}_{q_\phi(z|x^{(i)})} \left[ \log p_\theta(x^{(i)}|z) \right]
$

- - The second term is the reconstruction loss, which encourages accurate reconstruction of the input:

- - This term measures the likelihood of reconstructing the input $ x^{(i)} $ given the latent variable $ z $ sampled from $ q_\phi(z|x^{(i)}) $.

- - The expectation $\mathbb{E}_{q_\phi(z|x^{(i)})}$ indicates that we are taking an average over multiple samples from  $q(z|x)$

- - $ \log p_\theta(x^{(i)}|z) $ represents the log-likelihood of the reconstructed data given the latent representation, which is high when the reconstructed data is close to the original.

- - This term encourages the VAE to accurately reconstruct the input data from the latent representation, ensuring that the learned representations retain enough information to reconstruct the original input.


## Solving Issues with AEs
- **Limited Control Over the Latent Space Representation** <br>
 **Solution?**
- - VAEs enforce a structured latent space by assuming that the latent variables $z$  are drawn from a known prior distribution, typically a Gaussian $p(z) = \mathcal{N}(0,I)$
- - The encoder outputs a mean ($\mu$) and variance ($σ^2)$ for each dimension of and the latent variable is sampled as $z ~ \mathcal{N}(\mu, σ^2)$
- -  This ensures that the latent space is smooth, continuous, and interpretable, making it easier to sample new data points or interpolate between latent representations.



- **Overfitting ?** <br>
 **Solution:**
- - The KL divergence term in the VAE loss regularizes the latent space by encouraging the learned distribution $q(z|x)$ to stay close to the prior $p(z)$

- - This regularization prevents overfitting by discouraging the model from over-specializing to individual training samples and instead promotes learning generalized patterns.
- - Additionally, the probabilistic nature of VAEs (sampling from $\mathcal{N}(\mu, σ^2)$) introduces randomness, acting as a form of regularization and preventing the model from purely memorizing data.

- **Data Generation Limiation ?**
 **Solution**
- - VAEs are explicitly designed as generative models. By enforcing a Gaussian prior on the latent space, the latent variables $z$ are drawn from a smooth and well-defined distribution.

- - New samples can be generated by simply sampling $z. ~ \mathcal{N}(0, I) $ and passing these samples through the decoder $p(x|z)$.

- - The continuity of the latent space ensures that interpolating between points in $z$-space results in meaningful and realistic outputs.


# Coding AEs and VAs



Credits: [facnet-pytorch](https://github.com/timesler/facenet-pytorch) and [LFW-Face dataset](https://vis-www.cs.umass.edu/lfw)


# Preliminaries

## Libraries

In [None]:
!pip install facenet-pytorch --quiet
!pip install torchinfo --quiet

## LFW Face Dataset

In [None]:
!wget http://vis-www.cs.umass.edu/lfw/lfw.tgz
!tar -xvzf lfw.tgz

In [None]:
!pip install gdown

# Get Checkpoints

In [None]:
import gdown


In [None]:
ae_url = "https://drive.google.com/uc?id=19DfNOXsiEfFg8nhV-h-Lf-Vda-vi8NB_"
ae_output = "ae.tar.gz"
gdown.download(ae_url, ae_output)
vae_url = "https://drive.google.com/uc?id=17s2-XgCfBOFpVHa5Gzb38UweN5Joiwjy"
vae_output = "vae.tar.gz"
gdown.download(vae_url, vae_output)

In [None]:
!tar xvzf /content/vae.tar.gz
!tar xvzf /content/ae.tar.gz

## Imports

In [None]:
!pip install --upgrade Pillow

In [None]:
from facenet_pytorch import MTCNN
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchinfo import summary
from typing import List, Tuple
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gc
import os
workers = 0 if os.name == 'nt' else 4
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

# Data Preparation

In [None]:
# Set the part to the LFW dataset you downloaded
data_dir = '/content/lfw'
batch_size = 1024

### -------------------------------------------------------------------------------------------------------------

def collate_fn(batch: List[Tuple[Image.Image, int]]) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Collate function for the DataLoader to process a batch of images and labels.

    Args:
        batch (List[Tuple[Image.Image, int]]): A list of tuples, where each tuple contains:
            - A PIL Image (Image.Image) representing an image.
            - An integer label (int) associated with the image.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]:
            - A tensor of images of shape (batch_size, C, H, W), where C is the number of channels,
              H is the height, and W is the width of the images. Each image is converted to a tensor
              from its original PIL format.
            - A tensor of labels of shape (batch_size), containing the corresponding integer labels.
    """
    # Separate images and labels from the batch
    images, labels = zip(*batch)

    # Convert each PIL image to a tensor
    images = [transforms.ToTensor()(img) for img in images]

    # Stack all images into a single tensor of shape (batch_size, C, H, W)
    images = torch.stack(images)

    # Convert labels to a tensor
    labels = torch.tensor(labels)

    return images, labels

### -------------------------------------------------------------------------------------------------------------


dataset = datasets.ImageFolder('/content/lfw')
# We add the idx_to_class attribute to the dataset to enable easy recoding of label indices to identity names later one.
dataset.idx_to_class = {i:c for c, i in dataset.class_to_idx.items()}
loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=workers, batch_size=batch_size)

image_size = dataset[0][0].size[0]
print("Number of classes    : ", len(dataset.classes))
print("No. of train images  : ", dataset.__len__())
print("Type of image        : ", type(dataset[0][0]))
print("Image Dim            : ", f'{image_size}x{image_size}')
print("Num Batches          : ", f'{len(loader)}')



# Plot one batch

In [None]:
# Set up an 8x8 grid of subplots for the 64 images in the batch
fig, ax = plt.subplots(4, 4, figsize=(10, 10))
num_plots = 4 * 4  # Number of images to plot in the batch
# Loop over the data loader
for batch in loader:
    # Get the batch of images and labels
    images, labels = batch
    # Loop over each image in the batch (assuming batch size of 64)
    for j in range(batch_size):
        img = images[j].permute(1, 2, 0)  # (C, H, W) -> (H, W, C)
        row = j // 4  # Row index for subplot
        col = j % 4   # Column index for subplot
        ax[row, col].imshow(img)
        ax[row, col].axis('off')
        if j >= num_plots - 1:
            break  # Stop after plotting the specified number of images
    break

plt.tight_layout()
plt.show()


# Data Preprocessing: MTCNN For Face Detection and Cropping
- `MTCNN` (Multi-task Cascaded Convolutional Networks) is a deep learning model for face detection, facial landmark detection, and face alignment. It uses a cascaded structure of three neural networks (P-Net, R-Net, O-Net) to progressively refine face detection and predict key facial landmarks, adjusting the face to a standard pose. It is known for its high accuracy and efficiency across various face sizes and orientations.
- We will use MTCNN to obtain bounding boxes for detected faces and then crop those faces with th `mtcnn.detect()` function. See `help(mtcnn.detect)` for details.

In [None]:
mtcnn = MTCNN(keep_all=True, device=device)

def apply_mtcnn_crop(images: torch.Tensor, target_size=(128,128)) -> torch.Tensor:
    """
    Function to apply MTCNN face detection and crop the images in a batch.

    Parameters:
        images (torch.Tensor): A batch of images of shape (batch_size, C, H, W).
        target_size (tuple): The target size (H, W) to resize images to after cropping.

    Returns:
        torch.Tensor: A batch of cropped images.
    """
    # Convert tensor batch to PIL images
    images_pil = [transforms.ToPILImage()(img) for img in images]

    cropped_images = []

    for img_pil in images_pil:
        boxes, _ = mtcnn.detect(img_pil)  # Detect faces with MTCNN
        if boxes is not None:
            # Crop the image using the first detected bounding box
            x1, y1, x2, y2 = boxes[0].tolist()
            img_pil = img_pil.crop((x1, y1, x2, y2))  # Crop the image
        # Resize the image to the target size
        img_pil = img_pil.resize(target_size)
        # Convert back to tensor after cropping (if cropped)
        img_tensor = transforms.ToTensor()(img_pil)
        cropped_images.append(img_tensor)

    # Stack the cropped images into a single tensor
    return torch.stack(cropped_images)

## MTCNN Processed Dataset (So you can use your own images later!)

In [None]:
class MTCNNPreprocessedDataset(Dataset):
    def __init__(self, original_dataset, target_size=(128, 128), transform=None):
        """
        Args:
            original_dataset (Dataset): The original dataset to use (e.g., a standard ImageFolder).
            target_size (tuple): The target size for resized images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.original_dataset = original_dataset  # Original dataset (e.g., ImageFolder, etc.)
        self.target_size = target_size  # The target size after cropping
        self.transform = transform  # Optional transform to apply to the image

        # Preprocess the images during initialization
        self.images = []
        self.labels = []
        self._preprocess_images()

    def __len__(self):
        return len(self.original_dataset)

    def __getitem__(self, idx):
        """
        Return preprocessed image and its label.
        """
        img, label = self.images[idx], self.labels[idx]

        # Apply additional transformations if provided (e.g., normalization)
        if self.transform:
            img = self.transform(img)

        return img, label

    def _preprocess_images(self):
        """
        Apply MTCNN face detection, crop, and resize for all images in the dataset.
        """
        print("Copping images with MTCNN: ")
        for idx in tqdm(range(len(self.original_dataset))):
            img, label = self.original_dataset[idx]  # Get the image and label from the original dataset
            img = self.apply_mtcnn_crop(img, target_size=self.target_size)  # Apply MTCNN crop
            self.images.append(img)
            self.labels.append(label)

    def apply_mtcnn_crop(self, img: Image.Image, target_size=(128, 128)):
        """
        Apply MTCNN cropping and resize the image.

        Args:
            img (PIL.Image.Image): The input image.
            target_size (tuple): The desired output size (H, W) after cropping and resizing.

        Returns:
            PIL.Image.Image: The cropped and resized image.
        """
        boxes, _ = mtcnn.detect(img)  # Detect faces with MTCNN
        if boxes is not None:
            # Crop the image using the first detected bounding box
            x1, y1, x2, y2 = boxes[0].tolist()
            img = img.crop((x1, y1, x2, y2))  # Crop the image
        # Resize the image to the target size
        img = img.resize(target_size)
        return img


## Create a MTCNN-PreProcessed dataset/dataloader

In [None]:
# transform: Convert to Tensor
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Create the custom MTCNN preprocessed dataset
mtcnn_preprocessed_dataset = MTCNNPreprocessedDataset(dataset, target_size=(128, 128), transform=transform)

# Create a DataLoader for the preprocessed images
batch_size = 1024
mtcnn_preprocessed_loader = DataLoader(mtcnn_preprocessed_dataset, batch_size=batch_size, shuffle=True)


## Plot Cropped Face Images detected by MTCNN

In [None]:
# Print number of batches and shape of each batch
print(f'Number of batches: {len(mtcnn_preprocessed_loader)}')

# Set up an 8x8 grid of subplots for the 64 images in the batch
fig, ax = plt.subplots(4, 4, figsize=(10, 10))
num_plots = 4 * 4  # Number of images to plot in the batch
# Loop over the data loader
for batch in mtcnn_preprocessed_loader:
    # Get the batch of images and labels
    images, labels = batch
    print(f'  Image batch shape: {images.shape}')
    print(f'  Label batch shape: {labels.shape}')
    # Loop over each image in the batch
    for j in range(batch_size):
        img = images[j].permute(1, 2, 0)  # (C, H, W) -> (H, W, C)
        row = j // 4  # Row index for subplot
        col = j % 4   # Column index for subplot
        ax[row, col].imshow(img)
        ax[row, col].axis('off')
        if j >= num_plots - 1:
            break  # Stop after plotting the specified number of images
    break

plt.tight_layout()
plt.show()

# The AutoEncoder

This Autoencoder architecture uses a series of convolutional layers to compress input images into a 128-dimensional latent space and then reconstructs them back using deconvolution layers. The Binary Cross-Entropy loss helps guide the network to minimize reconstruction errors, making the output images as close as possible to the originals.
- **Input**: Images of shape \([batch\_size, 3, 128, 128]\) (3 channels, 128x128 resolution).
- **Encoder**: Series of convolutional layers that progressively reduce the spatial dimensions, capturing higher-level features. The final output from the encoder is flattened to a 1D vector and then passed through a fully connected layer to produce the latent representation (dimensionality defined by `latent_space_size`).
- **Latent Space**: Fully connected layer compressing the encoded features into a 512-dimensional vector (latent representation).
- **Decoder**: Series of upscaling + convolutional layers that upsample the latent vector back to the original input shape.
- **Output**: Reconstructed image with the same shape as the input: \([batch\_size, 3, 128, 128]\).


### Loss Function
The reconstruction loss used is **Mean Squared Loss (MSE)**, which measures the similarity between the original input and the reconstructed output:

$\text{recon_loss} = \text{MSE}(recon\_x, x)$

Where:
- \(x\) is the original image.
- \(recon\_x\) is the reconstructed image from the decoder.

In [None]:
class Autoencoder(nn.Module):
    def __init__(self, in_ch=3, ndf=32, latent_space_size=256):
        super(Autoencoder, self).__init__()

        self.latent_space_size = latent_space_size
        self.ndf = ndf

        # Encoder: Convolutional layers to extract features
        # [batch_size,     3, 128, 128]  -> [batch_size,   ndf, 64, 64]
        # [batch_size,   ndf,  64,  64]  -> [batch_size, ndf*2, 32, 32]
        # [batch_size, ndf*2,  32,  32]  -> [batch_size, ndf*4, 16, 16]
        # [batch_size, ndf*4,  16,  16]  -> [batch_size, ndf*8,  8,  8]
        self.encoder = nn.Sequential(
            nn.Conv2d(in_ch, ndf, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf),
            nn.LeakyReLU(),
            nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf*2),
            nn.LeakyReLU(),
            nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf*4),
            nn.LeakyReLU(),
            nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf*8),
            nn.LeakyReLU(),
        )

        # Latent space size
        # FC layer to output latent vector
        # Fully connected to latent space
        self.flattened_size = ndf*8*8*8
        self.fc = nn.Linear(self.flattened_size, self.latent_space_size)
        # Decoder: Convolutional Transpose layers (Deconvolution)
        self.fc_decode = nn.Linear(self.latent_space_size, self.flattened_size)

        # Encoder: Convolutional layers to extract features
        # [batch_size, ndf*8,  8,  8]  -> [batch_size, ndf*4,  16,  16]
        # [batch_size, ndf*4, 16, 16]  -> [batch_size, ndf*2,  32,  32]
        # [batch_size, ndf*2, 32, 32]  -> [batch_size,   ndf,  64,  64]
        # [batch_size, ndf,   64, 64]  -> [batch_size,     3, 128, 128]
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ndf*8, ndf*4, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(ndf*4),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ndf*4, ndf*2, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(ndf*2),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ndf*2, ndf, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(ndf),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ndf, in_ch, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )


    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encodes the input image to the latent representation.
        :param x: Input image tensor with shape [batch_size, 3, 128, 128]
        :return: Latent representation with shape [batch_size, latent_space_size]
        """
        x = self.encoder(x)  # Output shape: [batch_size, ndf*8, 8, 8]
        x = x.view(x.size(0), -1)  # Flatten the output
        return self.fc(x)

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """
        Decodes the latent space representation back into the reconstructed image.
        :param z: Latent space tensor with shape [batch_size, latent_space_size]
        :return: Reconstructed image tensor with shape [batch_size, 3, 128, 128]
        """
        z = self.fc_decode(z)  # Output shape: [batch_size, flattened_size]
        z = z.view(z.size(0), self.ndf*8, 8, 8)  # Reshape
        return self.decoder(z)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass: Encodes and decodes the input image.
        :param x: Input image tensor with shape [batch_size, 3, 128, 128]
        :return: Reconstructed image tensor with shape [batch_size, 3, 128, 128]
        """
        latent = self.encode(x)  # Output shape: [batch_size, latent_space_size]
        decoded = self.decode(latent)  # Output shape: [batch_size, 3, 128, 128]
        return decoded

    def loss_function(self, recon_x: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        """
        Calculates the Autoencoder loss, which consists of the reconstruction loss (BCE).
        :param recon_x: Reconstructed image tensor with shape [batch_size, 3, 128, 128]
        :param x: Original image tensor with shape [batch_size, 3, 128, 128]
        :return: Total loss value (BCE loss)
        """
        # Reconstruction loss (BCE Loss w logits, sigmoid is applied)
        recon_loss = F.mse_loss(recon_x, x, reduction='sum')
        return recon_loss


In [None]:
# Instantiate the autoencoder
AEModel = Autoencoder()
summary(AEModel.to(device), input_data=[images.to(device)])

# The Variational AutoEncoder (VAE)

A **Variational Autoencoder (VAE)** is a generative model that learns the distribution of data in a latent space, allowing for the generation of new data points similar to the input data. VAEs consist of two main parts: the **Encoder** and the **Decoder**. The encoder maps input data to a distribution in the latent space, while the decoder reconstructs data from the latent space representation. A VAE introduces a probabilistic twist to the traditional autoencoder by modeling the latent space as a distribution instead of a fixed vector, and optimizing the model using both reconstruction loss and the Kullback-Leibler (KL) divergence.

### 1. **Encoder**

A key difference from the AE encoder is that instead of directly outputting a single point in the latent space, the encoder outputs two values:
- **Mean (μ)** of the distribution.
- **Log variance (log(σ²))** to define the spread of the distribution.

This means that the encoder outputs parameters for a Gaussian distribution, which will be used for sampling a point in the latent space.

### 2. **Latent Space Sampling (Reparameterization Trick)**

Once we have the mean (μ) and log variance (log(σ²)) from the encoder, we need to sample a point from the latent space. However, direct sampling from this distribution is not feasible during backpropagation because we need to compute gradients. The **reparameterization trick** allows us to sample from the Gaussian distribution in a way that is differentiable, enabling gradient flow through the sampling process.

This trick works as follows:
- **std = exp(0.5 * log_var)**: The standard deviation is derived from the log variance.
- **z = mu + eps * std**: We sample `eps` from a standard normal distribution (N(0,1)) and then use it to scale and shift the distribution defined by μ and σ².

### 3. **Decoder**

The decoder takes the sampled latent vector (`z`) and maps it back to the data space (e.g., the original image). It is identical to the AE decoder.

### 4. **Loss Function**

The loss function for the VAE consists of two parts:
1. **Reconstruction Loss**: This measures how well the decoder can reconstruct the original input from the latent space.
   - `recon_loss = F.mse_loss(recon_x, x, reduction='sum')`

2. **KL Divergence**: This measures the difference between the learned distribution (the distribution defined by the encoder's outputs μ and log(σ²)) and the prior distribution, which is typically a standard normal distribution (N(0, 1)). The KL divergence term helps to regularize the latent space by pushing the learned distribution closer to a standard normal distribution.
   - `kl_div = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())`

The final VAE loss is the sum of these two components:
- `total_loss = recon_loss + kl_weight * kl_div`


In [None]:
class VariationalAutoencoder(nn.Module):
    def __init__(self, in_ch=3, ndf=32, latent_space_size=256):
        super(VariationalAutoencoder, self).__init__()

        self.latent_space_size = latent_space_size
        self.ndf = ndf

        # Encoder: Convolutional layers to extract features
        # [batch_size,     3, 128, 128]  -> [batch_size,    ndf, 64, 64]
        # [batch_size,   ndf,  64,  64]  -> [batch_size,  ndf*2, 32, 32]
        # [batch_size, ndf*2,  32,  32]  -> [batch_size,  ndf*4, 16, 16]
        self.encoder = nn.Sequential(
            nn.Conv2d(in_ch, ndf, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf),
            nn.LeakyReLU(),
            nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf*2),
            nn.LeakyReLU(),
            nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf*4),
            nn.LeakyReLU(),
            nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(ndf*8),
            nn.LeakyReLU(),
        )

        # Latent space size (for VAE, we need mu and log_var for the latent distribution)
        self.flattened_size = ndf*8*8*8

        # Fully connected layer to output the mean and log variance for the latent space
        self.fc_mu = nn.Linear(self.flattened_size, self.latent_space_size)  # Mean of latent space
        self.fc_log_var = nn.Linear(self.flattened_size, self.latent_space_size)  # Log variance of latent space


        # Decoder: Convolutional Transpose layers (Deconvolution)
        self.fc_decode = nn.Linear(self.latent_space_size, self.flattened_size)  # Fully connected to latent space

        # Encoder: Convolutional layers to extract features
        # [batch_size,   ndf*8,  8,  8]  -> [batch_size, ndf*4,  16,  16]
        # [batch_size,   ndf*4, 16, 16]  -> [batch_size, ndf*2,  32,  32]
        # [batch_size,   ndf*2, 32, 32]  -> [batch_size,   ndf,  64,  64]
        # [batch_size,     ndf, 64, 64]  -> [batch_size,     3, 128, 128]
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ndf*8, ndf*4, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(ndf*4),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ndf*4, ndf*2, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(ndf*2),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ndf*2, ndf, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(ndf),
            nn.LeakyReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(ndf, in_ch, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )


    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        Encodes the input image to the latent distribution (mean and log variance).
        :param x: Input image tensor with shape [batch_size, 3, 128, 128]
        :return: mean (mu) and log variance (log_var) tensors with shape [batch_size, latent_space_size]
        """
        x = self.encoder(x)  # Output shape: [batch_size, ndf*8, 8, 8]
        x = x.view(x.size(0), -1)  # Flatten the output
        mu = self.fc_mu(x)  # Mean of latent space [batch_size, latent_space_size]
        log_var = self.fc_log_var(x)  # Log variance of latent space [batch_size, latent_space_size]
        return mu, log_var

    def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
        """
        Reparameterization trick: Sample z from N(mu, sigma^2) using mu and log_var.
        :param mu: Mean of latent space with shape [batch_size, latent_space_size]
        :param log_var: Log variance of latent space with shape [batch_size, latent_space_size]
        :return: Latent vector z sampled from N(mu, sigma^2)
        """
        std = torch.exp(0.5 * log_var)  # Standard deviation from log_var
        eps = torch.randn_like(std)     # Sample from N(0, 1)
        return mu + eps * std

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """
        Decodes the latent space representation back into the reconstructed image.
        :param z: Latent space tensor with shape [batch_size, latent_space_size]
        :return: Reconstructed image tensor with shape [batch_size, 3, 64, 64]
        """
        z = self.fc_decode(z)  # Output shape: [batch_size, flattened_size]
        z = z.view(z.size(0), self.ndf*8, 8, 8)  # Reshape
        return self.decoder(z)  # Output shape: [batch_size, 3, 128, 128]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass: Encodes the input, applies reparameterization trick, and decodes.
        :param x: Input image tensor with shape [batch_size, 3, 128, 128]
        :return: Reconstructed image tensor with shape [batch_size, 3, 128, 128]
        """
        mu, log_var = self.encode(x)          # Output shape: [batch_size, latent_space_size]
        z = self.reparameterize(mu, log_var)  # Sample from latent distribution
        decoded = self.decode(z)
        return decoded, mu, log_var

    def loss_function(self, recon_x: torch.Tensor, x: torch.Tensor, mu: torch.Tensor, log_var: torch.Tensor, kl_weight:float=0.00025) -> torch.Tensor:
        """
        Calculates the VAE loss, which is the sum of the reconstruction loss and the KL divergence.
        :param recon_x: Reconstructed image tensor with shape [batch_size, 3, 128, 128]
        :param x: Original image tensor with shape [batch_size, 3, 128, 128]
        :param mu: Mean of the latent space with shape [batch_size, latent_space_size]
        :param log_var: Log variance of the latent space with shape [batch_size, latent_space_size]
        :param kl_weight: Weight for the KL divergence term
        :return: Total VAE loss value
        """
        # Reconstruction loss (BCE Loss)
        recon_loss = F.mse_loss(recon_x, x, reduction='sum')

        # KL divergence between the learned distribution and the standard normal
        kl_div = -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp())

        # Total VAE loss is the sum of reconstruction loss and KL divergence
        return recon_loss + kl_weight * kl_div


In [None]:
# Instantiate the autoencoder
VAEModel = VariationalAutoencoder()
summary(VAEModel.to(device), input_data=[images.to(device)])

# Utils

In [None]:
import math

class KLWeightScheduler:
    def __init__(self, kl_weight_max: float, total_epochs: int, low_epochs: int = 5, warmup_epochs: int = 10):
        """
        S-shaped growth scheduler for KL weight with an initial low period.

        :param kl_weight_max: The maximum KL weight at the end of warm-up.
        :param total_epochs: The total number of epochs for training.
        :param low_epochs: The number of epochs the KL weight remains low.
        :param warmup_epochs: The number of epochs over which the KL weight increases.
        """
        self.kl_weight_max = kl_weight_max
        self.total_epochs = total_epochs
        self.low_epochs = low_epochs
        self.warmup_epochs = warmup_epochs
        self.last_epoch = 0

    def get_kl_weight(self):
        """
        Calculate the KL weight using a modified sigmoid growth function.
        :return: The current KL weight value.
        """
        # Before the warm-up phase, keep the weight low
        if self.last_epoch < self.low_epochs:
            kl_weight = 0.0
        else:
            # Calculate the progress in the warm-up phase
            progress = (self.last_epoch - self.low_epochs) / self.warmup_epochs
            # Sigmoid function to generate S-shaped curve
            kl_weight = self.kl_weight_max / (1 + math.exp(-10 * (progress - 0.5)))

            # Ensure kl_weight does not exceed kl_weight_max
            kl_weight = min(kl_weight, self.kl_weight_max)

        return kl_weight

    def step(self):
        """
        Increment the epoch count for KL weight scheduling.
        """
        self.last_epoch += 1




def train_step(model: nn.Module, optimizer, train_loader: DataLoader, kl_weight:float=0.0):
    """
    Perform a single training step over the entire training dataset with exponential warm-up for KL weight.
    :param model: The neural network model to be trained.
    :param optimizer: The optimizer used to update the model's parameters.
    :param train_loader: DataLoader for the training dataset, providing batches of data.
    :param kl_weight: The KL weight for the model.
    :return: The average loss over the training dataset.
    """
    model.train()
    total_loss = 0
    num_batches = len(train_loader)


    # Wrap the train_loader with tqdm to show the progress bar
    with tqdm(train_loader, unit="batch", desc=f"[Training]:") as pbar:
        for data, _ in pbar:
            optimizer.zero_grad()

            # Move data to GPU if available
            data = data.cuda()

            # Forward pass
            if isinstance(model, VariationalAutoencoder):
                reconstructed, mu, log_var = model(data)
                loss = model.loss_function(reconstructed, data, mu, log_var, kl_weight=kl_weight)
            else:
                reconstructed = model(data)
                loss = model.loss_function(reconstructed, data)

            # Backward pass with scaled gradients
            loss.backward()
            optimizer.step()

            # Accumulate total loss
            total_loss += loss.item()

            # Update the progress bar with the current average batch loss
            pbar.set_postfix(loss=loss.item())

            # Clean up
            del data, reconstructed, loss
            torch.cuda.empty_cache()

    # Calculate the average loss for the epoch
    avg_loss = total_loss / num_batches
    return avg_loss


def validate_step(model: nn.Module, val_loader: DataLoader, epoch: int, path:str):
  """
  Runs a validation step and plots a few images with their reconstructions.
  Saves the plot as an image file.

  :param model: Trained model.
  :param val_loader: DataLoader for the validation dataset.
  :param epoch: Current epoch number.
  :path (str): File path for saving.
  """
  root = os.path.join(path, 'validation_images')
  os.makedirs(root, exist_ok=True)

  model.eval()  # Set the model to evaluation mode
  with torch.no_grad():
      # Get a batch of validation data
      data, _ = next(iter(val_loader))
      data = data.cuda()  # Move data to GPU if available

      # Forward pass to get reconstructions
      if isinstance(model, VariationalAutoencoder):
          reconstructed, _, _ = model(data)
      else:
          reconstructed = model(data)

      # Move data back to CPU for plotting
      data = data.cpu()
      reconstructed = reconstructed.cpu()

      # Plot original and reconstructed images
      fig, axes = plt.subplots(2, 5, figsize=(15, 6))
      for i in range(5):
          # Original images
          axes[0, i].imshow(data[i].permute(1, 2, 0).squeeze(), cmap='gray')
          axes[0, i].set_title("Original")
          axes[0, i].axis('off')

          # Reconstructed images
          axes[1, i].imshow(reconstructed[i].permute(1, 2, 0).squeeze(), cmap='gray')
          axes[1, i].set_title("Reconstructed")
          axes[1, i].axis('off')

      plt.tight_layout()
      impath = os.path.join(root, f'validation_epoch_{epoch}.png')
      plt.savefig(impath)
      plt.close(fig)
      print(f'Saved validation results for epoch {epoch} as validation_epoch_{epoch}.png')


def save_model(model, optimizer, scheduler, metric, epoch, path):
    """
    Saves the model, optimizer, and scheduler states along with a metric to a specified path.

    Args:
        model (nn.Module): Model to be saved.
        optimizer (Optimizer): Optimizer state to save.
        scheduler (Scheduler or None): Scheduler state to save.
        metric (tuple): Metric tuple (name, value) to be saved.
        epoch (int): Current epoch number.
        path (str): File path for saving.
    """
    # Ensure metric is provided as a tuple with correct structure
    if not (isinstance(metric, tuple) and len(metric) == 2):
        raise ValueError("metric must be a tuple in the form (name, value)")

    torch.save(
        {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict() if scheduler else {},
            metric[0]: metric[1],  # Unpacks the metric name and value
            "epoch": epoch
        },
        path
    )


def load_model(model, optimizer, scheduler, path):
    """
    Loads the model, optimizer, and scheduler states along with a saved metric and epoch from a specified path.

    Args:
        model (nn.Module): Model instance to load the state into.
        optimizer (Optimizer or nNone): Optimizer instance to load the state into.
        scheduler (Scheduler or None): Scheduler instance to load the state into, if applicable.
        path (str): File path to load the checkpoint from.

    Returns:
        tuple: A tuple containing the metric (name, value) and the last saved epoch.
    """
    # Load the checkpoint from the specified path
    checkpoint = torch.load(path)

    # Load the model's state dictionary
    model.load_state_dict(checkpoint["model_state_dict"])

    # Load the optimizer's state dictionary
    if optimizer:
      optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

    # Load the scheduler's state dictionary, if provided
    if scheduler:
        scheduler_state = checkpoint.get("scheduler_state_dict", {})
        if scheduler_state:
            scheduler.load_state_dict(scheduler_state)

    # Retrieve the metric from the checkpoint
    # Identify the metric key (excluding reserved keys)
    metric_keys = [key for key in checkpoint.keys() if key not in {"model_state_dict", "optimizer_state_dict", "scheduler_state_dict", "epoch"}]
    if len(metric_keys) != 1:
        raise ValueError("Unexpected format: More than one metric key found in checkpoint.")
    metric_name = metric_keys[0]
    metric_value = checkpoint[metric_name]

    # Retrieve the last saved epoch
    epoch = checkpoint.get("epoch", 0)

    return (metric_name, metric_value), epoch



# Train Autoencoder
A `ae_checkpoints` directory will have been created where the model checkpoints are saved. Additionally, every 10 epochs, 5 samples are reconstructed and compared to the original. The images are saved in `ae_checkpoints/validation_imgs`. You can use this to monitor the performance of your Autoencoder.

In [None]:
gc.collect()
torch.cuda.empty_cache()
checkpoint_root = os.path.join(os.getcwd(), 'ae_checkpoints')
os.makedirs(checkpoint_root, exist_ok=True)
checkpoint_filename = 'ae_model'

# choose model to train
model = AEModel

# set your epochs for this approach
# Set up the optimizer
# Set up the scheduler
epochs = 100
learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1E-8)
best_loss = float('inf')

for epoch in range(epochs):

    print("\nEpoch {}/{}".format(epoch+1, epochs))

    curr_lr = float(optimizer.param_groups[0]["lr"])

    train_loss = train_step(model, optimizer, mtcnn_preprocessed_loader)

    print("\nEpoch {}/{}: \nTrain Loss {:.04f}\t Learning Rate {:.06f}".format(
        epoch + 1, epochs, train_loss, curr_lr))

    scheduler.step()

    if (epoch+1) % 10 == 0:
        validate_step(model, mtcnn_preprocessed_loader, epoch+1, checkpoint_root)

    if best_loss >= train_loss:
        best_loss = train_loss
        epoch_model_path = os.path.join(checkpoint_root, (checkpoint_filename + str(epoch) + '.pth'))
        save_model(model, optimizer, scheduler, ('loss', train_loss), epoch, epoch_model_path)
        print("Saved best loss model")
#### ----------------------------------------------------------------------------------------------------------------------


# Train VAE
A `vae_checkpoints` directory will have been created where the model checkpoints are saved. Additionally, every 10 epochs, 5 samples are reconstructed and compared to the original. The images are saved in `vae_checkpoints/validation_imgs`. You can use this to monitor the performance of your Autoencoder.

In [None]:
gc.collect()
torch.cuda.empty_cache()
checkpoint_root = os.path.join(os.getcwd(), 'vae_checkpoints')
os.makedirs(checkpoint_root, exist_ok=True)
checkpoint_filename = 'vae_model'

# choose model to train
model = VAEModel

# set your epochs for this approach
# Set up the optimizer
# Set up the scheduler
epochs = 120
learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1E-8)

# Define KL Weight Scheduler with linear warm-up
kl_weight_max = 2.0  # Maximum KL weight
kl_scheduler = KLWeightScheduler(kl_weight_max, epochs, low_epochs=10, warmup_epochs=10)

best_loss = float('inf')

for epoch in range(epochs):

    print("\nEpoch {}/{}".format(epoch+1, epochs))

    curr_lr = float(optimizer.param_groups[0]["lr"])
    kl_weight = kl_scheduler.get_kl_weight()

    train_loss = train_step(model, optimizer, mtcnn_preprocessed_loader, kl_weight)

    print("\nEpoch {}/{}: \nTrain Loss {:.04f}\t Learning Rate {:.06f}\t KL Weight {:.06f}".format(
        epoch + 1, epochs, train_loss, curr_lr, kl_weight))

    scheduler.step()
    kl_scheduler.step()

    if (epoch+1) % 10 == 0:
        validate_step(model, mtcnn_preprocessed_loader, epoch+1, checkpoint_root)

    if best_loss >= train_loss:
        best_loss = train_loss
        epoch_model_path = os.path.join(checkpoint_root, (checkpoint_filename + str(epoch) + '.pth'))
        save_model(model, optimizer, scheduler, ('loss', train_loss), epoch, epoch_model_path)
        print("Saved best loss model")
#### ----------------------------------------------------------------------------------------------------------------------


# Experiments

## Load checkpoints

In [None]:
AEModel.load_state_dict(torch.load('/content/ae_checkpoints/final_ae_model99.pth')['model_state_dict'])
VAEModel.load_state_dict(torch.load('/content/vae_checkpoints/final_vae_model149.pth')['model_state_dict'])

# Image Reconstruction Comparison

In [None]:
AEModel.eval()
VAEModel.eval()
with torch.no_grad():
  # Get a batch of validation data
  data, _ = next(iter(mtcnn_preprocessed_loader))
  data = data.cuda()  # Move data to GPU if available
  reconstructed_ae        = AEModel(data)
  reconstructed_vae, _, _ = VAEModel(data)

  # Move data back to CPU for plotting
  data = data.cpu()
  reconstructed_ae = reconstructed_ae.cpu()
  reconstructed_vae = reconstructed_vae.cpu()

  # Plot original and reconstructed images
  fig, axes = plt.subplots(3, 5, figsize=(15, 6))
  for i in range(5):
      # Original images
      axes[0, i].imshow(data[i].permute(1, 2, 0).squeeze(), cmap='gray')
      axes[0, i].set_title("Original")
      axes[0, i].axis('off')

      # AE Reconstructed images
      axes[1, i].imshow(reconstructed_ae[i].permute(1, 2, 0).squeeze(), cmap='gray')
      axes[1, i].set_title("Reconstructed AE")
      axes[1, i].axis('off')

      # VAE Reconstructed images
      axes[2, i].imshow(reconstructed_vae[i].permute(1, 2, 0).squeeze(), cmap='gray')
      axes[2, i].set_title("Reconstructed AE")
      axes[2, i].axis('off')

  plt.tight_layout()
  plt.show()

# Sampling w/ random noise

In [None]:

def sample_images(model: nn.Module, num_samples: int = 10):
    model.eval()
    with torch.no_grad():
        # Sample from the latent space for VAE or directly for Autoencoder
        z = torch.randn(num_samples, model.latent_space_size).cuda()  # Latent vector of size latent_space_size
        if isinstance(model, VariationalAutoencoder):
            # For VAE, we can sample from the latent space (e.g., from a normal distribution)
            sampled_images = model.decode(z)  # Decoded images from latent space
        else:
            # For Autoencoder, use random input or encoded images for reconstruction
            sampled_images = model.decode(z)  # Reconstructed images

        # Plot the samples
        sampled_images = sampled_images.cpu()
        plt.figure(figsize=(10, 10))
        for i in range(num_samples):
            plt.subplot(1, num_samples, i+1)
            plt.imshow(sampled_images[i].permute(1, 2, 0).numpy())
            plt.axis('off')
        plt.show()



In [None]:
sample_images(AEModel, num_samples=5)

In [None]:
sample_images(VAEModel, num_samples=5)

# Latent Space Exploration

In [None]:
import torch
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

def visualize_latent_space(autoencoder, vae, dataloader, device='cuda'):
    """
    Visualize the latent space learned by Autoencoder and Variational Autoencoder using PCA in 2D.

    Parameters:
    - autoencoder: The trained Autoencoder model.
    - vae: The trained Variational Autoencoder model.
    - dataloader: DataLoader object providing the dataset (e.g., MNIST).
    - device: Device to run the models on, either 'cuda' or 'cpu'.
    """
    autoencoder.eval()
    vae.eval()

    # Collect the latent representations and labels
    ae_latents = []
    vae_latents = []
    labels = []

    with torch.no_grad():
        for i, (images, lbls) in enumerate(dataloader):
            images = images.to(device)

            # Get latent representations from AE and VAE
            ae_latent = autoencoder.encode(images)
            mu, var = vae.encode(images)
            vae_latent = vae.reparameterize(mu, var)

            ae_latents.append(ae_latent.view(ae_latent.size(0), -1))
            vae_latents.append(vae_latent.view(vae_latent.size(0), -1))
            labels.extend(lbls.cpu().numpy())

            if i > 1:  # Only collect data from a couple of batches for visualization
                break

    # Flatten the latent representations and convert labels to numpy array
    ae_latents = torch.cat(ae_latents).cpu().numpy()
    vae_latents = torch.cat(vae_latents).cpu().numpy()
    labels = np.array(labels)

    # Apply PCA for dimensionality reduction to 2D
    pca = PCA(n_components=2)
    ae_latents_2d = pca.fit_transform(ae_latents)
    vae_latents_2d = pca.fit_transform(vae_latents)

    # Plot the 2D scatter plots
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # Plot for Autoencoder
    scatter_ae = axes[0].scatter(ae_latents_2d[:, 0], ae_latents_2d[:, 1], c=labels, cmap='tab10', alpha=0.7)
    axes[0].set_title('Autoencoder Latent Space')
    axes[0].set_xlabel('Component 1')
    axes[0].set_ylabel('Component 2')
    legend1 = axes[0].legend(*scatter_ae.legend_elements(), title="Labels")
    axes[0].add_artist(legend1)

    # Plot for Variational Autoencoder
    scatter_vae = axes[1].scatter(vae_latents_2d[:, 0], vae_latents_2d[:, 1], c=labels, cmap='tab10', alpha=0.7)
    axes[1].set_title('Variational Autoencoder Latent Space')
    axes[1].set_xlabel('Component 1')
    axes[1].set_ylabel('Component 2')
    legend2 = axes[1].legend(*scatter_vae.legend_elements(), title="Labels")
    axes[1].add_artist(legend2)

    plt.tight_layout()
    plt.show()


In [None]:
visualize_latent_space(AEModel, VAEModel, mtcnn_preprocessed_loader)

# Latent Space Interpolation

In [None]:
def interpolate_latent_space(model: nn.Module, start_image, end_image, num_steps=32, device='cuda'):
    """
    Interpolates between two images in the latent space of the model and plots the result.

    Parameters:
    - model: The trained Autoencoder or Variational Autoencoder model.
    - start_image: The starting image for interpolation (should be in the range [0, 1] and shape (C, H, W)).
    - end_image: The ending image for interpolation (same shape as start_image).
    - num_steps: The number of interpolation steps.
    - device: Device to run the model on, either 'cuda' or 'cpu'.
    """
    model.eval()
    start_image = start_image.unsqueeze(0).to(device)
    end_image = end_image.unsqueeze(0).to(device)

    # Encode both images into their latent representations
    with torch.no_grad():
        if isinstance(model, VariationalAutoencoder):
            # For VAE, get mu and log_var, and reparameterize z
            mu_start, log_var_start = model.encode(start_image)
            mu_end, log_var_end = model.encode(end_image)
            # Reparameterize to get z (latent vectors)
            start_latent = model.reparameterize(mu_start, log_var_start)
            end_latent = model.reparameterize(mu_end, log_var_end)
        else:
            # For AE, directly use the encoded latent vector
            start_latent = model.encode(start_image)
            end_latent = model.encode(end_image)

        # Ensure latent vectors are of the same size
        start_latent = start_latent.view(start_latent.size(0), -1)
        end_latent = end_latent.view(end_latent.size(0), -1)

    # Create linear interpolation between the two latent representations
    latents = []
    for alpha in np.linspace(0, 1, num_steps):
        interpolated_latent = (1 - alpha) * start_latent + alpha * end_latent
        latents.append(interpolated_latent)

    # Decode the interpolated latents
    latents = torch.stack(latents).to(device)
    if isinstance(model, VariationalAutoencoder):
        # For VAE, we decode the reparameterized latents
        interpolated_images = model.decode(latents)
    else:
        # For AE, directly decode the latents
        interpolated_images = model.decode(latents)

    # Plot the results
    interpolated_images = interpolated_images.cpu()
    plt.figure(figsize=(8, 8))
    for i, img in enumerate(interpolated_images):
        plt.subplot(8, 8, i+1)  # 4 rows, 8 columns
        plt.imshow(img.permute(1, 2, 0).detach().numpy())
        plt.axis('off')
    plt.tight_layout()
    plt.show()


In [None]:
random_index1 = torch.randint(0, images.size(0), (1,)).item()
random_index2 = torch.randint(0, images.size(0), (1,)).item()
image1 = images[random_index1]
image2 = images[random_index2]


In [None]:
interpolate_latent_space(VAEModel, image1, image2, num_steps=64, device='cuda')

In [None]:
interpolate_latent_space(AEModel, image1, image2, num_steps=64, device='cuda')