Source code for vissl.optimizers.optimizer_helper

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import logging
from typing import Any, List

import torch.nn as nn
from vissl.utils.misc import is_apex_available


_CONV_TYPES = (nn.Conv1d, nn.Conv2d, nn.Conv3d)

_BN_TYPES = (
    nn.BatchNorm1d,
    nn.BatchNorm2d,
    nn.BatchNorm3d,
    nn.SyncBatchNorm,  # pytorch SyncBN
)

if is_apex_available():
    import apex

    _BN_TYPES += (apex.parallel.SyncBatchNorm,)


def _get_bn_optimizer_params(
    module, regularized_params, unregularized_params, optimizer_config
):
    """
    Given the (Sync)BatchNorm module in the model, we separate the module params
    into regularized or non-regularized (weight_decay=0).
    """
    # this is called by get_optimizer_params for BN specific layer only
    if module.weight is not None:
        if optimizer_config["regularize_bn"]:
            regularized_params.append(module.weight)
        else:
            unregularized_params.append(module.weight)
    if module.bias is not None:
        if optimizer_config["regularize_bn"] and optimizer_config["regularize_bias"]:
            regularized_params.append(module.bias)
        else:
            unregularized_params.append(module.bias)
    return regularized_params, unregularized_params


def _filter_trainable(param_list: List[Any]) -> List[Any]:
    """
    Keep on the trainable parameters of the model and return the list of
    trainable params.
    """
    # Keep only the trainable params
    return list(filter(lambda x: x.requires_grad, param_list))


[docs]def get_optimizer_param_groups( model, model_config, optimizer_config, optimizer_schedulers ): """ Go through all the layers, sort out which parameters should be regularized, unregularized and optimization settings for the head/trunk. We filter the trainable params only and add them to the param_groups. Returns: param_groups (List[Dict]): [ { "params": trunk_regularized_params, "lr": lr_value, "weight_decay": wd_value, }, { "params": trunk_unregularized_params, "lr": lr_value, "weight_decay": 0.0, }, { "params": head_regularized_params, "lr": head_lr_value, "weight_decay": head_weight_decay, }, { "params": head_unregularized_params, "lr": head_lr_value, "weight_decay": 0.0, }, { "params": remaining_regularized_params, "lr": lr_value } ] """ # if the different LR, weight decay value for head is not specified, we use the # same LR/wd as trunk. if not optimizer_config.head_optimizer_params.use_different_lr: assert "lr_head" in optimizer_schedulers # we create 4 params groups: trunk regularized, trunk unregularized, head regularized # and head unregularized. Unregularized can contain BN layers. trunk_regularized_params, trunk_unregularized_params = [], [] head_regularized_params, head_unregularized_params = [], [] # for anything else regularized_params = [] for name, module in model.named_modules(): # head, Linear/Conv layer if "head" in name and ( isinstance(module, nn.Linear) or isinstance(module, _CONV_TYPES) ): head_regularized_params.append(module.weight) if module.bias is not None: if optimizer_config["regularize_bias"]: head_regularized_params.append(module.bias) else: head_unregularized_params.append(module.bias) # head, BN layer elif "head" in name and isinstance(module, _BN_TYPES): ( head_regularized_params, head_unregularized_params, ) = _get_bn_optimizer_params( module, head_regularized_params, head_unregularized_params, optimizer_config, ) # trunk, Linear/Conv elif isinstance(module, nn.Linear) or isinstance(module, _CONV_TYPES): trunk_regularized_params.append(module.weight) if module.bias is not None: if optimizer_config["regularize_bias"]: trunk_regularized_params.append(module.bias) else: trunk_regularized_params.append(module.bias) # trunk, BN layer elif isinstance(module, _BN_TYPES): ( trunk_regularized_params, trunk_unregularized_params, ) = _get_bn_optimizer_params( module, trunk_regularized_params, trunk_unregularized_params, optimizer_config, ) elif len(list(module.children())) >= 0: # for any other layers not bn_types, conv_types or nn.Linear, if # the layers are the leaf nodes and have parameters, we regularize # them. Similarly, if non-leaf nodes but have parameters, regularize # them (set recurse=False) for params in module.parameters(recurse=False): regularized_params.append(params) # for non-trainable params, set the requires_grad to False non_trainable_params = [] for name, param in model.named_parameters(): if name in model_config.NON_TRAINABLE_PARAMS: param.requires_grad = False non_trainable_params.append(param) trainable_params = _filter_trainable(model.parameters()) trunk_regularized_params = _filter_trainable(trunk_regularized_params) trunk_unregularized_params = _filter_trainable(trunk_unregularized_params) head_regularized_params = _filter_trainable(head_regularized_params) head_unregularized_params = _filter_trainable(head_unregularized_params) regularized_params = _filter_trainable(regularized_params) logging.info( f"\nTrainable params: {len(trainable_params)}, \n" f"Non-Trainable params: {len(non_trainable_params)}, \n" f"Trunk Regularized Parameters: {len(trunk_regularized_params)}, \n" f"Trunk Unregularized Parameters {len(trunk_unregularized_params)}, \n" f"Head Regularized Parameters: {len(head_regularized_params)}, \n" f"Head Unregularized Parameters: {len(head_unregularized_params)} \n" f"Remaining Regularized Parameters: {len(regularized_params)} " ) param_groups = [ { "params": trunk_regularized_params, "lr": optimizer_schedulers["lr"], "weight_decay": optimizer_config.weight_decay, }, { "params": trunk_unregularized_params, "lr": optimizer_schedulers["lr"], "weight_decay": 0.0, }, { "params": head_regularized_params, "lr": optimizer_schedulers["lr_head"], "weight_decay": optimizer_config.head_optimizer_params.weight_decay, }, { "params": head_unregularized_params, "lr": optimizer_schedulers["lr_head"], "weight_decay": 0.0, }, ] if len(regularized_params) > 0: param_groups.append( {"params": regularized_params, "lr": optimizer_schedulers["lr"]} ) return param_groups