Hooks are the helper functions that can be executed at several parts of a training process as described below:
on_start: These hooks are executed before the training starts.
on_phase_start: executed at the beginning of every epoch (including test, train epochs)
on_forward: executed after every forward pass
on_loss_and_meter: executed after loss and meters are calculateds
on_backward: executed after every backward pass of the model
on_update: executed after model parameters are updated by the optimizer
on_step: executed after one single training (or test) iteration finishes
on_phase_end: executed after the epoch (train or test) finishes
on_end: executed at the very end of training.
Hooks are executed by inserting
task.run_hooks(SSLClassyHookFunctions.<type>.name) at several steps of the training.
How to enable certain hooks in VISSL¶
VISSL supports many hooks. Users can configure which hooks to use from simple configuration files. The hooks in VISSL can be categorized into following buckets:
Tensorboard hook: to enable this hook, set
TENSORBOARD_SETUP.USE_TENSORBOARD=trueand configure the tensorboard settings
Model Complexity hook: this hook performs one single forward pass of the model on the synthetic input and computes the #FLOPs, #params and #activations in the model. To enable this hook, set
MODEL.MODEL_COMPLEXITY.COMPUTE_COMPLEXITY=trueand configure it.
Self-supervised Loss hooks: VISSL has hooks specific to self-supervised approaches like MoCo, SwAV etc. These hooks are handy in performing some intermediate operations required in self-supervision. For example:
MoCoHookis called after every forward pass of the model and updates the momentum encoder network. Users don’t need to do anything special for using these hooks. If the user configuration file has the loss function for an approach, VISSL will automatically enable the hooks for the approach.
Logging, checkpoint, training variable update hooks: These hooks are used by default in VISSL and perform operations like logging the training progress (loss, LR, eta etc) on stdout, save checkpoints etc.