Source code for vissl.models.heads.linear_eval_mlp

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List

import torch
import torch.nn as nn
from vissl.models.heads import register_model_head
from vissl.models.heads.mlp import MLP
from vissl.utils.hydra_config import AttrDict


[docs]@register_model_head("eval_mlp") class LinearEvalMLP(nn.Module): """ A standard Linear classification module that can be attached to several layers of the model to evaluate the representation quality of features. The layers attached are: BatchNorm2d -> Linear (1 or more) Accepts a 4D input tensor. If you want to use 2D input tensor instead, use the "mlp" head directly. """
[docs] def __init__( self, model_config: AttrDict, in_channels: int, dims: List[int], use_bn: bool = False, use_relu: bool = False, ): """ Args: model_config (AttrDict): dictionary config.MODEL in the config file in_channels (int): number of channels the input has. This information is used to attached the BatchNorm2D layer. dims (int): dimensions of the linear layer. Example [8192, 1000] which means attaches `nn.Linear(8192, 1000, bias=True)` """ super().__init__() self.channel_bn = nn.BatchNorm2d( in_channels, eps=model_config.HEAD.BATCHNORM_EPS, momentum=model_config.HEAD.BATCHNORM_MOMENTUM, ) self.clf = MLP(model_config, dims, use_bn=use_bn, use_relu=use_relu)
[docs] def forward(self, batch: torch.Tensor): """ Args: batch (torch.Tensor): 4D torch tensor. This layer is meant to be attached at several parts of the model to evaluate feature representation quality. For 2D input tensor, the tensor is unsqueezed to NxDx1x1 and then eval_mlp is applied Returns: out (torch.Tensor): 2D output torch tensor """ # in case of a 2D tensor input, unsqueeze the tensor to (N x D x 1 x 1) and apply # eval mlp normally. if len(batch.shape) == 2: batch = batch.unsqueeze(2).unsqueeze(3) assert len(batch.shape) == 4, "Eval MLP head expects 4D tensor input" out = self.channel_bn(batch) out = torch.flatten(out, start_dim=1) out = self.clf(out) return out