Source code for vissl.data.collators.targets_one_hot_default_collator

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import torch
from vissl.data.collators import register_collator


[docs]def convert_to_one_hot(pos_lbl, neg_lbl, num_classes: int) -> torch.Tensor: """ This function converts target class indices to one-hot vectors, given the number of classes. -> 1 for positive labels, -> 0 for negative and -> -1 for ignore labels. """ one_hot_targets = torch.LongTensor(num_classes).zero_() - 1 if isinstance(pos_lbl, list) and (len(pos_lbl) > 0): assert ( max(pos_lbl) < num_classes ), "Class Index must be less than number of classes" one_hot_targets.scatter_(0, torch.Tensor(pos_lbl).long(), 1) if isinstance(neg_lbl, list) and (len(neg_lbl) > 0): assert ( max(neg_lbl) < num_classes ), "Class Index must be less than number of classes" one_hot_targets.scatter_(0, torch.Tensor(neg_lbl).long(), 0) return one_hot_targets.squeeze()
[docs]@register_collator("targets_one_hot_default_collator") def targets_one_hot_default_collator(batch, num_classes: int): """ The collators collates the batch for the following input: Input: input : [[img0, ..., imgk]] label: [ [[1, 3, 6], [4, 9]] [[1, 5], [6, 8, 10, 11]] ..... ] Output: output: [img0, img0, .....,] label: [[0, 1, 0, 1, ..., -1, 0, 0, 1], [0, 1, 0, 0, 0, 1, 0], ....] """ assert num_classes > 0, "num_classes not specified for the collator" assert "data" in batch[0], "data not found in sample" assert "label" in batch[0], "label not found in sample" assert len(batch[0]["data"]) == 1, ( "This collator supports only 1 data source. " "Please extend it to support many data sources." ) assert ( len(batch[0]["label"][0]) == 2 ), "This collator takes positive and negative labels separately. Please modify it to suit your needs." data = torch.stack([x["data"][0] for x in batch]) data_valid = torch.stack([torch.tensor(x["data_valid"][0]) for x in batch]) data_idx = torch.stack([torch.tensor(x["data_idx"][0]) for x in batch]) labels = [x["label"][0] for x in batch] output_labels = [] for idx in range(data.shape[0]): # import pdb; pdb.set_trace() output_labels.append( convert_to_one_hot(labels[idx][0], labels[idx][1], num_classes) ) output_batch = { "data": [data], "label": [torch.stack(output_labels)], "data_valid": [data_valid], "data_idx": [data_idx], } return output_batch