CIFAR10 Classfier: TensorFlow + Optuna Edition
Our objective is similar to the Keras-Tuner and Ray Tune notebooks:
- Explore Optuna optimization library for hyperparam tuning
- Find out if we can beat the test accuracy of a hand tuned model 69.5 (PyTorch)
Author: Katnoria | Created: 18-Oct-2020
1. Imports & Setup
import pickle
from time import time
from datetime import datetime
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, Dropout
from tensorflow.keras.layers import BatchNormalization, Input, GlobalAveragePooling2D
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras import Model
import IPython
import optuna
from optuna.integration.tensorboard import TensorBoardCallback
def version_info(cls):
print(f"{cls.__name__}: {cls.__version__}")
print("Version Used in this Notebook:")
version_info(tf)
version_info(tfds)
version_info(optuna)
Version Used in this Notebook:
tensorflow: 2.3.0
tensorflow_datasets: 3.2.1
optuna: 2.2.0
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
Num GPUs Available: 1
EPOCHS = 25
# EPOCHS = 2
BATCH_SIZE=128
IMG_SIZE=32
NUM_CLASSES=10
2. Dataset
Tensorflow Datasets already provides this dataset in a format that we can use out of the box.
https://github.com/optuna/optuna/blob/master/examples/tensorflow_eager_simple.py
def get_dataset():
(ds_train, ds_test), metadata = tfds.load(
'cifar10', split=['train', 'test'], shuffle_files=True,
with_info=True, as_supervised=True
)
train_ds = ds_train \
.cache() \
.batch(BATCH_SIZE, drop_remainder=True) \
.prefetch(tf.data.experimental.AUTOTUNE)
test_ds = ds_test \
.cache() \
.batch(BATCH_SIZE, drop_remainder=True) \
.prefetch(tf.data.experimental.AUTOTUNE)
return (train_ds, test_ds)
3. Model
We will use the same transforms that were using in training hand tuned TensorFlow notebooks.
3.1 Create Model
transforms = tf.keras.Sequential([
tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])
## Create Model
def create_model(trial):
"""
Create a simple CIFAR-10 model that uses ResNet50 as its backbone.
Params:
-------
trial: optuna Trial object
"""
inputs = Input(shape=(IMG_SIZE, IMG_SIZE, 3))
x = transforms(inputs)
x = tf.keras.applications.resnet.preprocess_input(x)
x = tf.keras.applications.ResNet50(input_shape=(IMG_SIZE, IMG_SIZE, 3), include_top=False)(x, training=False)
# Flatten or GAP
use_gap = trial.suggest_categorical('use_gap', [True, False])
if use_gap:
x = GlobalAveragePooling2D()(x)
else:
x = Flatten()(x)
x = Flatten()(x)
# Dense Layer Units
num_hidden = trial.suggest_int('dense_1', 32, 128)
# Activation
activation = trial.suggest_categorical('activation', ['relu', 'selu', 'elu'])
x = Dense(128, activation=activation)(x)
# Dropout rate
drop_rate = trial.suggest_float('drop_rate', 0.0, 0.8)
x = Dropout(drop_rate)(x)
outputs = Dense(NUM_CLASSES)(x)
model = tf.keras.Model(inputs, outputs)
return model
3.2 Optimizers
We could add various optimizers to the search space. I’ll leave that for you to try.
## Create Optimizer
def create_optimizer(trial):
# LR
lr = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
return optimizer
3.3 Training Loop
## Define Objective
def train(model, optimizer, dataset, mode="eval"):
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")
mean_loss = tf.keras.metrics.Mean(name="loss")
for images, labels in dataset:
with tf.GradientTape() as tape:
predictions = model(images, training=(mode=='train'))
loss = loss_object(labels, predictions)
if mode == "train":
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
accuracy(labels, predictions)
mean_loss(loss)
return accuracy.result(), mean_loss.result()
4. Trials
4.1 Setup Objective
We define the objective function that Optuna should optimize. In our case, its the test accuracy over a certain number of EPOCHS
.
You can improve the search efficient by letting Optuna prune the unpromising trials.
def objective(trial):
# dataset
train_ds, test_ds = get_dataset()
# model
model = create_model(trial)
# optimizer
optimizer = create_optimizer(trial)
# train
for epoch in range(EPOCHS):
train_acc, train_loss = train(model, optimizer, train_ds, "train")
test_acc, test_loss = train(model, optimizer, test_ds, "eval")
trial.report(test_acc, epoch)
print(f"train_accuracy:{train_acc:.4f}, train_loss: {train_loss:.4f}, test_acc: {test_acc:.4f}")
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
return test_acc
4.2 Run Trials
We are now ready to run optuna and find the best set of hyper parameters.
# Track using Tensorboard
tensorboard_cb = TensorBoardCallback("./logs/", metric_name="accuracy")
start = time()
# Run
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=1, timeout=600, callbacks=[tensorboard_cb])
stop = time()
/home/ashish/miniconda3/envs/tf2_3/lib/python3.6/site-packages/ipykernel_launcher.py:2: ExperimentalWarning: TensorBoardCallback is experimental (supported from v2.0.0). The interface can change in the future.
[32m[I 2020-10-21 11:22:21,476][0m A new study created in memory with name: no-name-0fb6a9b3-81e4-4b13-a494-974ec9cea169[0m
train_accuracy:0.5205, train_loss: 1.3534, test_acc: 0.6755
train_accuracy:0.6663, train_loss: 0.9624, test_acc: 0.7300
train_accuracy:0.7090, train_loss: 0.8409, test_acc: 0.7423
train_accuracy:0.7325, train_loss: 0.7766, test_acc: 0.7612
train_accuracy:0.7474, train_loss: 0.7279, test_acc: 0.7724
train_accuracy:0.7648, train_loss: 0.6791, test_acc: 0.7628
train_accuracy:0.7727, train_loss: 0.6483, test_acc: 0.7711
train_accuracy:0.7870, train_loss: 0.6148, test_acc: 0.7734
train_accuracy:0.7959, train_loss: 0.5919, test_acc: 0.7776
train_accuracy:0.8052, train_loss: 0.5624, test_acc: 0.7837
train_accuracy:0.8098, train_loss: 0.5452, test_acc: 0.7851
train_accuracy:0.8191, train_loss: 0.5196, test_acc: 0.7915
train_accuracy:0.8255, train_loss: 0.4976, test_acc: 0.7678
train_accuracy:0.8311, train_loss: 0.4801, test_acc: 0.7833
train_accuracy:0.8385, train_loss: 0.4645, test_acc: 0.7758
train_accuracy:0.8429, train_loss: 0.4510, test_acc: 0.7774
train_accuracy:0.8511, train_loss: 0.4288, test_acc: 0.7824
train_accuracy:0.8561, train_loss: 0.4120, test_acc: 0.7764
train_accuracy:0.8599, train_loss: 0.3994, test_acc: 0.7807
train_accuracy:0.8658, train_loss: 0.3825, test_acc: 0.7890
train_accuracy:0.8727, train_loss: 0.3672, test_acc: 0.7817
train_accuracy:0.8773, train_loss: 0.3537, test_acc: 0.7930
train_accuracy:0.8793, train_loss: 0.3460, test_acc: 0.7887
train_accuracy:0.8819, train_loss: 0.3397, test_acc: 0.7755
[32m[I 2020-10-21 11:38:29,271][0m Trial 0 finished with value: 0.7918669581413269 and parameters: {'use_gap': False, 'dense_1': 69, 'activation': 'relu', 'drop_rate': 0.07607458016208418, 'learning_rate': 8.362974203545105e-05}. Best is trial 0 with value: 0.7918669581413269.[0m
train_accuracy:0.8877, train_loss: 0.3200, test_acc: 0.7919
took = stop - start
print(f"Total time: {took//60 : .0f}m {took%60:.0f}s")
4.3 Inspect
Print out the information about the trials.
pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
print(f"Finished Trials: {len(study.trials)}")
print(f"Pruned Trials: {len(pruned_trials)}")
print(f"Completed Trials: {len(complete_trials)}")
Finished Trials: 1
Pruned Trials: 0
Completed Trials: 1
trial = study.best_trial
print(trial.value)
0.7918669581413269
for k,v in trial.params.items():
print(f"{k}: {v}")
use_gap: False
dense_1: 69
activation: relu
drop_rate: 0.07607458016208418
learning_rate: 8.362974203545105e-05
5. Conclusion
See also
- Exploration Log of Contrastive Language-Image Pre-training
- 📙 CIFAR-10 Classifiers: Part 7 - Speed up PyTorch hyperparameter search using Ray Tune
- 📙 CIFAR-10 Classifiers: Part 5 - Speed up TensorFlow hyperparameter search using Ray Tune
- 🍻 💊 Number of People Searching Hangover Cure
- 📙 CIFAR-10 Classifiers: Part 4 - Build a Simple Image Classifier using PyTorch Lightning