AlexNet trained on TinyImageNet (64x64)¶
Implementation based on the original paper with minimal differences. Whenever I take a different approach than the paper, I point it out.
Project Setup¶
In [1]:
import torch
import torchvision
import scorch
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
import os
import wandb
In [2]:
config = {
'batch_size': 64,
'epochs': 50,
'lr': 0.0025
}
In [3]:
wandb.login()
wandb.init(project='alexnet-tiny', config=config)
Data Loading¶
In [4]:
NUM_CLASSES = 200
DATASET_ROOT = os.path.join("datasets", "tiny-imagenet-200")
In [5]:
wnid_to_index = {}
wnid_to_name = {}
index_to_name = {}
identifiers_in_order = os.listdir(os.path.join(DATASET_ROOT, "train"))
identifiers_in_order = sorted(identifiers_in_order)
identifiers_in_order = [ident for ident in identifiers_in_order if ident.startswith('n')]
with open(os.path.join(DATASET_ROOT, "words.txt"), 'r') as f:
line = True
while line:
line = f.readline().strip()
if not line:
break
id_and_names = line.split("\t")
identifier = id_and_names[0]
name = id_and_names[1].split(',')[0]
wnid_to_name[identifier] = name
for index, wnid in enumerate(identifiers_in_order):
index_to_name[index] = wnid_to_name[wnid]
wnid_to_index[wnid] = index
In [6]:
x_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.RandomCrop((56, 56)),
torchvision.transforms.RandomHorizontalFlip(),
])
y_transform = torchvision.transforms.Compose([
lambda label: torch.tensor(label),
lambda label: torch.nn.functional.one_hot(label, NUM_CLASSES).float()
])
In [7]:
train_dataset = torchvision.datasets.ImageFolder(os.path.join(DATASET_ROOT, "train"), transform=x_transform, target_transform=y_transform)
valid_dataset = torchvision.datasets.ImageFolder(os.path.join(DATASET_ROOT, "val"), transform=torchvision.transforms.ToTensor(), target_transform=y_transform)
In [8]:
train_dataloader = scorch.datasets.DataLoader(torch.utils.data.DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=11, prefetch_factor=2, pin_memory=True, persistent_workers=True), cuda_prefetch=True)
valid_dataloader = scorch.datasets.DataLoader(torch.utils.data.DataLoader(valid_dataset, batch_size=config['batch_size'], num_workers=2, prefetch_factor=4, pin_memory=True, persistent_workers=True), cuda_prefetch=True)
Architecture¶
The original AlexNet operates on 256x256 images. This network operates on 64x64, so whenever possible I try to make the network 4 times smaller, by e.g. using 4 times smaller convolution kernels. This makes trainig also quite a lot faster.
In [9]:
class AlexNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.k = 2
self.n = 5
self.alpha = 1e-4
self.beta = 0.75
self.relu = scorch.nn.ReLU()
self.overlapping_poll = torch.nn.MaxPool2d(kernel_size=3, stride=2)
self.conv1 = scorch.nn.Conv2D(in_channels=3, out_channels=96, kernel_size=3, padding=1) # kernel 11 -> 3
self.conv2_top = scorch.nn.Conv2D(in_channels=48, out_channels=128, kernel_size=3, padding=1) # kernel 5 -> 3
self.conv2_bottom = scorch.nn.Conv2D(in_channels=48, out_channels=128, kernel_size=3, padding=1) # kernel 5 -> 3
self.conv3 = scorch.nn.Conv2D(in_channels=256, out_channels=384, kernel_size=1) # kernel 3 -> 1
self.conv4_top = scorch.nn.Conv2D(in_channels=192, out_channels=192, kernel_size=1) # kernel 3 -> 1
self.conv4_bottom = scorch.nn.Conv2D(in_channels=192, out_channels=192, kernel_size=1) # kernel 3 -> 1
self.conv5_top = scorch.nn.Conv2D(in_channels=192, out_channels=128, kernel_size=1) # kernel 3 -> 1
self.conv5_bottom = scorch.nn.Conv2D(in_channels=192, out_channels=128, kernel_size=1) # kernel 3 -> 1
self.local_response_normalization = torch.nn.LocalResponseNorm(size=self.n, alpha=self.alpha, beta=self.beta, k=self.k)
self.dropout = scorch.nn.Dropout(p=0.5)
self.fc6 = scorch.nn.Linear(in_features=128 * 2 * 6 * 6, out_features=1024)
self.fc7 = scorch.nn.Linear(in_features=1024, out_features=1024)
self.fc8 = scorch.nn.Linear(in_features=1024, out_features=NUM_CLASSES)
def forward_patch(self, patch):
# Layer 1: 3 -> 96 channels, with LRN and polling
conv1 = self.relu(self.conv1(patch))
top, bottom = torch.split(conv1, conv1.size(1) // 2, dim=1)
top, bottom = self.local_response_normalization(top), self.local_response_normalization(bottom)
top, bottom = self.overlapping_poll(top), self.overlapping_poll(bottom)
# Layer 2: 96 -> 256 channels, with LRN and polling
top, bottom = self.relu(self.conv2_top(top)), self.relu(self.conv2_bottom(bottom))
top, bottom = self.local_response_normalization(top), self.local_response_normalization(bottom)
top, bottom = self.overlapping_poll(top), self.overlapping_poll(bottom)
# Layer 3: 256 -> 384 channels, only ReLU, all together
stacked = torch.cat((top, bottom), dim=1)
conv3 = self.relu(self.conv3(stacked))
# Layer 4: 384 -> 384, only ReLU, per GPU
top, bottom = torch.split(conv3, conv3.size(1) // 2, dim=1)
top, bottom = self.relu(self.conv4_top(top)), self.relu(self.conv4_bottom(bottom))
# Layer 5: 384 -> 256, with polling, per GPU
top, bottom = self.relu(self.conv5_top(top)), self.relu(self.conv5_bottom(bottom))
top, bottom = self.overlapping_poll(top), self.overlapping_poll(bottom)
# Layer 6: [BATCH, 256, 7, 7] -> [BATCH, 4096], with relu and dropout
stacked = torch.cat((top, bottom), dim=1) # [BATCH, 256, 7, 7]
stacked = torch.flatten(stacked, start_dim=1)
stacked = self.dropout(self.relu(self.fc6(stacked)))
# Layer 7: 4096 -> 4096, with relu and dropout
stacked = self.dropout(self.relu(self.fc7(stacked)))
# Layer 8: 4096 -> 200, with no activations so it outputs logits directly
out = self.fc8(stacked)
return out
def forward(self, x):
if self.training:
return self.forward_patch(x)
else:
top_left_result = self.forward_patch(x[:,:,:-8,:-8])
top_right_result = self.forward_patch(x[:,:,8:,:-8])
bottom_left_result = self.forward_patch(x[:,:,:-8,:-8])
bottom_right_result = self.forward_patch(x[:,:,8:,8:])
center_result = self.forward_patch(x[:,:,4:-4,4:-4])
results = torch.stack((
top_left_result, top_right_result, bottom_left_result, bottom_right_result, center_result
), dim=1)
mean_result = torch.mean(results, dim=1)
return mean_result
Training¶
In [10]:
alexnet = AlexNet()
In [11]:
loss_logger = scorch.execution.LossLogger(
epochs=config['epochs'],
train_batches=len(train_dataloader),
valid_batches=len(valid_dataloader),
log_every_batch=10,
log_wandb=True,
log_console=False
)
classifier_logger = scorch.execution.ClassifierLogger(
classes=NUM_CLASSES,
index_to_name=index_to_name,
epochs=config['epochs'],
log_detailed_on_last_epoch=True,
log_wandb=True,
log_console=True,
only_valid=False,
skip_batches=10,
)
profile_logger = scorch.execution.ProfileLogger(
log_every_batch=50,
log_console=False,
log_wandb=True
)
lr_logger = scorch.execution.LearningRateLogger(
log_console=False,
log_wandb=True
)
optimizer = torch.optim.Adam(alexnet.parameters())
In [12]:
PEAK_TRAINING_EPOCH = 5
runner = scorch.execution.Runner(
model=alexnet,
training_loader=train_dataloader,
optimizer=optimizer,
loss=torch.nn.BCEWithLogitsLoss(),
epochs=config['epochs'],
validation_loader=valid_dataloader,
lr_scheduler=torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=config['lr'], total_steps=config['epochs'], pct_start=PEAK_TRAINING_EPOCH/config['epochs']),
loggers=(loss_logger, classifier_logger, lr_logger),
profiling_loggers=(profile_logger,)
)
In [13]:
runner.run(train=True, validate=True)
Quantitative Results¶
In [14]:
loss_logger.plot()
In [15]:
classifier_logger.plot_error_rates()
Qualitative Results¶
In [16]:
import random
predictions_by_confidence = sorted(classifier_logger.valid['predictions'][-1], key=lambda p: -p.confidence)
wrong_predictions = list(filter(lambda p: not p.correct, predictions_by_confidence))
correct_predictions = list(filter(lambda p: p.correct, predictions_by_confidence))
random_predictions = random.sample(predictions_by_confidence, k=200)
In [17]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(wrong_predictions[::-1][:5], title='Least Confident, Wrong Predictions', top_k=10, rows=1)
In [18]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(wrong_predictions[:5], title='Most Confident, Wrong Predictions', top_k=10)
In [19]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(correct_predictions[:5], title='Most Confident, Correct Predictions', top_k=10)
In [20]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(correct_predictions[::-1][:5], title='Least Confident, Correct Predictions', top_k=10)
In [21]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(random_predictions[:15], title='Randomly Sampled Predictions', rows=3, scale=1, top_k=10)
In [22]:
wandb.finish()