πŸ“ PyTorch Snippets


Let me start by saying that the β€œPyTorch website is awesome”. The front page quickly gets you started with the installation, the tutorials, and the documentation.

This post mainly covers the things that I wish to find quickly and serves as something I can refer to regularly.

How to use numpy data in DataLoaders ?

If you ever find yourself testing a small dataset available as a numpy array and wish to use it in PyTorch, you could create a simple Dataset class and load it via dataloader.


# Load mnist from sklearn
from sklearn.datasets import load_digits
X_digits, y_digits = load_digits(return_X_y=True)
print(X_digits.shape, y_digits.shape)

# Dataset
import torch
from torch.utils.data import Dataset

class NumpyDataset(Dataset):
    """ Dataset for numpy Data
        References: https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#dataset-class
    """
    def __init__(self, data, transform=None):
    """
        parameters
        ----------
        data: numpy array
        transform: list of transforms to apply. e.g transforms.Compose([transforms.ToTensor()])
    """
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()    
        sample = self.data[idx]

        if self.transform:
            sample = self.transform(sample)

        return sample        

You can now load the numpy data through dataloader

tfms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
train_loader = DataLoader(
    NumpyDataset(datadict['train'], transform=tfms),
    batch_size=128, num_workers=workers, shuffle=True
    )


Stumped by loss function in torch.nn ?

It happens every now and then, where I keep forgetting to instantiate the loss and pass params.

Use:

torch.nn.MSELoss()(x_hat, x)

Wrong usage:

torch.nn.MSELoss(x_hat, x)

DataLoader boilerplate

import torch
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader, random_split

torch.manual_seed(42)

BATCH_SIZE=128
NUM_WORKERS=12

# Setup transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(tuple([0.5]*3), tuple([0.5]*3))
])

# Download dataset
train_ds = datasets.CIFAR10(
    root="./data", train=True, 
    download=True, transform=transform
)

# Create train and validation splits
train, val = random_split(train_ds, [45000, 5000])

# Create data loaders
train_loader = DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

# Display (CV datasets)
# Display images
images, labels = iter(train_loader).next()
# see: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.figure(figsize=(12, 12))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
    
imshow(utils.make_grid(images))

Trainer boilerplate

import torch
from tqdm.notebook import tqdm

# Define the training loop
def train(epochs, net, metric_logger=None):
    pbar = tqdm(range(epochs))
    best_loss = np.inf
    for idx, epoch in enumerate(pbar):
        # set the network into training mode
        net.train()
        epoch_loss = 0
        running_corrects = 0
        for idx, record in enumerate(train_loader):
            images, labels = record
            images = images.to(device)
            labels = labels.to(device)
            # zero grad the optimizer
            optimizer.zero_grad()
            
            with torch.set_grad_enabled(True):
                preds = net(images)
                loss = criterion(preds, labels)                
                loss.backward()
                optimizer.step()               
                # track losses
                epoch_loss += loss.item()
                _, predicted = torch.max(preds.data, 1)
                running_corrects += torch.sum(predicted == labels).item()
                
        train_loss = epoch_loss / len(train_loader)
        train_acc = running_corrects / len(train_loader)
        pbar.set_description(f"Train Loss: {train_loss:4f} | Acc: {train_acc:.4f}")
        if epoch_loss < best_loss:
            # save(net, epoch, epoch_loss)
            best_loss = epoch_loss
            
        # Validation accuracy (comment, if you don't need this)
        val_acc = test_model(net, val_loader)
        print(f"{epoch+1}/{epochs}, train_loss: {train_loss:.4f}, train_acc: {train_acc:.4f}, val_acc: {val_acc:.4f}")

Save Model

import os
from datetime import datetime

def save(model, epoch, loss, fname="pytorch_model.pt", dirname="models", debug=False):
    """
    Saves the model to the local directory
    """
    today = datetime.now().strftime("%Y%m%d")
    dirname = f"models/{today}"
    filename = f"{dirname}/{fname}"
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, f"{filename}")

    if debug:
        print(f"Model saved to {filename}")

See also