📙 CIFAR-10 Classifiers: Part 6 - Speed up TensorFlow hyperparameter search using Optuna


CIFAR10 Classfier: TensorFlow + Optuna Edition

Our objective is similar to the Keras-Tuner and Ray Tune notebooks:

  • Explore Optuna optimization library for hyperparam tuning
  • Find out if we can beat the test accuracy of a hand tuned model 69.5 (PyTorch)

Author: Katnoria | Created: 18-Oct-2020

1. Imports & Setup

import pickle
from time import time
from datetime import datetime
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, Dropout
from tensorflow.keras.layers import BatchNormalization, Input, GlobalAveragePooling2D
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras import Model
import IPython
import optuna
from optuna.integration.tensorboard import TensorBoardCallback
def version_info(cls):
    print(f"{cls.__name__}: {cls.__version__}")
print("Version Used in this Notebook:")
version_info(tf)
version_info(tfds)
version_info(optuna)
Version Used in this Notebook:
tensorflow: 2.3.0
tensorflow_datasets: 3.2.1
optuna: 2.2.0
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
Num GPUs Available:  1
EPOCHS = 25
# EPOCHS = 2
BATCH_SIZE=128
IMG_SIZE=32
NUM_CLASSES=10

2. Dataset

Tensorflow Datasets already provides this dataset in a format that we can use out of the box.

https://github.com/optuna/optuna/blob/master/examples/tensorflow_eager_simple.py

def get_dataset():
    (ds_train, ds_test), metadata = tfds.load(
        'cifar10', split=['train', 'test'], shuffle_files=True, 
        with_info=True, as_supervised=True
    )
    
    train_ds = ds_train \
        .cache() \
        .batch(BATCH_SIZE, drop_remainder=True) \
        .prefetch(tf.data.experimental.AUTOTUNE)
    
    test_ds = ds_test \
        .cache() \
        .batch(BATCH_SIZE, drop_remainder=True) \
        .prefetch(tf.data.experimental.AUTOTUNE)
    return (train_ds, test_ds)

3. Model

We will use the same transforms that were using in training hand tuned TensorFlow notebooks.

3.1 Create Model

transforms = tf.keras.Sequential([
    tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
    tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])
## Create Model
def create_model(trial):
    """
    Create a simple CIFAR-10 model that uses ResNet50 as its backbone.
    
    Params:
    -------
    trial: optuna Trial object
    """
    inputs = Input(shape=(IMG_SIZE, IMG_SIZE, 3))
    x = transforms(inputs)
    x = tf.keras.applications.resnet.preprocess_input(x)
    x = tf.keras.applications.ResNet50(input_shape=(IMG_SIZE, IMG_SIZE, 3), include_top=False)(x, training=False)
    # Flatten or GAP
    use_gap = trial.suggest_categorical('use_gap', [True, False])
    if use_gap:
        x = GlobalAveragePooling2D()(x)
    else:
        x = Flatten()(x)
    x = Flatten()(x)
    # Dense Layer Units
    num_hidden = trial.suggest_int('dense_1', 32, 128)
    # Activation
    activation = trial.suggest_categorical('activation', ['relu', 'selu', 'elu'])
    x = Dense(128, activation=activation)(x)
    # Dropout rate    
    drop_rate = trial.suggest_float('drop_rate', 0.0, 0.8)
    x = Dropout(drop_rate)(x)
    outputs = Dense(NUM_CLASSES)(x)
    model = tf.keras.Model(inputs, outputs)
    return model

3.2 Optimizers

We could add various optimizers to the search space. I’ll leave that for you to try.

## Create Optimizer
def create_optimizer(trial):
    # LR
    lr = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True)
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
    return optimizer

3.3 Training Loop

## Define Objective
def train(model, optimizer, dataset, mode="eval"):
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)    
    accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")
    mean_loss = tf.keras.metrics.Mean(name="loss")
    for images, labels in dataset:
        with tf.GradientTape() as tape:
            predictions = model(images, training=(mode=='train'))
            loss = loss_object(labels, predictions)            
            if mode == "train":
                gradients = tape.gradient(loss, model.trainable_variables)
                optimizer.apply_gradients(zip(gradients, model.trainable_variables))
            accuracy(labels, predictions)
            mean_loss(loss)
    return accuracy.result(), mean_loss.result()

4. Trials

4.1 Setup Objective

We define the objective function that Optuna should optimize. In our case, its the test accuracy over a certain number of EPOCHS. You can improve the search efficient by letting Optuna prune the unpromising trials.

def objective(trial):
    # dataset
    train_ds, test_ds = get_dataset()
    # model
    model = create_model(trial)
    # optimizer
    optimizer = create_optimizer(trial)
    # train
    for epoch in range(EPOCHS):
        train_acc, train_loss = train(model, optimizer, train_ds, "train")
        test_acc, test_loss = train(model, optimizer, test_ds, "eval")
        trial.report(test_acc, epoch)
        print(f"train_accuracy:{train_acc:.4f}, train_loss: {train_loss:.4f}, test_acc: {test_acc:.4f}")
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
        
    return test_acc

4.2 Run Trials

We are now ready to run optuna and find the best set of hyper parameters.

# Track using Tensorboard
tensorboard_cb = TensorBoardCallback("./logs/", metric_name="accuracy")

start = time()
# Run
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=1, timeout=600, callbacks=[tensorboard_cb])
stop = time()
/home/ashish/miniconda3/envs/tf2_3/lib/python3.6/site-packages/ipykernel_launcher.py:2: ExperimentalWarning: TensorBoardCallback is experimental (supported from v2.0.0). The interface can change in the future.
  
[I 2020-10-21 11:22:21,476] A new study created in memory with name: no-name-0fb6a9b3-81e4-4b13-a494-974ec9cea169


train_accuracy:0.5205, train_loss: 1.3534, test_acc: 0.6755
train_accuracy:0.6663, train_loss: 0.9624, test_acc: 0.7300
train_accuracy:0.7090, train_loss: 0.8409, test_acc: 0.7423
train_accuracy:0.7325, train_loss: 0.7766, test_acc: 0.7612
train_accuracy:0.7474, train_loss: 0.7279, test_acc: 0.7724
train_accuracy:0.7648, train_loss: 0.6791, test_acc: 0.7628
train_accuracy:0.7727, train_loss: 0.6483, test_acc: 0.7711
train_accuracy:0.7870, train_loss: 0.6148, test_acc: 0.7734
train_accuracy:0.7959, train_loss: 0.5919, test_acc: 0.7776
train_accuracy:0.8052, train_loss: 0.5624, test_acc: 0.7837
train_accuracy:0.8098, train_loss: 0.5452, test_acc: 0.7851
train_accuracy:0.8191, train_loss: 0.5196, test_acc: 0.7915
train_accuracy:0.8255, train_loss: 0.4976, test_acc: 0.7678
train_accuracy:0.8311, train_loss: 0.4801, test_acc: 0.7833
train_accuracy:0.8385, train_loss: 0.4645, test_acc: 0.7758
train_accuracy:0.8429, train_loss: 0.4510, test_acc: 0.7774
train_accuracy:0.8511, train_loss: 0.4288, test_acc: 0.7824
train_accuracy:0.8561, train_loss: 0.4120, test_acc: 0.7764
train_accuracy:0.8599, train_loss: 0.3994, test_acc: 0.7807
train_accuracy:0.8658, train_loss: 0.3825, test_acc: 0.7890
train_accuracy:0.8727, train_loss: 0.3672, test_acc: 0.7817
train_accuracy:0.8773, train_loss: 0.3537, test_acc: 0.7930
train_accuracy:0.8793, train_loss: 0.3460, test_acc: 0.7887
train_accuracy:0.8819, train_loss: 0.3397, test_acc: 0.7755


[I 2020-10-21 11:38:29,271] Trial 0 finished with value: 0.7918669581413269 and parameters: {'use_gap': False, 'dense_1': 69, 'activation': 'relu', 'drop_rate': 0.07607458016208418, 'learning_rate': 8.362974203545105e-05}. Best is trial 0 with value: 0.7918669581413269.


train_accuracy:0.8877, train_loss: 0.3200, test_acc: 0.7919
took = stop - start
print(f"Total time: {took//60 : .0f}m {took%60:.0f}s")

4.3 Inspect

Print out the information about the trials.

pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

print(f"Finished Trials: {len(study.trials)}")
print(f"Pruned Trials: {len(pruned_trials)}")
print(f"Completed Trials: {len(complete_trials)}")
Finished Trials: 1
Pruned Trials: 0
Completed Trials: 1
trial = study.best_trial
print(trial.value)
0.7918669581413269
for k,v in trial.params.items():
    print(f"{k}: {v}")
use_gap: False
dense_1: 69
activation: relu
drop_rate: 0.07607458016208418
learning_rate: 8.362974203545105e-05

5. Conclusion



See also