# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
from vissl.data.collators import register_collator
[docs]@register_collator("patch_and_image_collator")
def patch_and_image_collator(batch):
"""
This collator is used in PIRL approach.
batch contains two keys "data" and "label".
- data is a list of N+1 elements.
1st element is the "image" and remainder N
are patches.
- label is an integer (image index in the dataset)
We collate this to
image: batch_size tensor containing images
patches: N * batch_size tensor containing patches
"""
assert "data" in batch[0], "data not found in sample"
assert "label" in batch[0], "label not found in sample"
batch_size = len(batch)
data = [x["data"] for x in batch]
# labels are repeated N+1 times but they are the same
labels = [x["label"][0] for x in batch]
labels = torch.LongTensor(labels).squeeze()
# data valid is repeated N+1 times but they are the same
data_valid = torch.BoolTensor([x["data_valid"][0] for x in batch])
images = torch.stack([data[i][0] for i in range(batch_size)])
patch_list = []
for idx in range(batch_size):
patch_list.extend(data[idx][1:])
patches = torch.stack(patch_list)
output_batch = {
"images": [images],
"patches": [patches],
"label": [labels],
"data_valid": [data_valid],
}
return output_batch