Learn to Code in TensorFlow2 : Part3
We have the CIFAR10 data prepared from part1 and ResNet18 model ready from part2. In this article, we will understand the small blocks involved in training with custom loops, and then train our model to 90% accuracy.
Contents:
- Brief understanding of Gradient Tape
- Tensor Slices from data set
- Custom learning rate
- Other variables like batch size, momentum and weight decay
- Actual training with combination of above concepts
Brief understanding of Gradient Tape
For back propagation of loss.
Gradient : Tries to find minima
Example from https://www.tensorflow.org/api_docs/python/tf/GradientTape:
x = tf.constant(3.0)
with tf.GradientTape(persistent=True) as g:
g.watch(x)
y = x * x
z = y * y##z = x^4
dz_dx = g.gradient(z, x) # 108.0 (4*x³ at x = 3)
dy_dx = g.gradient(y, x) # 6.0
del g # Drop the reference to the tape
Trying to understand on similar lines:
with tf.GradientTape() as tape:
Defining the function as y = x * x. loss and correct are float tensors
loss, correct = model(x, y)
Getting the trainable variables from model. This is a list of tensors.
var[0] will be a tensor of (3, 3, 3, 64) whereas var[16] will be a tensor of (512, ). It depends on the layers as defined in the model.
var = model.trainable_variables
Getting the minima. grads will also be a list same as var.
grads = tape.gradient(loss, var)
Apply these to the optimizer function SGD
opt.apply_gradients(zip(grads, var))
Tensor Slices from data set:
from_tensor_slices takes the numpy.ndarray and divides it into single element of tensors
Tutorial : https://www.geeksforgeeks.org/tensorflow-tf-data-dataset-from_tensor_slices/
batch converts a list to a list of batches.
prefetch : fetches the data for next iteration : https://www.tensorflow.org/guide/data_performance#prefetching
See the example illustrated below to understand each functionality:
Custom learning rate:
Since we will be writing a loop for each epochs, we can very well control the learning rate. We create piece wise function with np.interp
Other variables like batch size, momentum and weight decay
BATCH_SIZE = 512 #@param {type:"integer"}
MOMENTUM = 0.7 #@param {type:"number"}
LEARNING_RATE = 0.4 #@param {type:"number"}
WEIGHT_DECAY = 5e-4 #@param {type:"number"}
EPOCHS = 50 #@param {type:"integer"}
We are now all set to train a model.
Training:
First we will define the SGD optimizer:
global_step = tf.compat.v1.train.get_or_create_global_step()
global_step = 0
lr_func = lambda: lr_schedule (global_step/batches_per_epoch) / BATCH_SIZE
opt = tf.keras.optimizers.SGD (learning_rate=lr_func, nesterov = True)
Then, we will combine all the above concepts used to write the training loop.
After training for 150 Epochs, we got 90% validation accuracy and here are the graphs:
- Model Accuracy w.r.t Epochs
- Model Loss w.r.t Epochs
- Learning Rate w.r.t Epochs
Open the source code and run it once end to end and then start making changes to understand each block with help of this article.
We have seen, how we can code in TensorFlow2 with ease. Now we can make custom changes to models, parameters, data augmentation and try to achieve a higher accuracy.