Source code for vissl.losses.swav_loss

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import logging
import math
import os
import pprint
from typing import List

import numpy as np
import torch
from classy_vision.generic.distributed_util import (
    all_reduce_max,
    all_reduce_sum,
    get_cuda_device_index,
    get_rank,
    get_world_size,
)
from classy_vision.losses import ClassyLoss, register_loss
from fvcore.common.file_io import PathManager
from torch import nn
from vissl.utils.hydra_config import AttrDict


[docs]@register_loss("swav_loss") class SwAVLoss(ClassyLoss): """ This loss is proposed by the SwAV paper https://arxiv.org/abs/2006.09882 by Caron et al. See the paper for more details about the loss. Config params: embedding_dim (int): the projection head output dimension temperature (float): temperature to be applied to the logits use_double_precision (bool): whether to use double precision for the loss. This could be a good idea to avoid NaNs. normalize_last_layer (bool): whether to normalize the last layer num_iters (int): number of sinkhorn algorithm iterations to make epsilon (float): see the paper for details num_crops (int): number of crops used crops_for_assign (List[int]): what crops to use for assignment num_prototypes (List[int]): number of prototypes temp_hard_assignment_iters (int): whether to do hard assignment for the initial few iterations output_dir (str): for dumping the debugging info in case loss becomes NaN queue: queue_length (int): number of features to store and used in the scores start_iter (int): when to start using the queue for the scores local_queue_length (int): length of queue per gpu """ def __init__(self, loss_config: AttrDict): super().__init__() self.loss_config = loss_config self.queue_start_iter = self.loss_config.queue.start_iter self.swav_criterion = SwAVCriterion( self.loss_config.temperature, self.loss_config.crops_for_assign, self.loss_config.num_crops, self.loss_config.num_iters, self.loss_config.epsilon, self.loss_config.use_double_precision, self.loss_config.num_prototypes, self.loss_config.queue.local_queue_length, self.loss_config.embedding_dim, self.loss_config.temp_hard_assignment_iters, self.loss_config.output_dir, )
[docs] @classmethod def from_config(cls, loss_config: AttrDict): """ Instantiates SwAVLoss from configuration. Args: loss_config: configuration for the loss Returns: SwAVLoss instance. """ return cls(loss_config)
[docs] def forward(self, output: torch.Tensor, target: torch.Tensor): self.swav_criterion.use_queue = ( self.swav_criterion.local_queue_length > 0 and self.swav_criterion.num_iteration >= self.queue_start_iter ) loss = 0 for i, prototypes_scores in enumerate(output[1:]): loss += self.swav_criterion(prototypes_scores, i) loss /= len(output) - 1 self.swav_criterion.num_iteration += 1 if self.swav_criterion.use_queue: self.swav_criterion.update_emb_queue(output[0].detach()) return loss
def __repr__(self): repr_dict = { "name": self._get_name(), "epsilon": self.loss_config.epsilon, "use_double_precision": self.loss_config.use_double_precision, "local_queue_length": self.loss_config.queue.local_queue_length, "temperature": self.loss_config.temperature, "num_prototypes": self.loss_config.num_prototypes, "num_crops": self.loss_config.num_crops, "nmb_sinkhornknopp_iters": self.loss_config.num_iters, "embedding_dim": self.loss_config.embedding_dim, "temp_hard_assignment_iters": self.loss_config.temp_hard_assignment_iters, } return pprint.pformat(repr_dict, indent=2)
[docs]class SwAVCriterion(nn.Module): """ This criterion is used by the SwAV paper https://arxiv.org/abs/2006.09882 by Caron et al. See the paper for more details about the loss. Config params: embedding_dim (int): the projection head output dimension temperature (float): temperature to be applied to the logits num_iters (int): number of sinkhorn algorithm iterations to make epsilon (float): see the paper for details num_crops (int): number of crops used crops_for_assign (List[int]): what crops to use for assignment num_prototypes (List[int]): number of prototypes temp_hard_assignment_iters (int): whether to do hard assignment for the initial few iterations output_dir (str): for dumping the debugging info in case loss becomes NaN local_queue_length (int): length of queue per gpu """ def __init__( self, temperature: float, crops_for_assign: List[int], num_crops: int, num_iters: int, epsilon: float, use_double_prec: bool, num_prototypes: List[int], local_queue_length: int, embedding_dim: int, temp_hard_assignment_iters: int, output_dir: str, ): super(SwAVCriterion, self).__init__() self.use_gpu = get_cuda_device_index() > -1 self.temperature = temperature self.crops_for_assign = crops_for_assign self.num_crops = num_crops self.nmb_sinkhornknopp_iters = num_iters self.epsilon = epsilon self.use_double_prec = use_double_prec self.num_prototypes = num_prototypes self.nmb_heads = len(self.num_prototypes) self.embedding_dim = embedding_dim self.temp_hard_assignment_iters = temp_hard_assignment_iters self.local_queue_length = local_queue_length self.dist_rank = get_rank() self.world_size = get_world_size() self.log_softmax = nn.LogSoftmax(dim=1).cuda() self.softmax = nn.Softmax(dim=1).cuda() self.register_buffer("num_iteration", torch.zeros(1, dtype=int)) self.use_queue = False if local_queue_length > 0: self.initialize_queue() self.output_dir = output_dir
[docs] def distributed_sinkhornknopp(self, Q: torch.Tensor): """ Apply the distributed sinknorn optimization on the scores matrix to find the assignments """ eps_num_stab = 1e-12 with torch.no_grad(): # remove potential infs in Q # replace the inf entries with the max of the finite entries in Q mask = torch.isinf(Q) ind = torch.nonzero(mask) if len(ind) > 0: for i in ind: Q[i[0], i[1]] = 0 m = torch.max(Q) for i in ind: Q[i[0], i[1]] = m sum_Q = torch.sum(Q, dtype=Q.dtype) all_reduce_sum(sum_Q) Q /= sum_Q k = Q.shape[0] n = Q.shape[1] N = self.world_size * Q.shape[1] # we follow the u, r, c and Q notations from # https://arxiv.org/abs/1911.05371 r = torch.ones(k) / k c = torch.ones(n) / N if self.use_double_prec: r, c = r.double(), c.double() if self.use_gpu: r = r.cuda(non_blocking=True) c = c.cuda(non_blocking=True) for _ in range(self.nmb_sinkhornknopp_iters): u = torch.sum(Q, dim=1, dtype=Q.dtype) all_reduce_sum(u) # for numerical stability, add a small epsilon value # for non-zero Q values. if len(torch.nonzero(u == 0)) > 0: Q += eps_num_stab u = torch.sum(Q, dim=1, dtype=Q.dtype) all_reduce_sum(u) u = r / u # remove potential infs in "u" # replace the inf entries with the max of the finite entries in "u" mask = torch.isinf(u) ind = torch.nonzero(mask) if len(ind) > 0: for i in ind: u[i[0]] = 0 m = torch.max(u) for i in ind: u[i[0]] = m Q *= u.unsqueeze(1) Q *= (c / torch.sum(Q, dim=0, dtype=Q.dtype)).unsqueeze(0) Q = (Q / torch.sum(Q, dim=0, keepdim=True, dtype=Q.dtype)).t().float() # hard assignment if self.num_iteration < self.temp_hard_assignment_iters: index_max = torch.max(Q, dim=1)[1] Q.zero_() Q.scatter_(1, index_max.unsqueeze(1), 1) return Q
[docs] def forward(self, scores: torch.Tensor, head_id: int): assert scores.shape[0] % self.num_crops == 0 bs = scores.shape[0] // self.num_crops total_loss = 0 n_term_loss = 0 # 2 big crops are normally used for the assignment for i, crop_id in enumerate(self.crops_for_assign): with torch.no_grad(): scores_this_crop = scores[bs * crop_id : bs * (crop_id + 1)] if self.use_queue: queue = getattr(self, "local_queue" + str(head_id))[i].clone() scores_this_crop = torch.cat((scores_this_crop, queue)) if self.use_double_prec: assignments = torch.exp( scores_this_crop.double() / np.float64(self.epsilon) ).t() assignments = assignments.double() else: assignments = scores_this_crop / self.epsilon # use the log-sum-exp trick for numerical stability. M = torch.max(assignments) all_reduce_max(M) assignments -= M assignments = torch.exp(assignments).t() assignments = self.distributed_sinkhornknopp(assignments)[:bs] idx_crop_pred = np.delete(np.arange(self.num_crops), crop_id) loss = 0 for p in idx_crop_pred: if self.use_double_prec: loss -= torch.mean( torch.sum( assignments * self.log_softmax( scores[bs * p : bs * (p + 1)].double() / np.float64(self.temperature) ), dim=1, dtype=assignments.dtype, ) ) else: loss -= torch.mean( torch.sum( assignments * self.log_softmax( scores[bs * p : bs * (p + 1)] / self.temperature ), dim=1, dtype=assignments.dtype, ) ) loss /= len(idx_crop_pred) total_loss += loss n_term_loss += 1 # stop training if NaN appears and log the output to help debugging # TODO (prigoyal): extract the logic to be common for all losses # debug_state() method that all losses can override if torch.isnan(loss): logging.info( f"Infinite Loss or NaN. Loss value: {loss}, rank: {self.dist_rank}" ) scores_output_file = os.path.join( self.output_dir, "rank" + str(self.dist_rank) + "_scores" + str(i) + ".pth", ) assignments_out_file = os.path.join( self.output_dir, "rank" + str(self.dist_rank) + "_assignments" + str(i) + ".pth", ) with PathManager.open(scores_output_file, "wb") as fwrite: torch.save(scores, fwrite) with PathManager.open(assignments_out_file, "wb") as fwrite: torch.save(assignments, fwrite) logging.info(f"Saved the scores matrix to: {scores_output_file}") logging.info(f"Saved the assignment matrix to: {assignments_out_file}") total_loss /= n_term_loss return total_loss
[docs] def update_emb_queue(self, emb): with torch.no_grad(): bs = len(emb) // self.num_crops for i, crop_id in enumerate(self.crops_for_assign): queue = self.local_emb_queue[i] queue[bs:] = queue[:-bs].clone() queue[:bs] = emb[crop_id * bs : (crop_id + 1) * bs] self.local_emb_queue[i] = queue
[docs] def compute_queue_scores(self, head): with torch.no_grad(): for crop_id in range(len(self.crops_for_assign)): for i in range(head.nmb_heads): scores = getattr(head, "prototypes" + str(i))( self.local_emb_queue[crop_id] ) getattr(self, "local_queue" + str(i))[crop_id] = scores
[docs] def initialize_queue(self): for i in range(self.nmb_heads): init_queue = ( torch.rand( len(self.crops_for_assign), self.local_queue_length, self.num_prototypes[i], ) * 2 - 1 ) self.register_buffer("local_queue" + str(i), init_queue) stdv = 1.0 / math.sqrt(self.embedding_dim / 3) init_queue = ( torch.rand( len(self.crops_for_assign), self.local_queue_length, self.embedding_dim ) .mul_(2 * stdv) .add_(-stdv) ) self.register_buffer("local_emb_queue", init_queue)
def __repr__(self): repr_dict = { "name": self._get_name(), "use_queue": self.use_queue, "local_queue_length": self.local_queue_length, "temperature": self.temperature, "num_prototypes": self.num_prototypes, "num_crops": self.num_crops, "nmb_sinkhornknopp_iters": self.nmb_sinkhornknopp_iters, "embedding_dim": self.embedding_dim, "temp_hard_assignment_iters": self.temp_hard_assignment_iters, } return pprint.pformat(repr_dict, indent=2)