This notebook is part of CIFAR-10 Classifiers post, which covers building classifiers using TensorFlow, PyTorch, PyTorch Lightning etc.
Part 1: Tensorflow | Part 3: PyTorch
CIFAR 10 Classifier: Pytorch Lightning Edition
We use Pytorch Lightning in this notebook, it abstracts away all the boiler plate code that we need to add to every pytorch training workflow.
Objective
Build a simple CIFAR10 classification model that runs on GPU using following:
- Resnet 50 as backbone
- Minimal augmentation
- Experiment tracking using Comet ML
Expected Outcome
By the end of this post, you will be able to:
- use backbone (VGG,EfficientNet,ResNext…) of your choice
- use the dataset of your choice or the ones available on tensorflow datasets
- use Comet ML in your projects
Most of the content is similar to the Tensorflow and Pytorch version of the notebook.
Author: Katnoria | Created: 03-Sep-2020 | Updated: 16-Oct-2020
1. Imports & Setup
import os
from time import time
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets, models, utils
from torch.utils.data import DataLoader, random_split
# load from .env
from pathlib import Path
from dotenv import load_dotenv
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.loggers import CometLogger
pl.seed_everything(42)
42
# Load secrets
%load_ext dotenv
%dotenv ../../.env
The dotenv extension is already loaded. To reload it, use:
%reload_ext dotenv
COMET_ML_API_KEY = os.getenv("COMET_ML_API_KEY")
len(COMET_ML_API_KEY)
25
def version_info(cls):
print(f"{cls.__name__}: {cls.__version__}")
# Print version info
version_info(torch)
version_info(pl)
torch: 1.6.0
pytorch_lightning: 0.9.0
2. Dataset
We will load CIFAR10 dataset from pytorch datasets
# Hyper params
BATCH_SIZE=128
NUM_WORKERS=12
tfms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
])
# Load the dataset
train_ds = datasets.CIFAR10(
root="./data", train=True,
download=True, transform=tfms
)
# Create train and validation splits
train, val = random_split(train_ds, [45000, 5000])
# Create data loaders
train_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
Files already downloaded and verified
len(train_loader.dataset), len(val_loader.dataset)
(45000, 5000)
test_ds = datasets.CIFAR10(
root="./data",
train=False,
download=True,
transform=tfms
)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
len(test_loader.dataset)
Files already downloaded and verified
10000
3. Build Model
We will use imagenet pre-trained ResNet50 model. You can swap out the base model with others such as ResNet 18 or ResNet 110. Just make sure the input features of the final linear layer matches with the out features of your base model.
class CIFARTenLitModelV2(pl.LightningModule):
"""CIFAR10 Model"""
def __init__(self, backbone, learning_rate):
super().__init__()
self.learning_rate = learning_rate
self.backbone = backbone
self.backbone.fc = nn.Linear(2048, 256)
self.fc1 = nn.Linear(256, 10)
def forward(self, x):
x = self.backbone(x)
x = F.relu(x)
out = self.fc1(x)
return out
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = accuracy(y_hat, y)
result = pl.TrainResult(loss)
# result.log("train_loss", loss, prog_bar=True)
result.log("train_acc", acc, prog_bar=True)
return result
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = accuracy(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss)
# result.log("val_loss", loss)
result.log("val_acc", acc)
return result
def test_step(self, batch, batch_index):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
acc = accuracy(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss)
result.log("test_acc", acc)
return result
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=self.learning_rate)
3. Train
3.1: Track Experiments
You can track your experiments using tensorboard, w&b or any other tool. I am using comet in this notebook. I quite like the overall experience with comet ml.
comet_logger = CometLogger(api_key=COMET_ML_API_KEY,save_dir='.',project_name="cf10-pl", workspace="katnoria")
# add a tag
comet_logger.experiment.add_tag("R50+Dense")
CometLogger will be initialized in online mode
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/katnoria/cf10-pl/8bfe9941b27b41339274f8ff23e9f79d
Load the pre-trained Resnet 50 model from pytorch
backbone = models.resnet50(pretrained=True)
for param in backbone.parameters():
param.requires_grad = False
early_stop = EarlyStopping(
monitor='val_loss',
patience=3,
strict=False,
verbose=False,
mode='min'
)
With Pytorch Lightning, the trainer takes care of setting up the training, validation and test loop. It also provides other goodies such as:
- fast dev test (a training loop to check model for errors)
- metric logging
- switching between accelerators (GPU/TPU) without any code change đź’Ż
EPOCHS=50
learning_rate=1e-3
trainer = pl.Trainer(
fast_dev_run=False,
gpus=1,
early_stop_callback=early_stop,
max_epochs=EPOCHS,
logger=comet_logger
)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
model = CIFARTenLitModelV2(backbone, learning_rate)
start = time()
print(f"start: {datetime.fromtimestamp(start)}")
# train
trainer.fit(model, train_loader, val_dataloaders=val_loader)
stop = time()
| Name | Type | Params
------------------------------------
0 | backbone | ResNet | 24 M
1 | fc1 | Linear | 2 K
start: 2020-10-16 18:30:03.064633
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…
Saving latest checkpoint..
COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO: Data:
COMET INFO: display_summary_level : 1
COMET INFO: url : https://www.comet.ml/katnoria/cf10-pl/8bfe9941b27b41339274f8ff23e9f79d
COMET INFO: Metrics [count] (min, max):
COMET INFO: epoch [800] : (0, 99)
COMET INFO: train_acc [700] : (0.359375, 0.8906250596046448)
COMET INFO: val_acc [100] : (0.500781238079071, 0.550976574420929)
COMET INFO: Uploads:
COMET INFO: code : 1 (12 KB)
COMET INFO: environment details : 1
COMET INFO: filename : 1
COMET INFO: git metadata : 1
COMET INFO: git-patch (uncompressed) : 1 (797 KB)
COMET INFO: installed packages : 1
COMET INFO: notebook : 1
COMET INFO: os packages : 1
COMET INFO: ---------------------------
COMET INFO: Uploading stats to Comet before program termination (may take several seconds)
took = stop - start
print(f"Total training time: {took//60 : .0f}m {took%60:.0f}s")
Total training time: 16m 14s
trainer.test(model, test_dataloaders=test_loader)
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/katnoria/cf10-pl/8bfe9941b27b41339274f8ff23e9f79d
HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.5159, device='cuda:0')}
--------------------------------------------------------------------------------
[{'test_acc': 0.5159216523170471}]
hyper_params = {
"batch_size": BATCH_SIZE,
"num_workers": NUM_WORKERS,
"learning_rate": learning_rate,
"num_epochs": EPOCHS,
}
comet_logger.experiment.log_parameters(hyper_params)
comet_logger.experiment.end()
COMET INFO: -----------------------------------
COMET INFO: Comet.ml ExistingExperiment Summary
COMET INFO: -----------------------------------
COMET INFO: Data:
COMET INFO: display_summary_level : 1
COMET INFO: url : https://www.comet.ml/katnoria/cf10-pl/8bfe9941b27b41339274f8ff23e9f79d
COMET INFO: Metrics:
COMET INFO: epoch : 99
COMET INFO: test_acc : 0.5159216523170471
COMET INFO: Parameters:
COMET INFO: batch_size : 128
COMET INFO: learning_rate : 0.001
COMET INFO: num_epochs : 100
COMET INFO: num_workers : 12
COMET INFO: -----------------------------------
COMET INFO: Uploading stats to Comet before program termination (may take several seconds)
Next Steps
This is a very simple example of training CIFAR10 classifier using a pre-trained network. Its your turn to turn the knobs and see if you can get model to generalise better. Some ideas:
- make the model overfit your training data
- regularize the model to generalize better
- increase/decrease model capacity based on what you find in above steps
- add image augmentation
- use hyperparameter tuning library to find the best set of combination
- rollout your own model from scratch, you can use the tuning library to help design the network too
See also
- Exploration Log of Contrastive Language-Image Pre-training
- đź“™ CIFAR-10 Classifiers: Part 7 - Speed up PyTorch hyperparameter search using Ray Tune
- đź“™ CIFAR-10 Classifiers: Part 6 - Speed up TensorFlow hyperparameter search using Optuna
- đź“™ CIFAR-10 Classifiers: Part 5 - Speed up TensorFlow hyperparameter search using Ray Tune
- 🍻 💊 Number of People Searching Hangover Cure