# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import logging
import math
import pprint
import numpy as np
import torch
from classy_vision.generic.distributed_util import (
all_reduce_sum,
get_cuda_device_index,
get_world_size,
is_distributed_training_run,
)
from classy_vision.losses import ClassyLoss, register_loss
from torch import nn
from vissl.utils.hydra_config import AttrDict
[docs]@register_loss("swav_momentum_loss")
class SwAVMomentumLoss(ClassyLoss):
"""
This loss extends the SwAV loss proposed in paper https://arxiv.org/abs/2006.09882
by Caron et al. The loss combines the benefits of using the SwAV approach
with the momentum encoder as used in MoCo.
Config params:
momentum (float): for the momentum encoder
momentum_eval_mode_iter_start (int): from what iteration should the momentum encoder
network be in eval mode
embedding_dim (int): the projection head output dimension
temperature (float): temperature to be applied to the logits
use_double_precision (bool): whether to use double precision for the loss.
This could be a good idea to avoid NaNs.
normalize_last_layer (bool): whether to normalize the last layer
num_iters (int): number of sinkhorn algorithm iterations to make
epsilon (float): see the paper for details
num_crops (int): number of crops used
crops_for_assign (List[int]): what crops to use for assignment
num_prototypes (List[int]): number of prototypes
queue:
queue_length (int): number of features to store and used in the scores
start_iter (int): when to start using the queue for the scores
local_queue_length (int): length of queue per gpu
"""
def __init__(self, loss_config: AttrDict):
super().__init__()
self.loss_config = loss_config
self.momentum_encoder = None
self.checkpoint = None
self.momentum_scores = None
self.momentum_embeddings = None
self.is_distributed = is_distributed_training_run()
self.use_gpu = get_cuda_device_index() > -1
self.softmax = nn.Softmax(dim=1)
# keep track of number of iterations
self.register_buffer("num_iteration", torch.zeros(1, dtype=int))
# for queue
self.use_queue = False
if self.loss_config.queue.local_queue_length > 0:
self.initialize_queue()
[docs] @classmethod
def from_config(cls, loss_config: AttrDict):
"""
Instantiates SwAVMomentumLoss from configuration.
Args:
loss_config: configuration for the loss
Returns:
SwAVMomentumLoss instance.
"""
return cls(loss_config)
[docs] def initialize_queue(self):
for i, nmb_proto in enumerate(self.loss_config.num_prototypes):
init_queue = (
torch.rand(
len(self.loss_config.crops_for_assign),
self.loss_config.queue.local_queue_length,
nmb_proto,
)
* 2
- 1
)
self.register_buffer("local_queue" + str(i), init_queue)
stdv = 1.0 / math.sqrt(self.loss_config.embedding_dim / 3)
init_queue = (
torch.rand(
len(self.loss_config.crops_for_assign),
self.loss_config.queue.local_queue_length,
self.loss_config.embedding_dim,
)
.mul_(2 * stdv)
.add_(-stdv)
)
self.register_buffer("local_emb_queue", init_queue)
[docs] def load_state_dict(self, state_dict, *args, **kwargs):
"""
Restore the loss state given a checkpoint
Args:
state_dict (serialized via torch.save)
"""
# If the encoder has been allocated, use the normal pytorch restoration
if self.momentum_encoder is None:
self.checkpoint = state_dict
logging.info("Storing the checkpoint for later use")
else:
logging.info("Restoring checkpoint")
super().load_state_dict(state_dict, *args, **kwargs)
[docs] def forward(self, output: torch.Tensor, *args, **kwargs):
self.use_queue = (
self.loss_config.queue.local_queue_length > 0
and self.num_iteration >= self.loss_config.queue.start_iter
)
if self.use_queue:
if self.is_distributed:
self.compute_queue_scores(self.momentum_encoder.module.heads[0])
else:
self.compute_queue_scores(self.momentum_encoder.heads[0])
loss = 0
for head_id, proto_scores in enumerate(output[1:]):
bs = proto_scores.shape[0] // self.loss_config.num_crops
sub_loss = 0
for j, crop_id in enumerate(self.loss_config.crops_for_assign):
with torch.no_grad():
scores_this_crop = self.momentum_scores[head_id][
j * bs : (j + 1) * bs
]
if self.use_queue:
queue = getattr(self, "local_queue" + str(head_id))[j].clone()
scores_this_crop = torch.cat((scores_this_crop, queue))
assignments = torch.exp(
scores_this_crop / self.loss_config.epsilon
).t()
assignments = self.distributed_sinkhornknopp(assignments)[:bs]
idx_crop_pred = np.delete(
np.arange(self.loss_config.num_crops), crop_id
)
subsubloss = 0
for p in idx_crop_pred:
subsubloss -= torch.mean(
torch.sum(
assignments
* torch.log(
self.softmax(
proto_scores[bs * p : bs * (p + 1)]
/ self.loss_config.temperature
)
),
dim=1,
)
)
sub_loss += subsubloss / len(idx_crop_pred)
loss += sub_loss / len(self.loss_config.crops_for_assign)
loss /= len(output) - 1
self.num_iteration += 1
if self.use_queue:
self.update_emb_queue()
return loss
def __repr__(self):
repr_dict = {"name": self._get_name()}
return pprint.pformat(repr_dict, indent=2)
[docs] def distributed_sinkhornknopp(self, Q: torch.Tensor):
"""
Apply the distributed sinknorn optimization on the scores matrix to
find the assignments
"""
with torch.no_grad():
sum_Q = torch.sum(Q, dtype=Q.dtype)
all_reduce_sum(sum_Q)
Q /= sum_Q
k = Q.shape[0]
n = Q.shape[1]
N = get_world_size() * Q.shape[1]
# we follow the u, r, c and Q notations from
# https://arxiv.org/abs/1911.05371
r = torch.ones(k) / k
c = torch.ones(n) / N
if self.use_gpu:
r = r.cuda(non_blocking=True)
c = c.cuda(non_blocking=True)
curr_sum = torch.sum(Q, dim=1, dtype=Q.dtype)
all_reduce_sum(curr_sum)
for _ in range(self.loss_config.num_iters):
u = curr_sum
Q *= (r / u).unsqueeze(1)
Q *= (c / torch.sum(Q, dim=0, dtype=Q.dtype)).unsqueeze(0)
curr_sum = torch.sum(Q, dim=1, dtype=Q.dtype)
all_reduce_sum(curr_sum)
return (Q / torch.sum(Q, dim=0, keepdim=True, dtype=Q.dtype)).t().float()
[docs] def update_emb_queue(self):
with torch.no_grad():
bs = len(self.momentum_embeddings) // self.loss_config.num_crops
for i in range(len(self.loss_config.crops_for_assign)):
queue = self.local_emb_queue[i]
queue[bs:] = queue[:-bs].clone()
queue[:bs] = self.momentum_embeddings[i * bs : (i + 1) * bs]
self.local_emb_queue[i] = queue
[docs] def compute_queue_scores(self, head):
with torch.no_grad():
for i in range(len(self.loss_config.crops_for_assign)):
for h in range(head.nmb_heads):
scores = getattr(head, "prototypes" + str(h))(
self.local_emb_queue[i]
)
getattr(self, "local_queue" + str(h))[i] = scores