Hooks¶
Hooks are the helper functions that can be executed at several parts of a training process as described below:
on_start
: These hooks are executed before the training starts.on_phase_start
: executed at the beginning of every epoch (including test, train epochs)on_forward
: executed after every forward passon_loss_and_meter
: executed after loss and meters are calculatedson_backward
: executed after every backward pass of the modelon_update
: executed after model parameters are updated by the optimizeron_step
: executed after one single training (or test) iteration finisheson_phase_end
: executed after the epoch (train or test) finisheson_end
: executed at the very end of training.
Hooks are executed by inserting task.run_hooks(SSLClassyHookFunctions.<type>.name)
at several steps of the training.
How to enable certain hooks in VISSL¶
VISSL supports many hooks. Users can configure which hooks to use from simple configuration files. The hooks in VISSL can be categorized into following buckets:
Tensorboard hook
: to enable this hook, setTENSORBOARD_SETUP.USE_TENSORBOARD=true
and configure the tensorboard settingsModel Complexity hook
: this hook performs one single forward pass of the model on the synthetic input and computes the #FLOPs, #params and #activations in the model. To enable this hook, setMODEL.MODEL_COMPLEXITY.COMPUTE_COMPLEXITY=true
and configure it.Self-supervised Loss hooks
: VISSL has hooks specific to self-supervised approaches like MoCo, SwAV etc. These hooks are handy in performing some intermediate operations required in self-supervision. For example:MoCoHook
is called after every forward pass of the model and updates the momentum encoder network. Users don’t need to do anything special for using these hooks. If the user configuration file has the loss function for an approach, VISSL will automatically enable the hooks for the approach.Logging, checkpoint, training variable update hooks
: These hooks are used by default in VISSL and perform operations like logging the training progress (loss, LR, eta etc) on stdout, save checkpoints etc.