Define custom callbacks

Many things can go wrong during ML training (incorrect driver versions, plateaued metrics, etc) that could result in wasted GPUs and time. The Aim callbacks API helps to define custom callbacks to be executed at any point during ML training - a programmable way to guard ML training from wasting resources.

Callbacks can actually encompass any programmable functionality, such as logging messages and sending notifications, or killing the training process when the given condition is met.

Callbacks

Terms:

  • callback - python function that implements a custom logic to be executed at a certain point during the training.

  • callbacks class - python class to group callback functions. Can be used to share state between different callbacks (example below).

  • event - represents an event to be bound to the training.

The callbacks API:

  • TrainingFlow - defines the training flow/events.

  • events.on.* - decorators to define when the callback function is executed.

The list of currently available training events:

  • events.on.training_started - called after the training start.

  • events.on.training_successfully_finished - called after the training is successfully finished, meaning no unexpected exceptions are raised, even a manual keyboard interruption (ctrl+C). Please note that programmatic early stopping is considered as a successful finish.

  • events.on.training_metrics_collected - called after the training metrics are calculated and ready to be logged. Typically called at each N batches.

  • events.on.validation_metrics_collected - called after the validation metrics are calculated and ready to be logged. Mostly called only once, after the validation loop is finished.

  • events.on.init - automatically called after the callbacks class initialization and before all the other events. Must not be called manually. Typical use-case can be initializing a shared state for callback functions (example below).

Example

The below example demonstrates how to implement custom callbacks to check and notify, when:

  • wrong driver versions are installed.

  • gnorm metrics explode.

  • model starts to overfit.

Defining the callbacks

from aim.sdk.callbacks import events

class MyCallbacks:
    @events.on.init  # Called when initializing the TrainingFlow object
    def init_gnorm_accumulators(self, **kwargs):
        # Initialize a state to collect gnorm values over training process
        self.gnorm_sum = 0
        self.gnorm_len = 0

    @events.on.init
    def init_ppl_accumulators(self, **kwargs):
        # Initialize a state to collect ppl values over training process
        self.ppl_sum = 0
        self.ppl_len = 0

    @events.on.init
    def init_metrics_accumulators(self, **kwargs):
        # Collect only the last 100 appended values
        self.last_train_metrics = deque(maxlen=100)

    # NOTE: all the above methods can be merged into one,
    #        but are separated for readability reasons

    @events.on.training_started
    def check_cuda_version(self, run: aim.Run, **kwargs):
        if run['__system_params', 'cuda_version'] != '11.6':
            run.log_warning("Wrong CUDA version is installed!")

    @events.on.training_metrics_collected
    def check_gnorm_and_notify(
        self,
        metrics: Dict[str, Any],
        step: int,
        # always denotes the number of *training* steps
        # `1 step per 4 batches` can be in case of gradient accumulation
        epoch: int,
        run: aim.Run,
        **kwargs
    ):
        current = metrics['gnorm'] # notice that it's the last one
        # thus we need to use self.* to collect gnorm values
        self.gnorm_sum += current
        self.gnorm_len += 1
        mean = self.gnorm_sum / self.gnorm_len

        if current > 1.15 * mean:
            run.log_warning(f'gnorms have exploded. mean: {mean}, '
                             'step {step}, epoch {epoch} ...')

    @events.on.training_metrics_collected
    def check_ppl_and_notify(
        self,
        metrics: Dict[str, Any],
        step: int,
        epoch: int,
        run: aim.Run,
        **kwargs
    ):
        current = metrics['ppl'] # notice that it's the last one
        # thus we need to use self.* to collect ppl values
        self.ppl_sum += current
        self.ppl_len += 1
        mean = self.ppl_sum / self.ppl_len

        if current > 1.15 * mean:
            run.log_warning(f'ppl have exploded. mean: {mean}, '
                             'step {step}, epoch {epoch} ...')

    @events.on.training_metrics_collected
    def store_last_train_metrics(
        self,
        metrics: Dict[str, Any],
        step: int,
        epoch: int,
        **kwargs,
    ):
        self.last_train_metrics.append(metrics)

    @events.on.validation_metrics_collected
    def check_overfitting(
        self,
        metrics: Dict[str, Any],
        epoch: int = None,
        run: aim.Run,
        **kwargs,
    ):
        mean_train_ppl = sum(
            metrics['ppl'] for metrics
            in self.last_train_metrics
        ) / len(self.last_train_metrics)

        if mean_train_ppl > 1.15 * metrics['ppl']:
            run.log_warning(f'I think we are overfitting on epoch={epoch}')

Registering the callbacks

from aim import TrainingFlow, Run

aim_run = Run()

training_flow = TrainingFlow(run=aim_run, callbacks=[MyCallbacks()])
# or
training_flow = TrainingFlow(run=aim_run)
training_flow.register(MyCallbacks())