Source code for vissl.data.data_helper

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import logging
import queue

import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data.distributed import DistributedSampler


[docs]def get_mean_image(crop_size): """ Helper function that returns a gray PIL image of the size specified by user. Args: crop_size (int): used to generate (crop_size x crop_size x 3) image. Returns: img: PIL Image """ img = Image.fromarray(128 * np.ones((crop_size, crop_size, 3), dtype=np.uint8)) return img
[docs]class StatefulDistributedSampler(DistributedSampler): """ More fine-grained state DataSampler that uses training iteration and epoch both for shuffling data. PyTorch DistributedSampler only uses epoch for the shuffling and starts sampling data from the start. In case of training on very large data, we train for one epoch only and when we resume training, we want to resume the data sampler from the training iteration. """
[docs] def __init__(self, dataset, batch_size=None): """ Initializes the instance of StatefulDistributedSampler. Random seed is set for the epoch set and data is shuffled. For starting the sampling, use the start_iter (set to 0 or set by checkpointing resuming) to sample data from the remaining images. Args: dataset (Dataset): Pytorch dataset that sampler will shuffle batch_size (int): batch size we want the sampler to sample """ super().__init__(dataset, shuffle=False) self.start_iter = 0 self.batch_size = batch_size self.total_size = len(dataset) - (len(dataset) % self.num_replicas) self.num_samples = self.total_size // self.num_replicas logging.info(f"rank: {self.rank}: Sampler created...")
def __iter__(self): # partition data into num_replicas and optionally shuffle within a rank g = torch.Generator() g.manual_seed(self.epoch) shuffling = torch.randperm(self.num_samples, generator=g).tolist() indices = np.array( list( range( (self.rank * self.num_samples), (self.rank + 1) * self.num_samples ) ) )[shuffling].tolist() # make sure we have correct number of samples per replica assert len(indices) == self.num_samples assert self.batch_size > 0, "batch_size not set for the sampler" # resume the sampler indices = indices[(self.start_iter * self.batch_size) :] return iter(indices)
[docs] def set_start_iter(self, start_iter): """ Set the iteration number from which the sampling should start. This is used to find the marker in the data permutation order from where the sampler should start sampling. """ self.start_iter = start_iter
[docs]class QueueDataset(Dataset): """ This class helps dealing with the invalid images in the dataset by using two queue. One queue is used to enqueue seen and valid images from previous batches. The other queue is used to dequeue. The class is implemented such that the same batch will never have duplicate images. If we can't dequeue a valid image, we return None for that instance. Args: queue_size: size the the queue (ideally set it to batch_size). Both queues will be of the same size """ def __init__(self, queue_size): self.queue_size = queue_size # we create a CPU queue to buffer the valid seen images. We use these # images to replace the invalid images in the minibatch # 2 queues (FIFO) per gpu of size = batch size per gpu (64 img): # a) 1st queue is used only to dequeue seen images. We get images from # this queue only to backfill. # b) 2nd queue is used only to add the new incoming valid seen images self.queue_init = False self.dequeue_images_queue = None self.enqueue_images_queue = None def _init_queues(self): self.dequeue_images_queue = queue.Queue(maxsize=self.queue_size) self.enqueue_images_queue = queue.Queue(maxsize=self.queue_size) self.queue_init = True logging.info(f"QueueDataset enabled. Using queue_size: {self.queue_size}") def _refill_dequeue_buffer(self): dequeue_qsize = self._get_dequeue_buffer_size() for _ in range(self.queue_size - dequeue_qsize): try: self.dequeue_images_queue.put( self.enqueue_images_queue.get(), block=True ) except Exception: continue def _enqueue_valid_image(self, img): if self._get_enqueue_buffer_size() >= self.queue_size: return try: self.enqueue_images_queue.put(img, block=True, timeout=0.1) return except queue.Full: return def _dequeue_valid_image(self): if self._get_dequeue_buffer_size() == 0: return try: return self.dequeue_images_queue.get(block=True, timeout=0.1) except queue.Empty: return None def _get_enqueue_buffer_size(self): return self.enqueue_images_queue.qsize() def _get_dequeue_buffer_size(self): return self.dequeue_images_queue.qsize() def _is_large_image(self, sample): h, w = sample.size if h * w > 10000000: return True return False
[docs] def on_sucess(self, sample): """ If we encounter a successful image and the queue is not full, we store it in the queue. One consideration we make further is: if the image is very large, we don't add it to the queue as otherwise the CPU memory will grow a lot. """ if self._is_large_image(sample): return self._enqueue_valid_image(sample) if self.enqueue_images_queue.full() and not self.dequeue_images_queue.full(): self._refill_dequeue_buffer()
[docs] def on_failure(self): """ If there was a failure in getting the origin image, we look into the queue if there is any valid seen image available. If yes, we dequeue and use this image in place of the failed image. """ sample, is_success = None, False if self._get_dequeue_buffer_size() > 0: sample = self._dequeue_valid_image() if sample is not None: is_success = True return sample, is_success