Lenet-5 for MNIST Digit Recognition

Implementation based on the original paper with minimal differences. Whenever I take a different approach than the paper, I point it out.

Setting up the project

In [1]:
import importlib
import torch
import torchvision
import scorch
import wandb
In [2]:
wandb.login()
Out[2]:
True
In [3]:
config = {
    'batch_size': 128,
    'epochs': 20,
    'lr': 0.001,
    'lr_gamma': 0.9
}
In [4]:
wandb.init(project="lenet-5", config=config)

Data Normalization

The paper describes that data is normalized so that background corresponds to -0.1 and foreground to 1.175. I calculate the assumed mean and std from these values.

In [5]:
dataset_path = "./datasets/mnist"

# background and foreground values from the paper
# out = (in - mean) / std
# -0.1 = (0 - mean) / std => -0.1 = -mean / std => mean = 0.1std
# 1.175 = (1 - mean) / std => 1.175std = 1 - mean => 1.175std - 1 = -mean => mean = 1 - 1.175std
# => 0.1std = 1 - 1.175std => 1.275std = 1 => std = 0.78 => mean = 0.078

transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Pad(2, fill=0),
    torchvision.transforms.Normalize(mean=0.078, std=0.78)
])
label_transform = torchvision.transforms.Compose([
    lambda label: torch.tensor([label]),
    lambda label: torch.nn.functional.one_hot(label, num_classes=10).squeeze().float()
])

train_dataset_base = torchvision.datasets.MNIST(dataset_path, transform=transforms, target_transform=label_transform, train=True, download=True)
test_dataset_base = torchvision.datasets.MNIST(dataset_path, transform=transforms, target_transform=label_transform, train=False, download=True)
train_dataloader = torch.utils.data.DataLoader(train_dataset_base, batch_size=config['batch_size'], shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset_base, batch_size=config['batch_size'])

Architecture

Replicates the model described in the paper, but I do not encode each digit with a hardcoded 7x12 bitmap so it can be used in an Euclidean Radial Basis Function. Instead, I use Softmax to directly compute probabilities over classes

In [6]:
class SubSampling(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.coefficient = torch.nn.Parameter(torch.rand(size=(1,)) * 2 - 1)
        self.bias = torch.nn.Parameter(torch.rand(size=(1,)) * 2 - 1)
        self.tanh = scorch.nn.Tanh()

    def forward(self, x):
        polled = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
        return self.tanh(self.coefficient * polled + self.bias)
In [7]:
class SparseConv2D(torch.nn.Module):
    def __init__(self, connectivity_table):
        super().__init__()

        self.conv = scorch.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5)
        mask = torch.zeros_like(self.conv.weights)

        for in_ch, out_ch in connectivity_table:
            mask[out_ch, in_ch] = 1

        self.register_buffer('mask', mask)

    def forward(self, x):
        self.conv.weights *= self.mask
        return self.conv(x)
In [8]:
connectivity_table = [
    (0, 0), (0, 4), (0, 5), (0, 6), (0, 9), (0, 10), (0, 11), (0, 12), (0, 14), (0, 15),
    (1, 0), (1, 1), (1, 5), (1, 6), (1, 7), (1, 10), (1, 11), (1, 12), (1, 13), (1, 15),
    (2, 0), (2, 1), (2, 2), (2, 6), (2, 7), (2, 8), (2, 11), (2, 13), (2, 14), (2, 15),
    (3, 1), (3, 2), (3, 3), (3, 6), (3, 7), (3, 8), (3, 9), (3, 12), (3, 14), (3, 15),
    (4, 2), (4, 3), (4, 4), (4, 7), (4, 8), (4, 9), (4, 10), (4, 12), (4, 13), (4, 15),
    (5, 3), (5, 4), (5, 5), (5, 8), (5, 9), (5, 10), (5, 11), (5, 13), (5, 14), (5, 15)
]

lenet = torch.nn.Sequential(
    # C1
    scorch.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5),
    scorch.nn.Tanh(),
    # S2
    SubSampling(),
    # C3
    scorch.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5),
    scorch.nn.Tanh(),
    # S4
    SubSampling(),
    # C5
    scorch.nn.Conv2D(in_channels=16, out_channels=120, kernel_size=5),
    scorch.nn.Flatten(),
    scorch.nn.Tanh(),
    # F6
    scorch.nn.Linear(120, 84),
    scorch.nn.Tanh(),
    # OUTPUT
    scorch.nn.Linear(84, 10), # paper has RBF connections, but I'll skip that
    scorch.nn.Softmax()
)
In [9]:
loss_logger = scorch.execution.LossLogger(
    epochs=config['epochs'],
    train_batches=len(train_dataloader),
    valid_batches=len(test_dataloader),
    log_wandb=True,
    log_console=True
)

classification_logger = scorch.execution.ClassifierLogger(
    classes=10,
    index_to_name=[str(i) for i in range(10)],
    epochs=config['epochs'],
    log_detailed_on_last_epoch=True,
    log_wandb=True,
    log_console=True
)

lr_logger = scorch.execution.LearningRateLogger(
    log_console=False,
    log_wandb=True
)
In [10]:
optimizer = torch.optim.Adam(lenet.parameters(), lr=config['lr'])
In [11]:
trainer = scorch.execution.Runner(
    model=lenet,
    training_loader=train_dataloader,
    optimizer=optimizer,
    lr_scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=config['lr_gamma']),
    loss=scorch.nn.BCELoss,
    epochs=config['epochs'],
    validation_loader=test_dataloader,
    loggers=(
        loss_logger,
        classification_logger,
        lr_logger
    )
)
Using device: cpu
In [12]:
stats = trainer.run(train=True, validate=True)
Epoch 1/20
	Epoch Train Loss: 0.11362675577402115
	Epoch Validation Loss: 0.02596954070031643
Epoch 1/20 Error Rates
	Top-1 Train: 0.25806236673773986
	Top-3 Train: 0.13621068763326227
	Top-5 Train: 0.08480477078891258
	Top-10 Train: 0.0
	Top-1 Validation: 0.044501582278481014
	Top-3 Validation: 0.005834651898734177
	Top-5 Validation: 0.0004944620253164557
	Top-10 Validation: 0.0
Epoch 2/20
	Epoch Train Loss: 0.019073758274316788
	Epoch Validation Loss: 0.01291697844862938
Epoch 2/20 Error Rates
	Top-1 Train: 0.03243825515280739
	Top-3 Train: 0.00396455223880597
	Top-5 Train: 0.000882862473347548
	Top-10 Train: 0.0
	Top-1 Validation: 0.02254746835443038
	Top-3 Validation: 0.0020767405063291137
	Top-5 Validation: 9.889240506329114e-05
	Top-10 Validation: 0.0

[...]

Epoch 19/20
	Epoch Train Loss: 0.0011879552621394396
	Epoch Validation Loss: 0.006232853978872299
Epoch 19/20 Error Rates
	Top-1 Train: 0.001199360341151386
	Top-3 Train: 0.0001166044776119403
	Top-5 Train: 3.331556503198294e-05
	Top-10 Train: 0.0
	Top-1 Validation: 0.010977056962025316
	Top-3 Validation: 0.0004944620253164557
	Top-5 Validation: 9.889240506329114e-05
	Top-10 Validation: 0.0

Quantitative Results

In [13]:
loss_logger.plot(ylabel='BCE Loss')
No description has been provided for this image
In [14]:
classification_logger.plot_error_rates(ks=(1, 3))
No description has been provided for this image
In [15]:
classification_logger.plot_confusion_matrix()
No description has been provided for this image

Qualitative Results

In [16]:
predictions_by_confidence = sorted(classification_logger.valid['predictions'][-1], key=lambda p: -p.confidence)
wrong_predictions = list(filter(lambda p: not p.correct, predictions_by_confidence))
correct_predictions = list(filter(lambda p: p.correct, predictions_by_confidence))
In [17]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(wrong_predictions[::-1][:5], title='Least Confident, Wrong Predictions')
No description has been provided for this image
In [18]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(wrong_predictions[:5], title='Most Confident, Wrong Predictions')
No description has been provided for this image
In [19]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(correct_predictions[:5], title='Most Confident, Correct Predictions')
No description has been provided for this image
In [20]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(correct_predictions[::-1][:5], title='Least Confident, Correct Predictions')
No description has been provided for this image
In [21]:
wandb.finish()


Run history:


batch
epoch▁▁▁▁▁▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇█████
lr/group-0█▇▆▆▅▅▄▄▃▃▃▂▂▂▂▂▁▁▁▁
train/loss_epoch█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/top-1█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/top-10▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/top-3█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/top-5█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid/loss_epoch█▃▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁
valid/top-1█▃▂▂▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁
valid/top-10▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid/top-3█▃▂▁▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁
valid/top-5█▂▂▂▂▂▄▂▁▂▄▂▂▂▁▂▂▂▂


Run summary:


batch0
epoch19
lr/group-00.00014
train/loss_epoch0.00119
train/top-10.0012
train/top-100
train/top-30.00012
train/top-53e-05
valid/loss_epoch0.00623
valid/top-10.01098
valid/top-100
valid/top-30.00049
valid/top-50.0001