Source code for vissl.data.ssl_transforms.img_pil_multicrop_random_apply

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import logging
from typing import Any, Dict, List

import torchvision.transforms as pth_transforms
from classy_vision.dataset.transforms import build_transform, register_transform
from classy_vision.dataset.transforms.classy_transform import ClassyTransform
from PIL import Image


[docs]@register_transform("ImgPilMultiCropRandomApply") class ImgPilMultiCropRandomApply(ClassyTransform): """ Apply a list of transforms on multi-crop input. The transforms are Randomly applied to each crop using the specified probability. This is used in BYOL https://arxiv.org/pdf/2006.07733.pdf Multi-crops are several crops of a given image. This is most commonly used in contrastive learning. For example SimCLR, SwAV approaches use multi-crop input. """
[docs] def __init__(self, transforms: List[Dict[str, Any]], prob: float): """ Args: transforms ( List(tranforms) ): List of transforms that should be applied to each crop. prob (List(float)): Probability of RandomApply for the transforms composition on each crop. example: for 2 crop in BYOL, for solarization: prob = [0.0, 0.2] """ self.prob = prob self.transforms = None self._build_transform(transforms)
def _build_transform(self, transforms: List[Dict[str, Any]]): out_transforms = [] for transform_config in transforms: out_transforms.append(build_transform(transform_config)) out_transform = pth_transforms.Compose(out_transforms) self.transforms = [] for idx in range(len(self.prob)): self.transforms.append( pth_transforms.RandomApply([out_transform], p=self.prob[idx]) ) def __call__(self, image_list: List[Image.Image]): assert isinstance(image_list, list), "image_list must be a list" assert len(image_list) == len(self.prob) assert len(image_list) == len(self.transforms) output = [] for idx in range(len(image_list)): output.append(self.transforms[idx](image_list[idx])) return output
[docs] @classmethod def from_config(cls, config: Dict[str, Any]) -> "ImgPilMultiCropRandomApply": """ Instantiates ImgPilMultiCropRandomApply from configuration. Args: config (Dict): arguments for for the transform Returns: ImgPilMultiCropRandomApply instance. """ transforms = config.get("transforms", []) prob = config.get("prob", []) logging.info(f"ImgPilMultiCropRandomApply | Using prob: {prob}") return cls(transforms=transforms, prob=prob)