Source code for vissl.losses.deepclusterv2_loss

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import logging
import math
import pprint

import torch
import torch.distributed as dist
from classy_vision.generic.distributed_util import (
    all_reduce_sum,
    gather_from_all,
    get_rank,
    get_world_size,
)
from classy_vision.losses import ClassyLoss, register_loss
from torch import nn
from vissl.utils.hydra_config import AttrDict
from vissl.utils.misc import get_indices_sparse


[docs]@register_loss("deepclusterv2_loss") class DeepClusterV2Loss(ClassyLoss): """ Loss used for DeepClusterV2 approach as provided in SwAV paper https://arxiv.org/abs/2006.09882 Config params: DROP_LAST (bool): automatically inferred from DATA.TRAIN.DROP_LAST BATCHSIZE_PER_REPLICA (int): 256 # automatically inferred from DATA.TRAIN.BATCHSIZE_PER_REPLICA num_crops (int): 2 # automatically inferred from DATA.TRAIN.TRANSFORMS temperature (float): 0.1 num_clusters (List[int]): [3000, 3000, 3000] kmeans_iters (int): 10 crops_for_mb: [0] embedding_dim: 128 num_train_samples (int): -1 # @auto-filled """ def __init__(self, loss_config: AttrDict): super().__init__() self.loss_config = loss_config size_dataset = self.loss_config.num_train_samples size_memory_per_process = int(math.ceil(size_dataset * 1.0 / get_world_size())) if self.loss_config.DROP_LAST: size_memory_per_process -= ( size_memory_per_process % self.loss_config.BATCHSIZE_PER_REPLICA ) self.nmb_mbs = len(self.loss_config.memory_params.crops_for_mb) self.nmb_heads = len(self.loss_config.num_clusters) self.num_clusters = self.loss_config.num_clusters self.embedding_dim = self.loss_config.memory_params.embedding_dim self.crops_for_mb = self.loss_config.memory_params.crops_for_mb self.nmb_unique_idx = self.loss_config.BATCHSIZE_PER_REPLICA self.num_crops = self.loss_config.num_crops self.temperature = self.loss_config.temperature self.nmb_kmeans_iters = self.loss_config.kmeans_iters self.start_idx = 0 self.register_buffer( "local_memory_embeddings", torch.zeros(self.nmb_mbs, size_memory_per_process, self.embedding_dim), ) self.register_buffer( "local_memory_index", torch.zeros(size_memory_per_process).long() ) self.register_buffer( "assignments", -100 * torch.ones(self.nmb_heads, size_dataset).long() ) for i, k in enumerate(self.loss_config.num_clusters): self.register_buffer( "centroids" + str(i), torch.rand(k, self.embedding_dim) ) self.cross_entropy_loss = nn.CrossEntropyLoss(ignore_index=-100)
[docs] @classmethod def from_config(cls, loss_config: AttrDict): """ Instantiates DeepClusterV2Loss from configuration. Args: loss_config: configuration for the loss Returns: DeepClusterV2Loss instance. """ return cls(loss_config)
[docs] def forward(self, output: torch.Tensor, idx: int): output = nn.functional.normalize(output, dim=1, p=2) loss = 0 for i in range(self.nmb_heads): scores = ( torch.mm(output, getattr(self, "centroids" + str(i)).t()) / self.temperature ) loss += self.cross_entropy_loss(scores, self.assignments[i][idx]) loss /= self.nmb_heads self.update_memory_bank(output, idx) return loss
[docs] def init_memory(self, dataloader, model): logging.info(f"Rank: {get_rank()}, Start initializing memory banks") start_idx = 0 with torch.no_grad(): for inputs in dataloader: nmb_unique_idx = len(inputs["data_idx"][0]) // self.num_crops index = inputs["data_idx"][0][:nmb_unique_idx].cuda(non_blocking=True) # get embeddings outputs = [] for crop_idx in self.crops_for_mb: inp = inputs["data"][0][crop_idx].cuda(non_blocking=True) outputs.append(nn.functional.normalize(model(inp)[0], dim=1, p=2)) # fill the memory bank self.local_memory_index[start_idx : start_idx + nmb_unique_idx] = index for mb_idx, embeddings in enumerate(outputs): self.local_memory_embeddings[mb_idx][ start_idx : start_idx + nmb_unique_idx ] = embeddings start_idx += nmb_unique_idx logging.info( f"Rank: {get_rank()}, Memory banks initialized: " "full first forward pass done" )
[docs] def update_memory_bank(self, emb, idx): nmb_unique_idx = len(idx) // self.num_crops idx = idx[:nmb_unique_idx] self.local_memory_index[self.start_idx : self.start_idx + nmb_unique_idx] = idx for i, crop_idx in enumerate(self.crops_for_mb): self.local_memory_embeddings[i][ self.start_idx : self.start_idx + nmb_unique_idx ] = emb[crop_idx * nmb_unique_idx : (crop_idx + 1) * nmb_unique_idx] self.start_idx += nmb_unique_idx
[docs] def cluster_memory(self): self.start_idx = 0 j = 0 with torch.no_grad(): for i_K, K in enumerate(self.num_clusters): # run distributed k-means # init centroids with elements from memory bank of rank 0 centroids = torch.empty(K, self.embedding_dim).cuda(non_blocking=True) if get_rank() == 0: random_idx = torch.randperm(len(self.local_memory_embeddings[j]))[ :K ] assert len(random_idx) >= K, "please reduce the number of centroids" centroids = self.local_memory_embeddings[j][random_idx] dist.broadcast(centroids, 0) for n_iter in range(self.nmb_kmeans_iters + 1): # E step dot_products = torch.mm( self.local_memory_embeddings[j], centroids.t() ) _, assignments = dot_products.max(dim=1) # finish if n_iter == self.nmb_kmeans_iters: break # M step where_helper = get_indices_sparse(assignments.cpu().numpy()) counts = torch.zeros(K).cuda(non_blocking=True).int() emb_sums = torch.zeros(K, self.embedding_dim).cuda( non_blocking=True ) for k in range(len(where_helper)): if len(where_helper[k][0]) > 0: emb_sums[k] = torch.sum( self.local_memory_embeddings[j][where_helper[k][0]], dim=0, ) counts[k] = len(where_helper[k][0]) all_reduce_sum(counts) mask = counts > 0 all_reduce_sum(emb_sums) centroids[mask] = emb_sums[mask] / counts[mask].unsqueeze(1) # normalize centroids centroids = nn.functional.normalize(centroids, dim=1, p=2) getattr(self, "centroids" + str(i_K)).copy_(centroids) # gather the assignments assignments_all = gather_from_all(assignments) indexes_all = gather_from_all(self.local_memory_index) self.assignments[i_K] = -100 self.assignments[i_K][indexes_all] = assignments_all j = (j + 1) % self.nmb_mbs logging.info(f"Rank: {get_rank()}, clustering of the memory bank done")
def __repr__(self): repr_dict = {"name": self._get_name()} return pprint.pformat(repr_dict, indent=2)