Lenet-5 for MNIST Digit Recognition¶
Implementation based on the original paper with minimal differences. Whenever I take a different approach than the paper, I point it out.
Setting up the project¶
In [1]:
import importlib
import torch
import torchvision
import scorch
import wandb
In [2]:
wandb.login()
Out[2]:
In [3]:
config = {
'batch_size': 128,
'epochs': 20,
'lr': 0.001,
'lr_gamma': 0.9
}
In [4]:
wandb.init(project="lenet-5", config=config)
Data Normalization¶
The paper describes that data is normalized so that background corresponds to -0.1 and foreground to 1.175. I calculate the assumed mean and std from these values.
In [5]:
dataset_path = "./datasets/mnist"
# background and foreground values from the paper
# out = (in - mean) / std
# -0.1 = (0 - mean) / std => -0.1 = -mean / std => mean = 0.1std
# 1.175 = (1 - mean) / std => 1.175std = 1 - mean => 1.175std - 1 = -mean => mean = 1 - 1.175std
# => 0.1std = 1 - 1.175std => 1.275std = 1 => std = 0.78 => mean = 0.078
transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Pad(2, fill=0),
torchvision.transforms.Normalize(mean=0.078, std=0.78)
])
label_transform = torchvision.transforms.Compose([
lambda label: torch.tensor([label]),
lambda label: torch.nn.functional.one_hot(label, num_classes=10).squeeze().float()
])
train_dataset_base = torchvision.datasets.MNIST(dataset_path, transform=transforms, target_transform=label_transform, train=True, download=True)
test_dataset_base = torchvision.datasets.MNIST(dataset_path, transform=transforms, target_transform=label_transform, train=False, download=True)
train_dataloader = torch.utils.data.DataLoader(train_dataset_base, batch_size=config['batch_size'], shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset_base, batch_size=config['batch_size'])
Architecture¶
Replicates the model described in the paper, but I do not encode each digit with a hardcoded 7x12 bitmap so it can be used in an Euclidean Radial Basis Function. Instead, I use Softmax to directly compute probabilities over classes
In [6]:
class SubSampling(torch.nn.Module):
def __init__(self):
super().__init__()
self.coefficient = torch.nn.Parameter(torch.rand(size=(1,)) * 2 - 1)
self.bias = torch.nn.Parameter(torch.rand(size=(1,)) * 2 - 1)
self.tanh = scorch.nn.Tanh()
def forward(self, x):
polled = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return self.tanh(self.coefficient * polled + self.bias)
In [7]:
class SparseConv2D(torch.nn.Module):
def __init__(self, connectivity_table):
super().__init__()
self.conv = scorch.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5)
mask = torch.zeros_like(self.conv.weights)
for in_ch, out_ch in connectivity_table:
mask[out_ch, in_ch] = 1
self.register_buffer('mask', mask)
def forward(self, x):
self.conv.weights *= self.mask
return self.conv(x)
In [8]:
connectivity_table = [
(0, 0), (0, 4), (0, 5), (0, 6), (0, 9), (0, 10), (0, 11), (0, 12), (0, 14), (0, 15),
(1, 0), (1, 1), (1, 5), (1, 6), (1, 7), (1, 10), (1, 11), (1, 12), (1, 13), (1, 15),
(2, 0), (2, 1), (2, 2), (2, 6), (2, 7), (2, 8), (2, 11), (2, 13), (2, 14), (2, 15),
(3, 1), (3, 2), (3, 3), (3, 6), (3, 7), (3, 8), (3, 9), (3, 12), (3, 14), (3, 15),
(4, 2), (4, 3), (4, 4), (4, 7), (4, 8), (4, 9), (4, 10), (4, 12), (4, 13), (4, 15),
(5, 3), (5, 4), (5, 5), (5, 8), (5, 9), (5, 10), (5, 11), (5, 13), (5, 14), (5, 15)
]
lenet = torch.nn.Sequential(
# C1
scorch.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5),
scorch.nn.Tanh(),
# S2
SubSampling(),
# C3
scorch.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5),
scorch.nn.Tanh(),
# S4
SubSampling(),
# C5
scorch.nn.Conv2D(in_channels=16, out_channels=120, kernel_size=5),
scorch.nn.Flatten(),
scorch.nn.Tanh(),
# F6
scorch.nn.Linear(120, 84),
scorch.nn.Tanh(),
# OUTPUT
scorch.nn.Linear(84, 10), # paper has RBF connections, but I'll skip that
scorch.nn.Softmax()
)
In [9]:
loss_logger = scorch.execution.LossLogger(
epochs=config['epochs'],
train_batches=len(train_dataloader),
valid_batches=len(test_dataloader),
log_wandb=True,
log_console=True
)
classification_logger = scorch.execution.ClassifierLogger(
classes=10,
index_to_name=[str(i) for i in range(10)],
epochs=config['epochs'],
log_detailed_on_last_epoch=True,
log_wandb=True,
log_console=True
)
lr_logger = scorch.execution.LearningRateLogger(
log_console=False,
log_wandb=True
)
In [10]:
optimizer = torch.optim.Adam(lenet.parameters(), lr=config['lr'])
In [11]:
trainer = scorch.execution.Runner(
model=lenet,
training_loader=train_dataloader,
optimizer=optimizer,
lr_scheduler=torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=config['lr_gamma']),
loss=scorch.nn.BCELoss,
epochs=config['epochs'],
validation_loader=test_dataloader,
loggers=(
loss_logger,
classification_logger,
lr_logger
)
)
In [12]:
stats = trainer.run(train=True, validate=True)
Quantitative Results¶
In [13]:
loss_logger.plot(ylabel='BCE Loss')
In [14]:
classification_logger.plot_error_rates(ks=(1, 3))
In [15]:
classification_logger.plot_confusion_matrix()
Qualitative Results¶
In [16]:
predictions_by_confidence = sorted(classification_logger.valid['predictions'][-1], key=lambda p: -p.confidence)
wrong_predictions = list(filter(lambda p: not p.correct, predictions_by_confidence))
correct_predictions = list(filter(lambda p: p.correct, predictions_by_confidence))
In [17]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(wrong_predictions[::-1][:5], title='Least Confident, Wrong Predictions')
In [18]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(wrong_predictions[:5], title='Most Confident, Wrong Predictions')
In [19]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(correct_predictions[:5], title='Most Confident, Correct Predictions')
In [20]:
scorch.execution.ClassifierLogger.plot_predictions_for_images(correct_predictions[::-1][:5], title='Least Confident, Correct Predictions')
In [21]:
wandb.finish()