How to stop training a model on neural-network using callbacks? Simple-Answer

How to stop training a model on neural-network using callbacks? Simple-Answer

Often, when you're training your neural network in Tensorflow, you're wondering sometimes duration of training the model the accuracy is high before the final epoch, so if you want to stop your training at the highest accuracy or lowest loss of the model, this tutorial is for you. Thus, we can achieve what we want (optimal model weights) and avoid waste of resources (time and computation power).

It is a useful hack with Tensorflow and Keras

Introduction

Often, when you're training your neural network in Tensorflow, you're wondering sometimes duration of training the model the accuracy is high before the final epoch, so if you want to stop your training at the highest accuracy or lowest loss of the model, this tutorial is for you. Thus, we can achieve what we want (optimal model weights) and avoid waste of resources (time and computation power).

In this brief tutorial, let’s learn how to achieve this in Tensorflow and Keras, using the callback approach, in 4 simple steps.

let's do it

# Import tensorflow
import tensorflow as tf
  1. First, set the accuracy threshold till which you want to train your model.
ACCURACY_THRESHOLD = 0.95

2. Now, implement the callback class and function to stop training when accuracy reaches ACCURACY_THRESHOLD.


# Implement callback function to stop training in a custom class 
# when accuracy reaches e.g. ACCURACY_THRESHOLD = 0.95
class myCallback(tf.keras.callbacks.Callback): 
    def on_epoch_end(self, epoch, logs={}): 
        if(logs.get('acc') > ACCURACY_THRESHOLD):   
        print("\nReached %2.2f%% accuracy, so stopping training!!" %(ACCURACY_THRESHOLD*100))   
        self.model.stop_training = True

What exactly is going on here? We are creating a new class by extending tf.keras.callbacks.Callback, and implementing the on_epoch_end() method. This is invoked at the end of each epoch. Next, we are fetching the value of accuracy at the end of that epoch, and if it is greater than our threshold, we are setting the stop_training of the model to True.

Reference - https://faroit.com/keras-docs/2.1.3/callbacks/

3. Instantiate an object of myCallback class.

callbacks = myCallback()

Next, build a DNN or Conv-Net model following the normal steps of TensorFlow or Keras. The callback that we have built above will be used while training the model using the fit() method.

4. Simply pass an argument as callbacks=[<the newly instantiated object of myCallback class>] to fit() method.

model.fit(x_train, y_train, epochs=20, callbacks=[callbacks])

And that’s all! While training, as soon as accuracy reaches the value set in ACCURACY_THRESHOLD, training will be stopped.

To tie it all together, here’s a complete code snippet.

Bonus

You can also access loss metrics through logs attributes, use other methods like on_epoch_start, and more !!!

Reference - https://faroit.com/keras-docs/2.1.3/callbacks/

Conclusion

With our imagination, this approach can be used in varied creative ways, especially when we want to run quick PoCs to test and validate multiple DNN architectures. What other interesting usages can you think of? Please share your thoughts in the comments section below.

# Details

Published on January 31, 2024 2 min read

Machine Learning

Data science