Source code for vissl.data.ssl_transforms.img_pil_to_tensor

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import numpy as np
import torch
from classy_vision.dataset.transforms import register_transform
from classy_vision.dataset.transforms.classy_transform import ClassyTransform
from PIL import Image


try:
    import accimage
except ImportError:
    accimage = None


# See /pytorch/vision/torchvision/transforms/functional.py
def _is_numpy(img):
    return isinstance(img, np.ndarray)


def _is_pil_image(img):
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)


[docs]@register_transform("ImgToTensor") class ImgToTensor(ClassyTransform): """ The Transform that overrides the PyTorch transform to provide better transformation speed. # credits: mannatsingh@fb.com """ def __call__(self, img: Image): assert _is_numpy(img) or _is_pil_image(img) arr = np.asarray(img) arr = np.moveaxis(arr, -1, 0) # HWC to CHW format arr = arr.astype(np.float32) / 255 return torch.from_numpy(arr)