Source code for vissl.data.ssl_transforms.img_pil_to_lab_tensor

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Any, Dict

import numpy as np
import torch
from classy_vision.dataset.transforms import register_transform
from classy_vision.dataset.transforms.classy_transform import ClassyTransform
from vissl.utils.misc import is_opencv_available


[docs]@register_transform("ImgPil2LabTensor") class ImgPil2LabTensor(ClassyTransform): """ Convert a PIL image to LAB tensor of shape C x H x W This transform was proposed in Colorization - https://arxiv.org/abs/1603.08511 The input image is PIL Image. We first convert it to tensor HWC which has channel order RGB. We then convert the RGB to BGR and use OpenCV to convert the image to LAB. The LAB image is 8-bit image in range > L [0, 255], A [0, 255], B [0, 255]. We rescale it to: L [0, 100], A [-128, 127], B [-128, 127] The output is image torch tensor. """ def __init__(self, indices): self.indices = indices def __call__(self, image): img_tensor = np.array(image) # PIL image tensor is RGB. Convert to BGR img_bgr = img_tensor[:, :, ::-1] img_lab = self._convertbgr2lab(img_bgr.astype(np.uint8)) # convert HWC -> CHW. The image is LAB. img_lab = np.transpose(img_lab, (2, 0, 1)) # torch tensor output img_lab_tensor = torch.from_numpy(img_lab).float() return img_lab_tensor def _convertbgr2lab(self, img): # opencv is not a hard dependency for VISSL so we do the import locally assert ( is_opencv_available() ), "Please install OpenCV using: pip install opencv-python" import cv2 # img is [0, 255] , HWC, BGR format, uint8 type assert len(img.shape) == 3, "Image should have dim H x W x 3" assert img.shape[2] == 3, "Image should have dim H x W x 3" assert img.dtype == np.uint8, "Image should be uint8 type" img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) # 8-bit image range -> L [0, 255], A [0, 255], B [0, 255]. Rescale it to: # L [0, 100], A [-128, 127], B [-128, 127] img_lab = img_lab.astype(np.float32) img_lab[:, :, 0] = (img_lab[:, :, 0] * (100.0 / 255.0)) - 50.0 img_lab[:, :, 1:] = img_lab[:, :, 1:] - 128.0 ############################ debugging #################################### # img_lab_bw = img_lab.copy() # img_lab_bw[:, :, 1:] = 0.0 # img_lab_bgr = cv2.cvtColor(img_lab_bw, cv2.COLOR_Lab2BGR) # img_lab_bgr = img_lab_bgr.astype(np.float32) # img_lab_RGB = img_lab_bgr[:, :, [2, 1, 0]] # BGR to RGB # img_lab_RGB = img_lab_RGB - np.min(img_lab_RGB) # img_lab_RGB /= np.max(img_lab_RGB) + np.finfo(np.float64).eps # plt.imshow(img_lab_RGB) # n = np.random.randint(0, 1000) # np.save(f"/tmp/lab{n}.npy", img_lab_bgr) # print("SAVED!!") ######################### debugging over ################################## return img_lab
[docs] @classmethod def from_config(cls, config: Dict[str, Any]) -> "ImgPil2LabTensor": """ Instantiates ImgPil2LabTensor from configuration. Args: config (Dict): arguments for for the transform Returns: ImgPil2LabTensor instance. """ indices = config.get("indices", []) return cls(indices=indices)