Source code for vissl.data.ssl_transforms.img_pil_to_patches_and_image

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Any, Dict

import numpy as np
import PIL
import torchvision.transforms as pth_transforms
from classy_vision.dataset.transforms import register_transform
from classy_vision.dataset.transforms.classy_transform import ClassyTransform
from vissl.data.ssl_transforms.img_patches_tensor import ImgPatchesFromTensor


[docs]@register_transform("ImgPilToPatchesAndImage") class ImgPilToPatchesAndImage(ClassyTransform): """ Convert an input PIL image to Patches and Image This transform was proposed in PIRL - https://arxiv.org/abs/1912.01991. Input: PIL Image Returns: list containing N+1 elements + zeroth element: a RandomResizedCrop of the image + remainder: N patches extracted uniformly from a RandomResizedCrop """
[docs] def __init__( self, crop_scale_image=(0.08, 1.0), crop_size_image=224, crop_scale_patches=(0.6, 1.0), crop_size_patches=255, permute_patches=True, num_patches=9, ): """ Args: crop_scale_image (tuple of floats): scale for RandomResizedCrop of image crop_size_image (int): size for RandomResizedCrop of image crop_scale_patches (tuple of floats): scale for RandomResizedCrop of patches crop_size_patches (int): size for RandomResizedCrop of patches permute_patches (bool): permute the patches in any order num_patches (int): number of patches to create. should be a square integer. """ assert isinstance(num_patches, int) splits_per_side = np.sqrt(num_patches) assert ( splits_per_side ** 2 == num_patches ), "Num patches must be a perfect square integer." self.num_patches = num_patches assert len(crop_scale_image) == 2 assert len(crop_scale_patches) == 2 self.crop_image_tx = pth_transforms.RandomResizedCrop( scale=crop_scale_image, size=crop_size_image ) self.crop_patches_tx = pth_transforms.RandomResizedCrop( scale=crop_scale_patches, size=crop_size_patches ) self.image_to_patch_tx = ImgPatchesFromTensor(num_patches=num_patches)
def __call__(self, image): cropped_image = self.crop_image_tx(image) cropped_patch_image = self.crop_patches_tx(image) # image to patch accepts a tensor or array image_array = np.array(cropped_patch_image).transpose(2, 0, 1) patch_arrays = self.image_to_patch_tx(image_array) patches = [PIL.Image.fromarray(x.transpose(1, 2, 0)) for x in patch_arrays] # permute patches in any order perm_order = np.random.permutation(self.num_patches) patches = [patches[x] for x in perm_order] # make image the first member of the list patches.insert(0, cropped_image) return patches
[docs] @classmethod def from_config(cls, config: Dict[str, Any]) -> "ImgPilToPatchesAndImage": """ Instantiates ImgPilToPatchesAndImage from configuration. Args: config (Dict): arguments for for the transform Returns: ImgPilToPatchesAndImage instance. """ return cls(**config)