Source code for vissl.losses.moco_loss

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import logging
import pprint
from collections import namedtuple

import torch
from classy_vision.losses import ClassyLoss, register_loss
from torch import nn
from vissl.utils.misc import concat_all_gather


_MoCoLossConfig = namedtuple(
    "_MoCoLossConfig", ["embedding_dim", "queue_size", "momentum", "temperature"]
)


[docs]class MoCoLossConfig(_MoCoLossConfig): """ Settings for the MoCo loss"""
[docs] @staticmethod def defaults() -> "MoCoLossConfig": return MoCoLossConfig( embedding_dim=128, queue_size=65536, momentum=0.999, temperature=0.2 )
[docs]@register_loss("moco_loss") class MoCoLoss(ClassyLoss): """ This is the loss which was proposed in the "Momentum Contrast for Unsupervised Visual Representation Learning" paper, from Kaiming He et al. See http://arxiv.org/abs/1911.05722 for details and https://github.com/facebookresearch/moco for a reference implementation, reused here Config params: embedding_dim (int): head output output dimension queue_size (int): number of elements in queue momentum (float): encoder momentum value for the update temperature (float): temperature to use on the logits """ def __init__(self, config: MoCoLossConfig): super().__init__() self.loss_config = config # Create the queue self.register_buffer( "queue", torch.randn(self.loss_config.embedding_dim, self.loss_config.queue_size), ) self.queue = nn.functional.normalize(self.queue, dim=0) self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) self.criterion = nn.CrossEntropyLoss() self.initialized = False self.key = None self.sample = None self.moco_encoder = None self.checkpoint = None
[docs] @classmethod def from_config(cls, config: MoCoLossConfig): """ Instantiates MoCoLoss from configuration. Args: loss_config: configuration for the loss Returns: MoCoLoss instance. """ return cls(config)
@torch.no_grad() def _dequeue_and_enqueue(self, key: torch.Tensor): """ Discard the oldest key from the MoCo queue, save the newest one, through a round-robin mechanism """ # gather keys before updating queue /!\ the queue is duplicated on all GPUs keys = concat_all_gather(key) batch_size = keys.shape[0] # for simplicity, removes the case where the batch overlaps with the end # of the queue assert ( self.loss_config.queue_size % batch_size == 0 ), "The queue size needs to be a multiple of the batch size" # replace the keys at ptr (dequeue and enqueue) ptr = int(self.queue_ptr) self.queue[:, ptr : ptr + batch_size] = keys.T ptr = ( ptr + batch_size ) % self.loss_config.queue_size # move pointer, round robin self.queue_ptr[0] = ptr
[docs] def forward(self, query: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Given the encoder queries, the key and the queue of the previous queries, compute the cross entropy loss for this batch Args: query: output of the encoder given the current batch Returns: loss """ if not self.initialized: self.queue = self.queue.to(query.device) self.initialized = True # -- # Normalize the encoder raw outputs query = nn.functional.normalize(query, dim=1) # -- # Compute all the logits and the expected labels # Einstein sum is used in MoCo, deemed more intuitive. # Another option is `torch.diag(torch.matmul(query, self.key.T))` # positive logits: Nx1 l_pos = torch.einsum("nc,nc->n", [query, self.key]).unsqueeze(-1) # negative logits: NxK l_neg = torch.einsum("nc,ck->nk", [query, self.queue.clone().detach()]) # logits: Nx(1+K) logits = torch.cat([l_pos, l_neg], dim=1) # apply temperature logits /= self.loss_config.temperature # labels: positives are the first rank. # This is essentially a classification problem alongside all the samples labels = torch.zeros(logits.shape[0], dtype=torch.long).to(query.device) # --- # Update the queue for the next time self._dequeue_and_enqueue(self.key) # --- # Then just apply the XELoss return self.criterion(logits, labels)
def __repr__(self): repr_dict = {"name": self._get_name()} return pprint.pformat(repr_dict, indent=2)
[docs] def load_state_dict(self, state_dict, *args, **kwargs): """ Restore the loss state given a checkpoint Args: state_dict (serialized via torch.save) """ # If the encoder has been allocated, use the normal pytorch restoration if self.moco_encoder is None: self.checkpoint = state_dict logging.info("Storing the checkpoint for later use") else: logging.info("Restoring checkpoint") super().load_state_dict(state_dict, *args, **kwargs)