Source code for vissl.losses.swav_momentum_loss

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import logging
import math
import pprint

import numpy as np
import torch
from classy_vision.generic.distributed_util import (
from classy_vision.losses import ClassyLoss, register_loss
from torch import nn
from vissl.utils.hydra_config import AttrDict

[docs]@register_loss("swav_momentum_loss") class SwAVMomentumLoss(ClassyLoss): """ This loss extends the SwAV loss proposed in paper by Caron et al. The loss combines the benefits of using the SwAV approach with the momentum encoder as used in MoCo. Config params: momentum (float): for the momentum encoder momentum_eval_mode_iter_start (int): from what iteration should the momentum encoder network be in eval mode embedding_dim (int): the projection head output dimension temperature (float): temperature to be applied to the logits use_double_precision (bool): whether to use double precision for the loss. This could be a good idea to avoid NaNs. normalize_last_layer (bool): whether to normalize the last layer num_iters (int): number of sinkhorn algorithm iterations to make epsilon (float): see the paper for details num_crops (int): number of crops used crops_for_assign (List[int]): what crops to use for assignment num_prototypes (List[int]): number of prototypes queue: queue_length (int): number of features to store and used in the scores start_iter (int): when to start using the queue for the scores local_queue_length (int): length of queue per gpu """ def __init__(self, loss_config: AttrDict): super().__init__() self.loss_config = loss_config self.momentum_encoder = None self.checkpoint = None self.momentum_scores = None self.momentum_embeddings = None self.is_distributed = is_distributed_training_run() self.use_gpu = get_cuda_device_index() > -1 self.softmax = nn.Softmax(dim=1) # keep track of number of iterations self.register_buffer("num_iteration", torch.zeros(1, dtype=int)) # for queue self.use_queue = False if self.loss_config.queue.local_queue_length > 0: self.initialize_queue()
[docs] @classmethod def from_config(cls, loss_config: AttrDict): """ Instantiates SwAVMomentumLoss from configuration. Args: loss_config: configuration for the loss Returns: SwAVMomentumLoss instance. """ return cls(loss_config)
[docs] def initialize_queue(self): for i, nmb_proto in enumerate(self.loss_config.num_prototypes): init_queue = ( torch.rand( len(self.loss_config.crops_for_assign), self.loss_config.queue.local_queue_length, nmb_proto, ) * 2 - 1 ) self.register_buffer("local_queue" + str(i), init_queue) stdv = 1.0 / math.sqrt(self.loss_config.embedding_dim / 3) init_queue = ( torch.rand( len(self.loss_config.crops_for_assign), self.loss_config.queue.local_queue_length, self.loss_config.embedding_dim, ) .mul_(2 * stdv) .add_(-stdv) ) self.register_buffer("local_emb_queue", init_queue)
[docs] def load_state_dict(self, state_dict, *args, **kwargs): """ Restore the loss state given a checkpoint Args: state_dict (serialized via """ # If the encoder has been allocated, use the normal pytorch restoration if self.momentum_encoder is None: self.checkpoint = state_dict"Storing the checkpoint for later use") else:"Restoring checkpoint") super().load_state_dict(state_dict, *args, **kwargs)
[docs] def forward(self, output: torch.Tensor, *args, **kwargs): self.use_queue = ( self.loss_config.queue.local_queue_length > 0 and self.num_iteration >= self.loss_config.queue.start_iter ) if self.use_queue: if self.is_distributed: self.compute_queue_scores(self.momentum_encoder.module.heads[0]) else: self.compute_queue_scores(self.momentum_encoder.heads[0]) loss = 0 for head_id, proto_scores in enumerate(output[1:]): bs = proto_scores.shape[0] // self.loss_config.num_crops sub_loss = 0 for j, crop_id in enumerate(self.loss_config.crops_for_assign): with torch.no_grad(): scores_this_crop = self.momentum_scores[head_id][ j * bs : (j + 1) * bs ] if self.use_queue: queue = getattr(self, "local_queue" + str(head_id))[j].clone() scores_this_crop =, queue)) assignments = torch.exp( scores_this_crop / self.loss_config.epsilon ).t() assignments = self.distributed_sinkhornknopp(assignments)[:bs] idx_crop_pred = np.delete( np.arange(self.loss_config.num_crops), crop_id ) subsubloss = 0 for p in idx_crop_pred: subsubloss -= torch.mean( torch.sum( assignments * torch.log( self.softmax( proto_scores[bs * p : bs * (p + 1)] / self.loss_config.temperature ) ), dim=1, ) ) sub_loss += subsubloss / len(idx_crop_pred) loss += sub_loss / len(self.loss_config.crops_for_assign) loss /= len(output) - 1 self.num_iteration += 1 if self.use_queue: self.update_emb_queue() return loss
def __repr__(self): repr_dict = {"name": self._get_name()} return pprint.pformat(repr_dict, indent=2)
[docs] def distributed_sinkhornknopp(self, Q: torch.Tensor): """ Apply the distributed sinknorn optimization on the scores matrix to find the assignments """ with torch.no_grad(): sum_Q = torch.sum(Q, dtype=Q.dtype) all_reduce_sum(sum_Q) Q /= sum_Q k = Q.shape[0] n = Q.shape[1] N = get_world_size() * Q.shape[1] # we follow the u, r, c and Q notations from # r = torch.ones(k) / k c = torch.ones(n) / N if self.use_gpu: r = r.cuda(non_blocking=True) c = c.cuda(non_blocking=True) curr_sum = torch.sum(Q, dim=1, dtype=Q.dtype) all_reduce_sum(curr_sum) for _ in range(self.loss_config.num_iters): u = curr_sum Q *= (r / u).unsqueeze(1) Q *= (c / torch.sum(Q, dim=0, dtype=Q.dtype)).unsqueeze(0) curr_sum = torch.sum(Q, dim=1, dtype=Q.dtype) all_reduce_sum(curr_sum) return (Q / torch.sum(Q, dim=0, keepdim=True, dtype=Q.dtype)).t().float()
[docs] def update_emb_queue(self): with torch.no_grad(): bs = len(self.momentum_embeddings) // self.loss_config.num_crops for i in range(len(self.loss_config.crops_for_assign)): queue = self.local_emb_queue[i] queue[bs:] = queue[:-bs].clone() queue[:bs] = self.momentum_embeddings[i * bs : (i + 1) * bs] self.local_emb_queue[i] = queue
[docs] def compute_queue_scores(self, head): with torch.no_grad(): for i in range(len(self.loss_config.crops_for_assign)): for h in range(head.nmb_heads): scores = getattr(head, "prototypes" + str(h))( self.local_emb_queue[i] ) getattr(self, "local_queue" + str(h))[i] = scores