Source code for vissl.models.heads.mlp

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List

import torch
import torch.nn as nn
from vissl.models.heads import register_model_head
from vissl.utils.hydra_config import AttrDict

[docs]@register_model_head("mlp") class MLP(nn.Module): """ This module can be used to attach combination of {Linear, BatchNorm, Relu, Dropout} layers and they are fully configurable from the config file. The module also supports stacking multiple MLPs. Examples: Linear Linear -> BN Linear -> ReLU Linear -> Dropout Linear -> BN -> ReLU -> Dropout Linear -> ReLU -> Dropout Linear -> ReLU -> Linear -> ReLU -> ... Linear -> Linear -> ... ... Accepts a 2D input tensor. Also accepts 4D input tensor of shape `N x C x 1 x 1`. """
[docs] def __init__( self, model_config: AttrDict, dims: List[int], use_bn: bool = False, use_relu: bool = False, use_dropout: bool = False, use_bias: bool = True, ): """ Args: model_config (AttrDict): dictionary config.MODEL in the config file use_bn (bool): whether to attach BatchNorm after Linear layer use_relu (bool): whether to attach ReLU after (Linear (-> BN optional)) use_dropout (bool): whether to attach Dropout after (Linear (-> BN -> relu optional)) use_bias (bool): whether the Linear layer should have bias or not dims (int): dimensions of the linear layer. Example [8192, 1000] which attaches `nn.Linear(8192, 1000, bias=True)` """ super().__init__() layers = [] last_dim = dims[0] for dim in dims[1:]: layers.append(nn.Linear(last_dim, dim, bias=use_bias)) if use_bn: layers.append( nn.BatchNorm1d( dim, eps=model_config.HEAD.BATCHNORM_EPS, momentum=model_config.HEAD.BATCHNORM_MOMENTUM, ) ) if use_relu: layers.append(nn.ReLU(inplace=True)) last_dim = dim if use_dropout: layers.append(nn.Dropout()) self.clf = nn.Sequential(*layers) # we use the default normal or uniform initialization for the layers # and allow users to scale the initialization. self.scale_weights(model_config)
[docs] def scale_weights(self, model_config): params_multiplier = model_config.HEAD.PARAMS_MULTIPLIER for m in self.modules(): if isinstance(m, nn.Linear): *= params_multiplier if m.bias is not None: *= params_multiplier
[docs] def forward(self, batch: torch.Tensor): """ Args: batch (torch.Tensor): 2D torch tensor or 4D tensor of shape `N x C x 1 x 1` Returns: out (torch.Tensor): 2D output torch tensor """ if isinstance(batch, list): assert ( len(batch) == 1 ), "MLP input should be either a tensor (2D, 4D) or list containing 1 tensor." batch = batch[0] batch = torch.squeeze(batch) assert ( len(batch.shape) <= 2 ), f"MLP expected 2D input tensor or 4D tensor of shape NxCx1x1. got: {batch.shape}" out = self.clf(batch) return out