Source code for vissl.data.collators.multicrop_collator

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import torch
from vissl.data.collators import register_collator


[docs]@register_collator("multicrop_collator") def multicrop_collator(batch): """ This collator is used in SwAV approach. The collators collates the batch for the following input (assuming k-copies of image): Input: batch: Example batch = [ {"data" : [img1_0, ..., img1_k], ..}, {"data" : [img2_0, ..., img2_k], ...}, ... ] Returns: Example output: output = [ { "data": torch.tensor([img1_0, ..., imgN_0], [img1_k, ..., imgN_k]) .. }, ] """ assert "data" in batch[0], "data not found in sample" assert "label" in batch[0], "label not found in sample" data = [x["data"] for x in batch] labels = [torch.tensor(x["label"]) for x in batch] data_valid = [torch.tensor(x["data_valid"]) for x in batch] data_idx = [torch.tensor(x["data_idx"]) for x in batch] num_duplicates, num_images = len(data[0]), len(data) output_data, output_label, output_data_valid, output_data_idx = [], [], [], [] for pos in range(num_duplicates): _output_data = [] for idx in range(num_images): _output_data.append(data[idx][pos]) output_label.append(labels[idx][pos]) output_data_valid.append(data_valid[idx][pos]) output_data_idx.append(data_idx[idx][pos]) output_data.append(torch.stack(_output_data)) output_batch = { "data": [output_data], "label": [torch.stack(output_label)], "data_valid": [torch.stack(output_data_valid)], "data_idx": [torch.stack(output_data_idx)], } return output_batch