Source code for vissl.losses.nce_loss

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import logging
import math
import pprint
from typing import List, Union

import numpy as np
import torch
from classy_vision.generic.distributed_util import (
    all_reduce_mean,
    gather_from_all,
    get_rank,
)
from classy_vision.generic.util import is_pos_int
from classy_vision.losses import ClassyLoss, register_loss
from torch import nn
from vissl.utils.hydra_config import AttrDict


[docs]@register_loss("nce_loss_with_memory") class NCELossWithMemory(ClassyLoss): """ Distributed version of the NCE loss. It performs an "all_gather" to gather the allocated buffers like memory no a single gpu. For this, Pytorch distributed backend is used. If using NCCL, one must ensure that all the buffer are on GPU. This class supports training using both NCE and CrossEntropy (InfoNCE). This loss is used by NPID (https://arxiv.org/pdf/1805.01978.pdf), NPID++ and PIRL (https://arxiv.org/abs/1912.01991) approaches. Written by: Ishan Misra (imisra@fb.com) Config params: norm_embedding (bool): whether to normalize embeddings temperature (float): the temperature to apply to logits norm_constant (int): Z parameter in the NCEAverage update_mem_with_emb_index (int): In case we have multiple embeddings used in the nce loss, specify which embedding to use to update the memory. loss_type (str): options are "nce" | "cross_entropy". Using the cross_entropy turns the loss into InfoNCE loss. loss_weights (List[float]): if the NCE loss is computed between multiple pairs, we can set a loss weight per term can be used to weight different pair contributions differently negative_sampling_params: num_negatives (int): how many negatives to contrast with type (str): how to select the negatives. options "random" memory_params: memory_size (int): number of training samples as all the samples are stored in memory embedding_dim (int): the projection head output dimension momentum (int): momentum to use to update the memory norm_init (bool): whether to L2 normalize the initialized memory bank update_mem_on_forward (bool): whether to update memory on the forward pass num_train_samples (int): number of unique samples in the training dataset """ def __init__(self, loss_config: AttrDict): super(NCELossWithMemory, self).__init__() self.loss_config = loss_config memory_params = self.loss_config.memory_params memory_params.memory_size = self.loss_config.num_train_samples assert is_pos_int( memory_params.memory_size ), f"Memory size must be positive: {memory_params.memory_size}" assert self.loss_config.loss_type in [ "nce", "cross_entropy", ], f"Supported types are nce/cross_entropy. Found {self.loss_config.loss_type}" self.loss_type = self.loss_config.loss_type self.update_memory_on_forward = memory_params.update_mem_on_forward self.update_memory_emb_index = self.loss_config.update_mem_with_emb_index if self.update_memory_on_forward is False: # we have multiple embeddings used in NCE # but we update memory with only one of them assert self.update_memory_emb_index >= 0 # first setup the NCEAverage method to get the scores of the output wrt # memory bank negatives self.nce_average = NCEAverage( memory_params=memory_params, negative_sampling_params=self.loss_config.negative_sampling_params, T=self.loss_config.temperature, Z=self.loss_config.norm_constant, loss_type=self.loss_type, ) if self.loss_type == "nce": # setup the actual NCE loss self.nce_criterion = NCECriterion(self.loss_config.num_train_samples) elif self.loss_type == "cross_entropy": # cross-entropy loss. Also called InfoNCE self.xe_criterion = nn.CrossEntropyLoss() # other constants self.normalize_embedding = self.loss_config.norm_embedding self.loss_weights = self.loss_config.loss_weights self.init_sync_memory = False self.ignore_index = self.loss_config.get("ignore_index", -1)
[docs] @classmethod def from_config(cls, loss_config: AttrDict): """ Instantiates NCELossWithMemory from configuration. Args: loss_config: configuration for the loss Returns: NCELossWithMemory instance. """ return cls(loss_config)
[docs] def forward( self, output: Union[torch.Tensor, List[torch.Tensor]], target: torch.Tensor ): """ For each output and single target, loss is calculated. """ if isinstance(output, torch.Tensor): output = [output] assert isinstance( output, list ), "Model output should be a list of tensors. Got Type {}".format(type(output)) if not self.init_sync_memory: self.sync_memory() target = target.squeeze() # filter out ignore_index ones non_ignore = target != self.ignore_index target = target[non_ignore] output = [x[non_ignore] for x in output] loss = 0 for (l_idx, l_output) in enumerate(output): normalized_output = l_output if self.normalize_embedding: normalized_output = nn.functional.normalize(l_output, dim=1, p=2) if self.update_memory_on_forward is False: update_memory_on_forward = self.update_memory_emb_index == l_idx else: update_memory_on_forward = True # compare output embeddings to memory bank embeddings and get scores # nce_average is batch x (1 + num_negatives) nce_average = self.nce_average( normalized_output, target, update_memory_on_forward=update_memory_on_forward, ) if self.loss_type == "nce": curr_loss = self.nce_criterion(nce_average, target) elif self.loss_type == "cross_entropy": labels = torch.zeros( (nce_average.shape[0], 1), device=nce_average.device, dtype=torch.int64, ) curr_loss = self.xe_criterion(nce_average, labels) loss += self.loss_weights[l_idx] * curr_loss return loss
[docs] def sync_memory(self): """ Sync memory across all processes before first forward pass. Only needed in the distributed case. After the first forward pass, the update_memory function in NCEAverage does a gather over all embeddings, so memory stays in sync. Doing a gather over embeddings is O(batch size). Syncing memory is O(num items in memory). Generally, batch size << num items in memory. So, we prefer doing the syncs in update_memory. """ self.nce_average.memory = all_reduce_mean(self.nce_average.memory) logging.info(f"Rank: {get_rank()}: Memory synced") # set to true once we are done. forward pass in nce_average will sync after. self.init_sync_memory = True
[docs] def update_memory(self, embedding, y): assert ( self.nce_average.update_memory_on_forward is False ), "Memory was already updated on forward" with torch.no_grad(): if self.normalize_embedding: embedding = nn.functional.normalize(embedding, dim=1, p=2) self.nce_average.update_memory(embedding, y)
def __repr__(self): repr_dict = { "name": self._get_name(), "nce_average": self.nce_average, "loss_type": self.loss_type, "update_emb_index": self.update_memory_emb_index, } return pprint.pformat(repr_dict, indent=2)
# Original implementation: https://github.com/zhirongw/lemniscate.pytorch/blob/master/lib/NCEAverage.py # NOQA # Adapted by: Ishan Misra (imisra@fb.com)
[docs]class NCEAverage(nn.Module): """ Computes the scores of the model embeddings against the `positive' and `negative' samples from the Memory Bank. This class does *NOT* compute the actual loss, just the scores, i.e., inner products followed by normalizations/exponentiation etc. """ def __init__( self, memory_params, negative_sampling_params, T=0.07, Z=-1, loss_type="nce" ): super(NCEAverage, self).__init__() self.nLem = memory_params.memory_size self.loss_type = loss_type self.setup_negative_sampling(negative_sampling_params) self.init_memory(memory_params) self.register_buffer( "params", torch.tensor( [ self.negative_sampling_params.num_negatives, T, Z, self.memory_params.momentum, ] ), )
[docs] def forward(self, embedding, y, idx=None, update_memory_on_forward=None): assert embedding.ndim == 2 assert embedding.shape[1] == self.memory.shape[1] K = int(self.params[0].item()) T = self.params[1].item() Z = self.params[2].item() batchSize = embedding.shape[0] embedding_dim = self.memory.size(1) # score computation if idx is None: idx = self.do_negative_sampling(embedding, y, num_negatives=K) # sample weight = torch.index_select(self.memory, 0, idx.view(-1)).detach() weight = weight.view(batchSize, K + 1, embedding_dim) out = torch.bmm(weight, embedding.view(batchSize, embedding_dim, 1)) out = torch.div(out, T) if self.loss_type == "nce": out = torch.exp(out) # compute partition function if Z < 0: self.compute_partition_function(out) Z = self.params[2].clone().detach().item() out = torch.div(out, Z).contiguous().reshape(batchSize, K + 1) if update_memory_on_forward or ( update_memory_on_forward is None and self.update_memory_on_forward ): self.update_memory(embedding, y) return out
[docs] def compute_partition_function(self, out): num_items = self.memory.size(0) with torch.no_grad(): batch_mean = out.mean() # NOTE: this relies of "mean" computation being stable and deterministic # across all nodes. Could be replaced with smarter ways. if torch.distributed.is_available() and torch.distributed.is_initialized(): batch_mean_gathered = gather_from_all(batch_mean) all_batch_mean = batch_mean_gathered.mean().squeeze().item() else: all_batch_mean = batch_mean.item() self.params[2] = all_batch_mean * num_items Z = self.params[2].clone().detach().item() rank = get_rank() logging.info(f"Rank: {rank}; Normalization constant Z is set to {Z}")
[docs] def do_negative_sampling(self, embedding, y, num_negatives): with torch.no_grad(): if self.negative_sampling_params.type in ["random", "debug"]: batchSize = embedding.shape[0] idx = self.multinomial.draw(batchSize * (num_negatives + 1)).view( batchSize, -1 ) idx.select(1, 0).copy_(y.data) return idx
[docs] def setup_negative_sampling(self, negative_sampling_params): self.negative_sampling_params = negative_sampling_params assert self.negative_sampling_params["type"] in ["random", "debug"] self.negative_sampling_params = negative_sampling_params if self.negative_sampling_params.type == "debug": logging.info("Using debug mode for negative sampling.") logging.info("Will use slower NumpySampler.") self.multinomial = NumpySampler(self.nLem) else: unigrams = torch.ones(self.nLem) self.multinomial = AliasMethod(unigrams) self.num_negatives = negative_sampling_params["num_negatives"]
[docs] def init_memory(self, memory_params): self.memory_params = memory_params num_items = memory_params.memory_size embedding_dim = memory_params.embedding_dim self.update_memory_on_forward = memory_params.update_mem_on_forward stdv = 1.0 / math.sqrt(embedding_dim / 3) self.register_buffer( "memory", torch.rand(num_items, embedding_dim).mul_(2 * stdv).add_(-stdv) ) if memory_params.norm_init: self.memory = nn.functional.normalize(self.memory, p=2, dim=1) sample_norm = self.memory[:10].norm(dim=1).mean() mem_info = f"Init memory: {self.memory.shape}; \ stdv: {stdv}; normalize: {memory_params.norm_init} \ norm: {sample_norm}" logging.info(f"Rank: {get_rank()} - {mem_info}")
[docs] def update_memory(self, embedding, y): momentum = self.params[3].item() if torch.distributed.is_available() and torch.distributed.is_initialized(): # gather all embeddings. embedding_gathered = gather_from_all(embedding) y_gathered = gather_from_all(y) else: embedding_gathered = embedding y_gathered = y # update memory with torch.no_grad(): # Assumption: memory_size >= y.max() assert y_gathered.max() < self.memory.shape[0], ( f"Memory bank {self.memory.shape} is not sufficient " f"to hold index: {y_gathered.max()}" ) l_pos = torch.index_select(self.memory, 0, y_gathered.view(-1)) l_pos.mul_(momentum) l_pos.add_(torch.mul(embedding_gathered, 1 - momentum)) updated_l = nn.functional.normalize(l_pos, p=2, dim=1) self.memory.index_copy_(0, y_gathered, updated_l)
def __repr__(self): num_negatives = int(self.params[0].item()) T = self.params[1].item() Z = self.params[2].item() momentum = self.params[3].item() repr_dict = { "name": self._get_name(), "T": T, "Z": Z, "num_negatives": num_negatives, "momentum": momentum, "memory_buffer_size": self.memory.shape, "negative_sampling": self.negative_sampling_params, "memory": self.memory_params, "update_memory_on_forward": self.update_memory_on_forward, } return pprint.pformat(repr_dict, indent=2)
# Credits: https://github.com/zhirongw/lemniscate.pytorch/blob/master/lib/alias_multinomial.property # NOQA
[docs]class AliasMethod(nn.Module): """ A fast way to sample from a multinomial distribution. Faster than torch.multinomial or np.multinomial. The setup (__init__) for this class is slow, however `draw' (actual sampling) is fast. """ def __init__(self, probs): super(AliasMethod, self).__init__() if probs.sum() > 1: probs.div_(probs.sum()) K = len(probs) self.register_buffer("prob", torch.zeros(K)) self.register_buffer("alias", torch.LongTensor([0] * K)) # Sort the data into the outcomes with probabilities that are larger and # smaller than 1/K. smaller = [] larger = [] for kk, prob in enumerate(probs): self.prob[kk] = K * prob if self.prob[kk] < 1.0: smaller.append(kk) else: larger.append(kk) # Loop though and create little binary mixtures that appropriately allocate # the larger outcomes over the overall uniform mixture. while len(smaller) > 0 and len(larger) > 0: small = smaller.pop() large = larger.pop() self.alias[small] = large self.prob[large] = (self.prob[large] - 1.0) + self.prob[small] if self.prob[large] < 1.0: smaller.append(large) else: larger.append(large) for last_one in smaller + larger: self.prob[last_one] = 1
[docs] def draw(self, N): """ Draw N samples from multinomial :param N: number of samples :return: samples """ K = self.alias.size(0) kk = torch.zeros(N, dtype=torch.long, device=self.prob.device).random_(0, K) prob = self.prob.index_select(0, kk) alias = self.alias.index_select(0, kk) # b is whether a random number is greater than q b = torch.bernoulli(prob) oq = kk.mul(b.long()) oj = alias.mul((1 - b).long()) return oq + oj
# Numpy based sampler. Useful for debugging # This Sampler is faster to setup than AliasMethod. # However, it is **much** slower to sample from it.
[docs]class NumpySampler(object): def __init__(self, high): self.high = high
[docs] def draw(self, num_negatives): rand_nums = np.random.choice(self.high, size=num_negatives) rand_nums = torch.from_numpy(rand_nums).long().cuda() return rand_nums
# Core NCE Criterion that computes cross entropy with a fixed prior on negatives # from lemniscate # Credits: https://github.com/zhirongw/lemniscate.pytorch/blob/master/lib/NCECriterion.property # NOQA
[docs]class NCECriterion(nn.Module): def __init__(self, nLem): super(NCECriterion, self).__init__() self.nLem = nLem self.eps = float(np.finfo("float32").eps)
[docs] def forward(self, x, targets): batchSize = x.size(0) K = x.size(1) - 1 Pnt = 1 / float(self.nLem) Pns = 1 / float(self.nLem) # eq 5.1 : P(origin=model) = Pmt / (Pmt + k*Pnt) Pmt = x.select(1, 0) Pmt_div = Pmt.add(K * Pnt + self.eps) lnPmt = torch.div(Pmt, Pmt_div) # eq 5.2 : P(origin=noise) = k*Pns / (Pms + k*Pns) Pon_div = x.narrow(1, 1, K).add(K * Pns + self.eps) Pon = Pon_div.clone().fill_(K * Pns) lnPon = torch.div(Pon, Pon_div) # equation 6 in ref. A lnPmt.log_() lnPon.log_() lnPmtsum = lnPmt.sum(0) lnPonsum = lnPon.view(-1, 1).sum(0) loss = -(lnPmtsum + lnPonsum) / batchSize return loss