Source code for vissl.data.collators.moco_collator

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Any, Dict, List

import torch
from vissl.data.collators import register_collator


[docs]@register_collator("moco_collator") def moco_collator(batch: List[Dict[str, Any]]) -> Dict[str, List[torch.Tensor]]: """ This collator is specific to MoCo approach http://arxiv.org/abs/1911.05722 The collators collates the batch for the following input (assuming k-copies of image): Input: batch: Example batch = [ {"data" : [img1_0, ..., img1_k], ..}, {"data" : [img2_0, ..., img2_k], ...}, ... ] Returns: Example output: output = [ { "data": torch.tensor([img1_0, ..., img1_k], [img2_0, ..., img2_k]) .. }, ] Dimensions become [num_positives x Batch x C x H x W] """ assert "data" in batch[0], "data not found in sample" assert "label" in batch[0], "label not found in sample" data = [torch.stack(x["data"]) for x in batch] labels = [torch.tensor(x["label"]) for x in batch] data_valid = [torch.tensor(x["data_valid"]) for x in batch] data_idx = [torch.tensor(x["data_idx"]) for x in batch] output_batch = { "data": [torch.stack(data).squeeze()[:, 0, :, :, :].squeeze()], # encoder "data_momentum": [ torch.stack(data).squeeze()[:, 1, :, :, :].squeeze() ], # momentum encoder "label": [torch.stack(labels).squeeze()], "data_valid": [torch.stack(data_valid).squeeze()], "data_idx": [torch.stack(data_idx).squeeze()], } return output_batch