đź“™ CIFAR-10 Classifiers: Part 4 - Build a Simple Image Classifier using PyTorch Lightning


This notebook is part of CIFAR-10 Classifiers post, which covers building classifiers using TensorFlow, PyTorch, PyTorch Lightning etc.


Part 1: Tensorflow | Part 3: PyTorch

CIFAR 10 Classifier: Pytorch Lightning Edition

We use Pytorch Lightning in this notebook, it abstracts away all the boiler plate code that we need to add to every pytorch training workflow.

Objective

Build a simple CIFAR10 classification model that runs on GPU using following:

  • Resnet 50 as backbone
  • Minimal augmentation
  • Experiment tracking using Comet ML

Expected Outcome

By the end of this post, you will be able to:

  • use backbone (VGG,EfficientNet,ResNext…) of your choice
  • use the dataset of your choice or the ones available on tensorflow datasets
  • use Comet ML in your projects

Most of the content is similar to the Tensorflow and Pytorch version of the notebook.


Author: Katnoria | Created: 03-Sep-2020 | Updated: 16-Oct-2020

1. Imports & Setup

import os
from time import time
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets, models, utils
from torch.utils.data import DataLoader, random_split

# load from .env
from pathlib import Path
from dotenv import load_dotenv


import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.metrics.functional import accuracy

from pytorch_lightning.loggers import CometLogger
pl.seed_everything(42)
42
# Load secrets
%load_ext dotenv
%dotenv ../../.env
The dotenv extension is already loaded. To reload it, use:
  %reload_ext dotenv
COMET_ML_API_KEY = os.getenv("COMET_ML_API_KEY")
len(COMET_ML_API_KEY)
25
def version_info(cls):
    print(f"{cls.__name__}: {cls.__version__}")
# Print version info
version_info(torch)
version_info(pl)
torch: 1.6.0
pytorch_lightning: 0.9.0

2. Dataset

We will load CIFAR10 dataset from pytorch datasets

# Hyper params
BATCH_SIZE=128
NUM_WORKERS=12
tfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
])
# Load the dataset
train_ds = datasets.CIFAR10(
    root="./data", train=True, 
    download=True, transform=tfms
)
# Create train and validation splits
train, val = random_split(train_ds, [45000, 5000])
# Create data loaders
train_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
Files already downloaded and verified
len(train_loader.dataset), len(val_loader.dataset)
(45000, 5000)
test_ds = datasets.CIFAR10(
    root="./data", 
    train=False,
    download=True,
    transform=tfms
)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
len(test_loader.dataset)
Files already downloaded and verified





10000

3. Build Model

We will use imagenet pre-trained ResNet50 model. You can swap out the base model with others such as ResNet 18 or ResNet 110. Just make sure the input features of the final linear layer matches with the out features of your base model.

class CIFARTenLitModelV2(pl.LightningModule):
    """CIFAR10 Model"""
    def __init__(self, backbone, learning_rate):
        super().__init__()
        self.learning_rate = learning_rate
        self.backbone = backbone
        self.backbone.fc = nn.Linear(2048, 256)
        self.fc1 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.backbone(x)
        x = F.relu(x)
        out = self.fc1(x)
        return out
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)        
        result = pl.TrainResult(loss)
#         result.log("train_loss", loss, prog_bar=True)
        result.log("train_acc", acc, prog_bar=True)        
        return result
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)        
        result = pl.EvalResult(checkpoint_on=loss)
#         result.log("val_loss", loss)
        result.log("val_acc", acc)
        return result
    
    def test_step(self, batch, batch_index):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)        
        result = pl.EvalResult(checkpoint_on=loss)
        result.log("test_acc", acc)
        return result
        
    
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)

3. Train

3.1: Track Experiments

You can track your experiments using tensorboard, w&b or any other tool. I am using comet in this notebook. I quite like the overall experience with comet ml.

comet_logger = CometLogger(api_key=COMET_ML_API_KEY,save_dir='.',project_name="cf10-pl", workspace="katnoria")
# add a tag
comet_logger.experiment.add_tag("R50+Dense")
CometLogger will be initialized in online mode
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/katnoria/cf10-pl/8bfe9941b27b41339274f8ff23e9f79d

Load the pre-trained Resnet 50 model from pytorch

backbone = models.resnet50(pretrained=True)
for param in backbone.parameters():
    param.requires_grad = False
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=3,
    strict=False,
    verbose=False,
    mode='min'
)

With Pytorch Lightning, the trainer takes care of setting up the training, validation and test loop. It also provides other goodies such as:

  • fast dev test (a training loop to check model for errors)
  • metric logging
  • switching between accelerators (GPU/TPU) without any code change đź’Ż
EPOCHS=50
learning_rate=1e-3
trainer = pl.Trainer(
    fast_dev_run=False, 
    gpus=1, 
    early_stop_callback=early_stop, 
    max_epochs=EPOCHS,
    logger=comet_logger
)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]
model = CIFARTenLitModelV2(backbone, learning_rate)
start = time()
print(f"start: {datetime.fromtimestamp(start)}")
# train
trainer.fit(model, train_loader, val_dataloaders=val_loader)
stop = time()
  | Name     | Type   | Params
------------------------------------
0 | backbone | ResNet | 24 M  
1 | fc1      | Linear | 2 K   


start: 2020-10-16 18:30:03.064633



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


Saving latest checkpoint..
COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/katnoria/cf10-pl/8bfe9941b27b41339274f8ff23e9f79d
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     epoch [800]     : (0, 99)
COMET INFO:     train_acc [700] : (0.359375, 0.8906250596046448)
COMET INFO:     val_acc [100]   : (0.500781238079071, 0.550976574420929)
COMET INFO:   Uploads:
COMET INFO:     code                     : 1 (12 KB)
COMET INFO:     environment details      : 1
COMET INFO:     filename                 : 1
COMET INFO:     git metadata             : 1
COMET INFO:     git-patch (uncompressed) : 1 (797 KB)
COMET INFO:     installed packages       : 1
COMET INFO:     notebook                 : 1
COMET INFO:     os packages              : 1
COMET INFO: ---------------------------





COMET INFO: Uploading stats to Comet before program termination (may take several seconds)
took = stop - start
print(f"Total training time: {took//60 : .0f}m {took%60:.0f}s")
Total training time:  16m 14s
trainer.test(model, test_dataloaders=test_loader)
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/katnoria/cf10-pl/8bfe9941b27b41339274f8ff23e9f79d




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.5159, device='cuda:0')}
--------------------------------------------------------------------------------






[{'test_acc': 0.5159216523170471}]
hyper_params = {
    "batch_size": BATCH_SIZE,
    "num_workers": NUM_WORKERS,
    "learning_rate": learning_rate,
    "num_epochs": EPOCHS,
}
comet_logger.experiment.log_parameters(hyper_params)
comet_logger.experiment.end()
COMET INFO: -----------------------------------
COMET INFO: Comet.ml ExistingExperiment Summary
COMET INFO: -----------------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/katnoria/cf10-pl/8bfe9941b27b41339274f8ff23e9f79d
COMET INFO:   Metrics:
COMET INFO:     epoch    : 99
COMET INFO:     test_acc : 0.5159216523170471
COMET INFO:   Parameters:
COMET INFO:     batch_size    : 128
COMET INFO:     learning_rate : 0.001
COMET INFO:     num_epochs    : 100
COMET INFO:     num_workers   : 12
COMET INFO: -----------------------------------
COMET INFO: Uploading stats to Comet before program termination (may take several seconds)

Next Steps

This is a very simple example of training CIFAR10 classifier using a pre-trained network. Its your turn to turn the knobs and see if you can get model to generalise better. Some ideas:

  • make the model overfit your training data
  • regularize the model to generalize better
  • increase/decrease model capacity based on what you find in above steps
  • add image augmentation
  • use hyperparameter tuning library to find the best set of combination
  • rollout your own model from scratch, you can use the tuning library to help design the network too

See also