Variational Auto-Encoder on MNIST¶

In [1]:
import scorch
import torch
import torchvision
import wandb

Data¶

In [3]:
dataset_path = "./datasets/mnist"

train_dataset_base = torchvision.datasets.MNIST(dataset_path, transform=torchvision.transforms.ToTensor(), train=True, download=True)
test_dataset_base = torchvision.datasets.MNIST(dataset_path, transform=torchvision.transforms.ToTensor(), train=False, download=True)
train_dataset = scorch.datasets.AutoEncoderDataset(train_dataset_base)
test_dataset = scorch.datasets.AutoEncoderDataset(test_dataset_base)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16)

Architecture¶

In [4]:
LATENT_SIZE = 32
In [5]:
class VAE(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = torch.nn.Sequential(
            scorch.nn.Flatten(),
            scorch.nn.Linear(28 * 28, 512),
            scorch.nn.ReLU(),
            scorch.nn.Linear(512, 128),
            scorch.nn.ReLU(),
            scorch.nn.Linear(128, LATENT_SIZE * 2))

        self.decoder = torch.nn.Sequential(
            scorch.nn.Linear(LATENT_SIZE, 128),
            scorch.nn.ReLU(),
            scorch.nn.Linear(128, 512),
            scorch.nn.ReLU(),
            scorch.nn.Linear(512, 28*28),
            scorch.nn.Sigmoid(),
            scorch.nn.Unflatten(-1, (1, 28, 28))
        )

    def encode(self, x):
        latent = self.encoder(x)
        mu, logvar = torch.chunk(latent, 2, -1)
        return mu, logvar
        
    def sample(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        sampled = mu + std * torch.randn_like(mu)
        return sampled
 
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.sample(mu, logvar)
        out = self.decoder(z)
        return out, mu, logvar

vae = VAE()

Training¶

In [6]:
def VAELoss(pred, target):
    out, mu, logvar = pred
    kld = scorch.utils.math.kl_normal(mu, logvar) / torch.numel(out)
    l2 = scorch.nn.losses.MSE(out, target)
    return l2 + kld
In [7]:
loss_logger = scorch.execution.LossLogger(
    epochs=50,
    train_batches=len(train_loader),
    valid_batches=len(test_loader),
    log_every_batch=10,
    log_console=False,
    log_wandb=True
)

profiling_logger = scorch.execution.ProfileLogger(
    log_every_batch=50,
    log_console=False,
    log_wandb=True
)
In [8]:
optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)
trainer = scorch.execution.Runner(
    vae,
    train_loader,
    optimizer,
    VAELoss,
    epochs=50,
    validation_loader=test_loader,
    loggers=(loss_logger,),
    profiling_loggers=(profiling_logger,)
)
Using device: cpu
In [9]:
stats = trainer.run(train=True, validate=True)
In [10]:
loss_logger.plot()
No description has been provided for this image
In [ ]:
xs, _ = next(iter(test_loader))
pred, __, _ = vae(xs)

scorch.utils.plot_images(xs, pred.view(-1, 28, 28), title='Reconstructions')
No description has been provided for this image
In [12]:
import torch
import matplotlib.pyplot as plt

def bilinear_interpolate(latents, grid_size):
    TL, TR, BL, BR = latents
    grid = []

    for i in torch.linspace(0, 1, grid_size):
        top = TL * (1 - i) + TR * i
        bottom = BL * (1 - i) + BR * i
        row = [top * (1 - j) + bottom * j for j in torch.linspace(0, 1, grid_size)]
        grid.append(torch.stack(row))
    
    return torch.stack(grid).view(-1, TL.shape[-1])

def show_latent_grid(vae, dataloader, grid_size=10):
    vae.eval()

    images, _ = next(iter(dataloader))
    x1, x2, x3, x4 = images[7], images[1], images[2], images[3]

    with torch.no_grad():
        z1, _ = vae.encode(x1.unsqueeze(0))  
        z2, _ = vae.encode(x2.unsqueeze(0))  
        z3, _ = vae.encode(x3.unsqueeze(0))  
        z4, _ = vae.encode(x4.unsqueeze(0))  

        interpolated_z = bilinear_interpolate([z1, z2, z3, z4], grid_size)
        decoded = vae.decoder(interpolated_z).cpu()

    fig, axes = plt.subplots(grid_size, grid_size, figsize=(grid_size, grid_size))
    for i in range(grid_size):
        for j in range(grid_size):
            idx = i * grid_size + j
            axes[i, j].imshow(decoded[idx][0], cmap='gray')
            axes[i, j].axis('off')
    plt.tight_layout()
    plt.show()

show_latent_grid(vae, test_loader, grid_size=20)
No description has been provided for this image
In [13]:
import numpy as np

labeled_test_dataloader = torch.utils.data.DataLoader(test_dataset_base, batch_size=1)

all_labels = []
all_points = np.zeros((len(labeled_test_dataloader), 32), dtype=np.float32)

vae.eval()
with torch.no_grad():
    for i, (image, label) in enumerate(labeled_test_dataloader):
        mu, logvar = vae.encode(image)
        all_points[i] = mu[0].detach()
        all_labels.append(str(int(label[0])))
/var/folders/gf/1j153mhn49d76jzbs1zt3f5h0000gn/T/ipykernel_22236/2576094665.py:12: DeprecationWarning: __array__ implementation doesn't accept a copy keyword, so passing copy=False failed. __array__ must implement 'dtype' and 'copy' keyword arguments. To learn more, see the migration guide https://numpy.org/devdocs/numpy_2_0_migration_guide.html#adapting-to-changes-in-the-copy-keyword
  all_points[i] = mu[0].detach()
In [14]:
import sklearn

all_points_compressed = sklearn.decomposition.PCA(n_components=2).fit_transform(all_points)

pts_per_digit = {str(i): {'x': [], 'y': []} for i in range(10)}

for pt, l in zip(all_points_compressed, all_labels):
    pts_per_digit[l]['x'].append(pt[0])
    pts_per_digit[l]['y'].append(pt[1])

plt.grid()
for digit, points in pts_per_digit.items():
    plt.scatter(points['x'], points['y'], label=digit, s=1.0, alpha=0.5)
plt.suptitle('PCA')
plt.legend()
Out[14]:
<matplotlib.legend.Legend at 0x32a5d95d0>
No description has been provided for this image
In [15]:
all_points_compressed_tsne = sklearn.manifold.TSNE(n_components=2).fit_transform(all_points)
In [16]:
pts_per_digit_tsne = {str(i): {'x': [], 'y': []} for i in range(10)}

for pt, l in zip(all_points_compressed_tsne, all_labels):
    pts_per_digit_tsne[l]['x'].append(pt[0])
    pts_per_digit_tsne[l]['y'].append(pt[1])

plt.grid()
for digit, points in pts_per_digit_tsne.items():
    plt.scatter(points['x'], points['y'], label=digit, s=2.0, alpha=0.5)
plt.suptitle('t-SNE')
plt.legend()
Out[16]:
<matplotlib.legend.Legend at 0x3345205d0>
No description has been provided for this image