# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import pprint
import numpy as np
import torch
from classy_vision.generic.distributed_util import get_cuda_device_index, get_rank
from classy_vision.losses import ClassyLoss, register_loss
from torch import nn
from vissl.utils.distributed_gradients import gather_from_all
from vissl.utils.hydra_config import AttrDict
[docs]@register_loss("simclr_info_nce_loss")
class SimclrInfoNCELoss(ClassyLoss):
"""
This is the loss which was proposed in SimCLR https://arxiv.org/abs/2002.05709 paper.
See the paper for the details on the loss.
Config params:
temperature (float): the temperature to be applied on the logits
buffer_params:
world_size (int): total number of trainers in training
embedding_dim (int): output dimensions of the features projects
effective_batch_size (int): total batch size used (includes positives)
"""
def __init__(self, loss_config: AttrDict, device: str = "gpu"):
super(SimclrInfoNCELoss, self).__init__()
self.loss_config = loss_config
# loss constants
self.temperature = self.loss_config.temperature
self.buffer_params = self.loss_config.buffer_params
self.info_criterion = SimclrInfoNCECriterion(
self.buffer_params, self.temperature
)
[docs] @classmethod
def from_config(cls, loss_config: AttrDict):
"""
Instantiates SimclrInfoNCELoss from configuration.
Args:
loss_config: configuration for the loss
Returns:
SimclrInfoNCELoss instance.
"""
return cls(loss_config)
[docs] def forward(self, output, target):
normalized_output = nn.functional.normalize(output, dim=1, p=2)
loss = self.info_criterion(normalized_output)
return loss
def __repr__(self):
repr_dict = {"name": self._get_name(), "info_average": self.info_criterion}
return pprint.pformat(repr_dict, indent=2)
[docs]class SimclrInfoNCECriterion(nn.Module):
"""
The criterion corresponding to the SimCLR loss as defined in the paper
https://arxiv.org/abs/2002.05709.
Args:
temperature (float): the temperature to be applied on the logits
buffer_params:
world_size (int): total number of trainers in training
embedding_dim (int): output dimensions of the features projects
effective_batch_size (int): total batch size used (includes positives)
"""
def __init__(self, buffer_params, temperature: float):
super(SimclrInfoNCECriterion, self).__init__()
self.use_gpu = get_cuda_device_index() > -1
self.temperature = temperature
self.num_pos = 2
self.buffer_params = buffer_params
self.criterion = nn.CrossEntropyLoss()
self.dist_rank = get_rank()
self.pos_mask = None
self.neg_mask = None
self.precompute_pos_neg_mask()
logging.info(f"Creating Info-NCE loss on Rank: {self.dist_rank}")
[docs] def precompute_pos_neg_mask(self):
"""
We precompute the positive and negative masks to speed up the loss calculation
"""
# computed once at the begining of training
total_images = self.buffer_params.effective_batch_size
world_size = self.buffer_params.world_size
batch_size = total_images // world_size
orig_images = batch_size // self.num_pos
rank = self.dist_rank
pos_mask = torch.zeros(batch_size, total_images)
neg_mask = torch.zeros(batch_size, total_images)
all_indices = np.arange(total_images)
pos_members = orig_images * np.arange(self.num_pos)
orig_members = torch.arange(orig_images)
for anchor in np.arange(self.num_pos):
for img_idx in range(orig_images):
delete_inds = batch_size * rank + img_idx + pos_members
neg_inds = torch.tensor(np.delete(all_indices, delete_inds)).long()
neg_mask[anchor * orig_images + img_idx, neg_inds] = 1
for pos in np.delete(np.arange(self.num_pos), anchor):
pos_inds = batch_size * rank + pos * orig_images + orig_members
pos_mask[
torch.arange(
anchor * orig_images, (anchor + 1) * orig_images
).long(),
pos_inds.long(),
] = 1
self.pos_mask = pos_mask.cuda(non_blocking=True) if self.use_gpu else pos_mask
self.neg_mask = neg_mask.cuda(non_blocking=True) if self.use_gpu else neg_mask
[docs] def forward(self, embedding: torch.Tensor):
"""
Calculate the loss. Operates on embeddings tensor.
"""
assert embedding.ndim == 2
assert embedding.shape[1] == int(self.buffer_params.embedding_dim)
batch_size = embedding.shape[0]
T = self.temperature
num_pos = self.num_pos
assert batch_size % num_pos == 0, "Batch size should be divisible by num_pos"
# Step 1: gather all the embeddings. Shape example: 4096 x 128
embeddings_buffer = self.gather_embeddings(embedding)
# Step 2: matrix multiply: 64 x 128 with 4096 x 128 = 64 x 4096 and
# divide by temperature.
similarity = torch.exp(torch.mm(embedding, embeddings_buffer.t()) / T)
pos = torch.sum(similarity * self.pos_mask, 1)
neg = torch.sum(similarity * self.neg_mask, 1)
loss = -(torch.mean(torch.log(pos / (pos + neg))))
return loss
def __repr__(self):
num_negatives = self.buffer_params.effective_batch_size - 2
T = self.temperature
num_pos = self.num_pos
repr_dict = {
"name": self._get_name(),
"temperature": T,
"num_negatives": num_negatives,
"num_pos": num_pos,
"dist_rank": self.dist_rank,
}
return pprint.pformat(repr_dict, indent=2)
[docs] @staticmethod
def gather_embeddings(embedding: torch.Tensor):
"""
Do a gather over all embeddings, so we can compute the loss.
Final shape is like: (batch_size * num_gpus) x embedding_dim
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
embedding_gathered = gather_from_all(embedding)
else:
embedding_gathered = embedding
return embedding_gathered