Source code for vissl.data.ssl_transforms.shuffle_img_patches

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import logging
from typing import Any, Dict

import numpy as np
import torch
from classy_vision.dataset.transforms import register_transform
from classy_vision.dataset.transforms.classy_transform import ClassyTransform
from fvcore.common.file_io import PathManager
from vissl.utils.io import load_file


[docs]@register_transform("ShuffleImgPatches") class ShuffleImgPatches(ClassyTransform): """ This transform is used to shuffle the list of tensors (usually image patches of shape C x H x W) according to a randomly selected permutation from a pre-defined set of permutations. This is a common operation used in Jigsaw approach https://arxiv.org/abs/1603.09246 """
[docs] def __init__(self, perm_file: str): """ Args: perm_file (string): path to the file containing pre-defined permutations. """ self.perm_file = perm_file self.perm_loaded = False self.perms = None
def _load_perms(self): assert PathManager.exists( self.perm_file ), f"Permutation file NOT found: {self.perm_file}" logging.info(f"Loading permutation: {self.perm_file}") self.perms = load_file(self.perm_file) if np.min(self.perms) == 1: self.perms = self.perms - 1 logging.info(f"Loaded perm: {self.perms.shape}") self.perm_loaded = True
[docs] def __call__(self, input_patches): """ The interface `__call__` is used to transform the input data. It should contain the actual implementation of data transform. Args: input_patches (List[torch.tensor]): list of torch tensors """ if not self.perm_loaded: self._load_perms() perm_index = np.random.randint(self.perms.shape[0]) shuffled_patches = [ torch.FloatTensor(input_patches[i]) for i in self.perms[perm_index] ] # num_towers x C x H x W input_data = torch.stack(shuffled_patches) out_label = torch.Tensor([perm_index]).long() return input_data, out_label
[docs] @classmethod def from_config(cls, config: Dict[str, Any]) -> "ShuffleImgPatches": """ Instantiates ShuffleImgPatches from configuration. Args: config (Dict): arguments for for the transform Returns: ShuffleImgPatches instance. """ assert "perm_file" in config, "Please specify the perm_file" return cls(perm_file=config["perm_file"])