Source code for vissl.models.heads.swav_prototypes_head

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List

import torch
import torch.nn as nn
from vissl.models.heads import register_model_head
from vissl.utils.hydra_config import AttrDict


[docs]@register_model_head("swav_head") class SwAVPrototypesHead(nn.Module): """ SwAV head used in https://arxiv.org/pdf/2006.09882.pdf paper. The head is composed of 2 parts 1) projection of features to lower dimension like 128 2) feature classification into clusters (also called prototypes) The projected features are L2 normalized before clustering step. Input: 4D torch.tensor of shape (N x C x H x W) Output: List(2D torch.tensor of shape N x num_clusters) """
[docs] def __init__( self, model_config: AttrDict, dims: List[int], use_bn: bool, num_clusters: int, use_bias: bool = True, return_embeddings: bool = True, skip_last_bn: bool = True, normalize_feats: bool = True, ): """ Args: model_config (AttrDict): dictionary config.MODEL in the config file dims (int): dimensions of the linear layer. Must have length at least 2. Example: [2048, 2048, 128] attaches linear layer Linear(2048, 2048) -> BN -> Relu -> Linear(2048, 128) use_bn (bool): whether to attach BatchNorm after Linear layer num_clusters (List(int)): number of prototypes or clusters. Typically 3000. Example dims=[3000] will attach 1 prototype head. dims=[3000, 3000] will attach 2 prototype heads use_bias (bool): whether the Linear layer should have bias or not return_embeddings (bool): whether return the projected embeddings or not skip_last_bn (bool): whether to attach BN + Relu at the end of projection head. Example: [2048, 2048, 128] with skip_last_bn=True attaches linear layer Linear(2048, 2048) -> BN -> Relu -> Linear(2048, 128) [2048, 2048, 128] with skip_last_bn=False attaches linear layer Linear(2048, 2048) -> BN -> Relu -> Linear(2048, 128) -> BN -> ReLU This could be particularly useful when performing full finetuning on hidden layers. """ super().__init__() self.normalize_feats = normalize_feats # build the projection head layers = [] last_dim = dims[0] for i, dim in enumerate(dims[1:]): layers.append(nn.Linear(last_dim, dim, bias=use_bias)) if (i == len(dims) - 2) and skip_last_bn: break if use_bn: layers.append( nn.BatchNorm1d( dim, eps=model_config.HEAD.BATCHNORM_EPS, momentum=model_config.HEAD.BATCHNORM_MOMENTUM, ) ) layers.append(nn.ReLU(inplace=True)) last_dim = dim self.projection_head = nn.Sequential(*layers) # prototypes (i.e. centroids) layers if len(num_clusters) > 0: self.nmb_heads = len(num_clusters) for i, k in enumerate(num_clusters): self.add_module( "prototypes" + str(i), nn.Linear(dims[-1], k, bias=False) ) else: self.nmb_heads = 0 self.return_embeddings = return_embeddings
[docs] def forward(self, batch: torch.Tensor): """ Args: batch (4D torch.tensor): shape (N x C x H x W) Returns: List(2D torch.tensor of shape N x num_clusters) """ batch = self.projection_head(batch) if self.normalize_feats: batch = nn.functional.normalize(batch, dim=1, p=2) out = [] if self.return_embeddings: out.append(batch) if self.nmb_heads > 0: for i in range(self.nmb_heads): out.append(getattr(self, "prototypes" + str(i))(batch)) return out