Source code for vissl.data.collators.siamese_collator

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import torch
from vissl.data.collators import register_collator


[docs]@register_collator("siamese_collator") def siamese_collator(batch): """ This collator is used in Jigsaw approach. Input: batch: Example batch = [ {"data": [img1,], "label": [lbl1, ]}, #img1 {"data": [img2,], "label": [lbl2, ]}, #img2 . . {"data": [imgN,], "label": [lblN, ]}, #imgN ] where: img{x} is a tensor of size: num_towers x C x H x W lbl{x} is an integer Returns: Example output: output = [ { "data": torch.tensor([img1_0, ..., imgN_0]) .. }, ] where the output is of dimension: (N * num_towers) x C x H x W """ assert "data" in batch[0], "data not found in sample" assert "label" in batch[0], "label not found in sample" num_data_sources = len(batch[0]["data"]) batch_size = len(batch) data = [x["data"] for x in batch] labels = [x["label"] for x in batch] data_valid = torch.BoolTensor([x["data_valid"][0] for x in batch]) output_data, output_label = [], [] for idx in range(num_data_sources): # each image is of shape: num_towers x C x H x W # num_towers x C x H x W -> N x num_towers x C x H x W idx_data = torch.stack([data[i][idx] for i in range(batch_size)]) idx_labels = [labels[i][idx] for i in range(batch_size)] batch_size, num_siamese_towers, channels, height, width = idx_data.size() # N x num_towers x C x H x W -> (N * num_towers) x C x H x W idx_data = idx_data.view( batch_size * num_siamese_towers, channels, height, width ) output_data.append(idx_data) should_flatten = False if idx_labels[0].ndim == 1: should_flatten = True idx_labels = torch.stack(idx_labels).squeeze() if should_flatten: idx_labels = idx_labels.flatten() output_label.append(idx_labels) output_batch = { "data": output_data, "label": output_label, "data_valid": [data_valid], } return output_batch