Source code for vissl.losses.multicrop_simclr_info_nce_loss

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import logging
import pprint

import numpy as np
import torch
from classy_vision.losses import register_loss
from vissl.losses.simclr_info_nce_loss import SimclrInfoNCECriterion, SimclrInfoNCELoss
from vissl.utils.hydra_config import AttrDict


[docs]@register_loss("multicrop_simclr_info_nce_loss") class MultiCropSimclrInfoNCELoss(SimclrInfoNCELoss): """ Expanded version of the SimCLR loss. The SimCLR loss works only on 2 positives. We expand the loss to work for more positives following the multi-crop augmentation proposed in SwAV paper. See SwAV paper https://arxiv.org/abs/2006.09882 for the multi-crop augmentation details. Config params: temperature (float): the temperature to be applied on the logits num_crops (int): number of positives used buffer_params: world_size (int): total number of trainers in training embedding_dim (int): output dimensions of the features projects effective_batch_size (int): total batch size used (includes positives) """ def __init__(self, loss_config: AttrDict, device: str = "gpu"): super(SimclrInfoNCELoss, self).__init__() self.loss_config = loss_config # loss constants self.temperature = self.loss_config.temperature self.buffer_params = self.loss_config.buffer_params self.num_crops = self.loss_config.num_crops self.info_criterion = MultiCropSimclrInfoNCECriterion( self.buffer_params, self.temperature, self.num_crops ) def __repr__(self): repr_dict = { "name": self._get_name(), "info_average": self.info_criterion, "num_crops": self.num_crops, } return pprint.pformat(repr_dict, indent=2)
[docs]class MultiCropSimclrInfoNCECriterion(SimclrInfoNCECriterion): """ The criterion corresponding to the expandion SimCLR loss (as defined in the paper https://arxiv.org/abs/2002.05709) using the multi-crop augmentaion proposed in SwAV paper. The multi-crop augmentation allows using more positives per image. Args: temperature (float): the temperature to be applied on the logits num_crops (int): number of positives buffer_params: world_size (int): total number of trainers in training embedding_dim (int): output dimensions of the features projects effective_batch_size (int): total batch size used (includes positives) """ def __init__(self, buffer_params, temperature: float, num_crops: int): self.num_crops = num_crops super(MultiCropSimclrInfoNCECriterion, self).__init__( buffer_params, temperature ) logging.info(f"Setting multicrop num_crops: {num_crops}")
[docs] def precompute_pos_neg_mask(self): """ We precompute the positive and negative masks to speed up the loss calculation """ # computed once at the beginning of training total_images = self.buffer_params.effective_batch_size world_size = self.buffer_params.world_size local_orig_images = total_images // world_size local_crops = local_orig_images * self.num_crops rank = self.dist_rank pos_temps = [] for d in np.arange(self.num_crops): pos_temp, neg_temp = [], [] for i in range(world_size): if i == rank: pos = np.eye(local_crops, k=d * local_orig_images) + np.eye( local_crops, k=-local_crops + d * local_orig_images ) neg = np.ones((local_crops, local_crops)) else: pos = np.zeros((local_crops, local_crops)) neg = np.zeros((local_crops, local_crops)) pos_temp.append(pos) neg_temp.append(neg) pos_temps.append(np.hstack(pos_temp)) neg_temp = np.hstack(neg_temp) pos_mask = [] for i in range(self.num_crops - 1): pos_mask.append(torch.from_numpy(pos_temps[1 + i])) neg_mask = torch.from_numpy(neg_temp - sum(pos_temps)) if self.use_gpu: for i in range(len(pos_mask)): pos_mask[i] = pos_mask[i].cuda(non_blocking=True) neg_mask = neg_mask.cuda(non_blocking=True) self.pos_mask, self.neg_mask = pos_mask, neg_mask
[docs] def forward(self, embedding: torch.Tensor): """ Calculate the loss. Operates on embeddings tensor. """ assert embedding.ndim == 2 assert embedding.shape[1] == int(self.buffer_params.embedding_dim) batch_size = embedding.shape[0] T, num_crops = self.temperature, self.num_crops assert ( batch_size % num_crops == 0 ), "Batch size should be divisible by num_crops" # Step 1: gather all the embeddings. Shape example: 4096 x 128 embeddings_buffer = self.gather_embeddings(embedding) # Step 2: matrix multiply: 64 x 128 with 4096 x 128 = 64 x 4096 and # divide by temperature. similarity = torch.exp(torch.mm(embedding, embeddings_buffer.t()) / T) Z, loss = 0.0, 0.0 for loss_id in range(len(self.pos_mask)): pos = torch.sum(similarity * self.pos_mask[loss_id], 1) neg = torch.sum(similarity * self.neg_mask, 1) idx = (1 - torch.sum(self.pos_mask[loss_id], 1) > 0).detach() term_prob = pos / (pos + neg) term_prob[idx] = 1.0 term_loss = torch.log(term_prob) Z += torch.sum(~idx).detach() loss -= torch.sum(term_loss) loss /= Z return loss
def __repr__(self): repr_dict = { "name": self._get_name(), "temperature": self.temperature, "dist_rank": self.dist_rank, } return pprint.pformat(repr_dict, indent=2)