AlexNet trained on TinyImageNet (64x64)

Implementation based on the original paper with minimal differences. Whenever I take a different approach than the paper, I point it out.

Project Setup

In [1]:
import torch
import torchvision
import scorch
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
import os
import wandb
In [2]:
config = {
    'batch_size': 64,
    'epochs': 50,
    'lr': 0.0025
}
In [3]:
wandb.login()
wandb.init(project='alexnet-tiny', config=config)

Data Loading

In [4]:
NUM_CLASSES = 200
DATASET_ROOT = os.path.join("datasets", "tiny-imagenet-200")
In [5]:
wnid_to_index = {}
wnid_to_name = {}
index_to_name = {}

identifiers_in_order = os.listdir(os.path.join(DATASET_ROOT, "train"))
identifiers_in_order = sorted(identifiers_in_order)
identifiers_in_order = [ident for ident in identifiers_in_order if ident.startswith('n')]

with open(os.path.join(DATASET_ROOT, "words.txt"), 'r') as f:
    line = True
    while line:
        line = f.readline().strip()
        if not line:
            break
        id_and_names = line.split("\t")
        identifier = id_and_names[0]
        name = id_and_names[1].split(',')[0]

        wnid_to_name[identifier] = name

for index, wnid in enumerate(identifiers_in_order):
    index_to_name[index] = wnid_to_name[wnid]
    wnid_to_index[wnid] = index
In [6]:
x_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.RandomCrop((56, 56)),
    torchvision.transforms.RandomHorizontalFlip(),
])

y_transform = torchvision.transforms.Compose([
    lambda label: torch.tensor(label),
    lambda label: torch.nn.functional.one_hot(label, NUM_CLASSES).float()
])
In [7]:
train_dataset = torchvision.datasets.ImageFolder(os.path.join(DATASET_ROOT, "train"), transform=x_transform, target_transform=y_transform)
valid_dataset = torchvision.datasets.ImageFolder(os.path.join(DATASET_ROOT, "val"), transform=torchvision.transforms.ToTensor(), target_transform=y_transform)
In [8]:
train_dataloader = scorch.datasets.DataLoader(torch.utils.data.DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=11, prefetch_factor=2, pin_memory=True, persistent_workers=True), cuda_prefetch=True)
valid_dataloader = scorch.datasets.DataLoader(torch.utils.data.DataLoader(valid_dataset, batch_size=config['batch_size'], num_workers=2, prefetch_factor=4, pin_memory=True, persistent_workers=True), cuda_prefetch=True)

Architecture

The original AlexNet operates on 256x256 images. This network operates on 64x64, so whenever possible I try to make the network 4 times smaller, by e.g. using 4 times smaller convolution kernels. This makes trainig also quite a lot faster.

In [9]:
class AlexNet(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.k = 2
        self.n = 5
        self.alpha = 1e-4
        self.beta = 0.75

        self.relu = scorch.nn.ReLU()
        self.overlapping_poll = torch.nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv1 = scorch.nn.Conv2D(in_channels=3, out_channels=96, kernel_size=3, padding=1) # kernel 11 -> 3
        self.conv2_top = scorch.nn.Conv2D(in_channels=48, out_channels=128, kernel_size=3, padding=1) # kernel 5 -> 3
        self.conv2_bottom = scorch.nn.Conv2D(in_channels=48, out_channels=128, kernel_size=3, padding=1) # kernel 5 -> 3
        self.conv3 = scorch.nn.Conv2D(in_channels=256, out_channels=384, kernel_size=1) # kernel 3 -> 1
        self.conv4_top = scorch.nn.Conv2D(in_channels=192, out_channels=192, kernel_size=1) # kernel 3 -> 1
        self.conv4_bottom = scorch.nn.Conv2D(in_channels=192, out_channels=192, kernel_size=1) # kernel 3 -> 1
        self.conv5_top = scorch.nn.Conv2D(in_channels=192, out_channels=128, kernel_size=1) # kernel 3 -> 1
        self.conv5_bottom = scorch.nn.Conv2D(in_channels=192, out_channels=128, kernel_size=1) # kernel 3 -> 1
        self.local_response_normalization = torch.nn.LocalResponseNorm(size=self.n, alpha=self.alpha, beta=self.beta, k=self.k)

        self.dropout = scorch.nn.Dropout(p=0.5)
        self.fc6 = scorch.nn.Linear(in_features=128 * 2 * 6 * 6, out_features=1024)
        self.fc7 = scorch.nn.Linear(in_features=1024, out_features=1024)
        self.fc8 = scorch.nn.Linear(in_features=1024, out_features=NUM_CLASSES)

    def forward_patch(self, patch):
        # Layer 1: 3 -> 96 channels, with LRN and polling
        conv1 = self.relu(self.conv1(patch))
        top, bottom = torch.split(conv1, conv1.size(1) // 2, dim=1)
        top, bottom = self.local_response_normalization(top), self.local_response_normalization(bottom)
        top, bottom = self.overlapping_poll(top), self.overlapping_poll(bottom)
        
        # Layer 2: 96 -> 256 channels, with LRN and polling
        top, bottom = self.relu(self.conv2_top(top)), self.relu(self.conv2_bottom(bottom))
        top, bottom = self.local_response_normalization(top), self.local_response_normalization(bottom)
        top, bottom = self.overlapping_poll(top), self.overlapping_poll(bottom)

        # Layer 3: 256 -> 384 channels, only ReLU, all together
        stacked = torch.cat((top, bottom), dim=1)
        conv3 = self.relu(self.conv3(stacked))

        # Layer 4: 384 -> 384, only ReLU, per GPU
        top, bottom = torch.split(conv3, conv3.size(1) // 2, dim=1)
        top, bottom = self.relu(self.conv4_top(top)), self.relu(self.conv4_bottom(bottom))

        # Layer 5: 384 -> 256, with polling, per GPU
        top, bottom = self.relu(self.conv5_top(top)), self.relu(self.conv5_bottom(bottom))
        top, bottom = self.overlapping_poll(top), self.overlapping_poll(bottom)

        # Layer 6: [BATCH, 256, 7, 7] -> [BATCH, 4096], with relu and dropout
        stacked = torch.cat((top, bottom), dim=1) # [BATCH, 256, 7, 7]
        stacked = torch.flatten(stacked, start_dim=1)
        stacked = self.dropout(self.relu(self.fc6(stacked)))

        # Layer 7: 4096 -> 4096, with relu and dropout
        stacked = self.dropout(self.relu(self.fc7(stacked)))

        # Layer 8: 4096 -> 200, with no activations so it outputs logits directly
        out = self.fc8(stacked)

        return out

    def forward(self, x):
        if self.training:
            return self.forward_patch(x)
        else:
            top_left_result = self.forward_patch(x[:,:,:-8,:-8])
            top_right_result = self.forward_patch(x[:,:,8:,:-8])
            bottom_left_result = self.forward_patch(x[:,:,:-8,:-8])
            bottom_right_result = self.forward_patch(x[:,:,8:,8:])
            center_result = self.forward_patch(x[:,:,4:-4,4:-4])

            results = torch.stack((
                top_left_result, top_right_result, bottom_left_result, bottom_right_result, center_result
            ), dim=1)
            mean_result = torch.mean(results, dim=1)

            return mean_result

Training

In [10]:
alexnet = AlexNet()
In [11]:
loss_logger = scorch.execution.LossLogger(
    epochs=config['epochs'],
    train_batches=len(train_dataloader),
    valid_batches=len(valid_dataloader),
    log_every_batch=10,
    log_wandb=True,
    log_console=False
)

classifier_logger = scorch.execution.ClassifierLogger(
    classes=NUM_CLASSES,
    index_to_name=index_to_name,
    epochs=config['epochs'],
    log_detailed_on_last_epoch=True,
    log_wandb=True,
    log_console=True,
    only_valid=False,
    skip_batches=10,
)

profile_logger = scorch.execution.ProfileLogger(
    log_every_batch=50,
    log_console=False,
    log_wandb=True
)

lr_logger = scorch.execution.LearningRateLogger(
    log_console=False,
    log_wandb=True
)

optimizer = torch.optim.Adam(alexnet.parameters())
In [12]:
PEAK_TRAINING_EPOCH = 5

runner = scorch.execution.Runner(
    model=alexnet,
    training_loader=train_dataloader,
    optimizer=optimizer,
    loss=torch.nn.BCEWithLogitsLoss(),
    epochs=config['epochs'],
    validation_loader=valid_dataloader,
    lr_scheduler=torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=config['lr'], total_steps=config['epochs'], pct_start=PEAK_TRAINING_EPOCH/config['epochs']),
    loggers=(loss_logger, classifier_logger, lr_logger),
    profiling_loggers=(profile_logger,)
)
Using device: cuda
In [13]:
runner.run(train=True, validate=True)
Epoch 1/50 Error Rates
	Top-1 Train: 0.9960191082802548
	Top-3 Train: 0.987062101910828
	Top-5 Train: 0.9768113057324841
	Top-10 Train: 0.9535230891719745
	Top-1 Validation: 0.9949919871794872
	Top-3 Validation: 0.9849759615384616
	Top-5 Validation: 0.9749599358974359
	Top-10 Validation: 0.9499198717948718
Epoch 2/50 Error Rates
	Top-1 Train: 0.9931891025641025
	Top-3 Train: 0.9829727564102564
	Top-5 Train: 0.9694511217948718
	Top-10 Train: 0.9395032051282052
	Top-1 Validation: 0.9884815705128205
	Top-3 Validation: 0.9660456730769231
	Top-5 Validation: 0.9441105769230769
	Top-10 Validation: 0.8921274038461539

[...]

Epoch 49/50 Error Rates
	Top-1 Train: 0.7338741987179487
	Top-3 Train: 0.5584935897435898
	Top-5 Train: 0.4661458333333333
	Top-10 Train: 0.3414463141025641
	Top-1 Validation: 0.7038261217948718
	Top-3 Validation: 0.5205328525641025
	Top-5 Validation: 0.4276842948717949
	Top-10 Validation: 0.30859375
Epoch 50/50 Error Rates
	Top-1 Train: 0.7373798076923077
	Top-3 Train: 0.5614983974358975
	Top-5 Train: 0.46504407051282054
	Top-10 Train: 0.3414463141025641
	Top-1 Validation: 0.7037259615384616
	Top-3 Validation: 0.5206330128205128
	Top-5 Validation: 0.42758413461538464
	Top-10 Validation: 0.30859375

Quantitative Results

In [14]:
loss_logger.plot()
No description has been provided for this image
In [15]:
classifier_logger.plot_error_rates()
No description has been provided for this image

Qualitative Results

In [16]:
import random

predictions_by_confidence = sorted(classifier_logger.valid['predictions'][-1], key=lambda p: -p.confidence)
wrong_predictions = list(filter(lambda p: not p.correct, predictions_by_confidence))
correct_predictions = list(filter(lambda p: p.correct, predictions_by_confidence))
random_predictions = random.sample(predictions_by_confidence, k=200)
In [17]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(wrong_predictions[::-1][:5], title='Least Confident, Wrong Predictions', top_k=10, rows=1)
No description has been provided for this image
In [18]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(wrong_predictions[:5], title='Most Confident, Wrong Predictions', top_k=10)
No description has been provided for this image
In [19]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(correct_predictions[:5], title='Most Confident, Correct Predictions', top_k=10)
No description has been provided for this image
In [20]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(correct_predictions[::-1][:5], title='Least Confident, Correct Predictions', top_k=10)
No description has been provided for this image
In [21]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(random_predictions[:15], title='Randomly Sampled Predictions', rows=3, scale=1, top_k=10)
No description has been provided for this image
In [22]:
wandb.finish()


Run history:


batch▁▁▁▁▁▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▁▆▇▇▇▇▇▇▇▇█████
epoch▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▇▇▇█████
lr/group-0▁▁▂▅▇██████▇▇▇▇▆▆▆▆▆▅▅▅▄▄▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁
train/loop-time▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss_batch██▇▅▅▄▄▃▃▄▃▄▃▄▃▃▃▃▂▄▂▃▂▃▂▃▂▁▂▁▂▂▂▁▁▃▁▂▂▁
train/loss_epoch█▄▄▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/throughput██████▁█████████████████████████████████
train/timing-backward▂▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▁▁▂▂▂▁▂▅▂▂█
train/timing-batch_load▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▄▁▁█▁▁▁▁▁▁▁▁▁
train/timing-data_move▅▁▃▂▂▄▇▁▄▅▄▁▆▂▄▄▂▁▇▆▅▇▅▅▆▂▅▅▆▃▁▂▆█▂▂▂▄▁▂
+15...


Run summary:


batch7839
epoch49
lr/group-00.0
train/loop-time0.03445
train/loss_batch0.01963
train/loss_epoch0.02075
train/throughput1857.96404
train/timing-backward0.01396
train/timing-batch_load0.00021
train/timing-data_move1e-05
+15...