Source code for vissl.data.ssl_dataset

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import logging

import numpy as np
from classy_vision.generic.distributed_util import get_world_size
from fvcore.common.file_io import PathManager
from torch.utils.data import Dataset
from vissl.data import dataset_catalog
from vissl.data.ssl_transforms import get_transform
from vissl.utils.env import get_machine_local_and_dist_rank


def _convert_lbl_to_long(lbl):
    """
    if the labels are int32, we convert them to int64 since pytorch
    needs a long (int64) type for labels to index. See
    https://discuss.pytorch.org/t/runtimeerror-expected-object-of-scalar-type-long-but-got-scalar-type-float-when-using-crossentropyloss/30542/5  # NOQA
    """
    out_lbl = lbl
    if isinstance(lbl, np.ndarray) and (lbl.dtype == np.int32):
        out_lbl = lbl.astype(np.int64)
    elif isinstance(lbl, list):
        out_lbl = [_convert_lbl_to_long(item) for item in lbl]
    elif isinstance(lbl, np.int32):
        out_lbl = out_lbl.astype(np.int64)
    return out_lbl


[docs]class GenericSSLDataset(Dataset): """ Base Self Supervised Learning Dataset Class. The GenericSSLDataset class is defined to support reading data from multiple data sources. For example: data = [dataset1, dataset2] and the minibatches generated will have the corresponding data from each dataset. For this reason, we also support labels from multiple sources. For example targets = [dataset1 targets, dataset2 targets]. In order to support multiple data sources, the dataset configuration always has list inputs. - DATA_SOURCES, LABEL_SOURCES, DATASET_NAMES, DATA_PATHS, LABEL_PATHS For several data sources, we also support specifying on what dataset the transforms should be applied. By default, apply the transforms on data from all datasets. Args: cfg (AttrDict): configuration defined by user split (str): the dataset split for which we are constructing the Dataset object dataset_source_map (Dict[str, Callable]): The dictionary that maps what data sources are supported and what object to use to read data from those sources. For example: DATASET_SOURCE_MAP = { "disk_filelist": DiskImageDataset, "disk_folder": DiskImageDataset, "synthetic": SyntheticImageDataset, } """ def __init__(self, cfg, split, dataset_source_map): self.split = split self.cfg = cfg self.data_objs = [] self.label_objs = [] self.data_paths = [] self.label_paths = [] self.batchsize_per_replica = self.cfg["DATA"][split]["BATCHSIZE_PER_REPLICA"] self.data_sources = self.cfg["DATA"][split].DATA_SOURCES self.label_sources = self.cfg["DATA"][split].LABEL_SOURCES self.dataset_names = self.cfg["DATA"][split].DATASET_NAMES self.label_type = self.cfg["DATA"][split].LABEL_TYPE self.transform = get_transform(self.cfg["DATA"][split].TRANSFORMS) self._labels_init = False self._verify_data_sources(split, dataset_source_map) self._get_data_files(split) if len(self.label_sources) > 0 and len(self.label_paths) > 0: assert len(self.label_sources) == len(self.label_paths), ( f"len(label_sources) != len(label paths) " f"{len(self.label_sources)} vs. {len(self.label_paths)}" ) for idx in range(len(self.data_sources)): datasource_cls = dataset_source_map[self.data_sources[idx]] self.data_objs.append( datasource_cls( cfg=self.cfg, path=self.data_paths[idx], split=split, dataset_name=self.dataset_names[idx], data_source=self.data_sources[idx], ) ) def _verify_data_sources(self, split, dataset_source_map): """ For each data source, verify that the specified data source is supported in VISSL. See DATASET_SOURCE_MAP for what sources are supported. """ for idx in range(len(self.data_sources)): assert self.data_sources[idx] in dataset_source_map, ( f"Unknown data source: {self.data_sources[idx]}, supported: " f"{list(dataset_source_map.keys())}" ) def _get_data_files(self, split): """ Get the given dataset split (train or test), get the path to the dataset (images and labels). 1. If the user has explicitly specified the data_sources, we simply use those and don't do lookup in the datasets registered with VISSL from the dataset catalog. 2. If the user hasn't specified the path, look for the dataset in the datasets catalog registered with VISSL. For a given list of datasets and a given partition (train/test), we first verify that we have the dataset and the correct source as specified by the user. Then for each dataset in the list, we get the data path (make sure it exists, sources match). For the label file, the file is optional. """ local_rank, _ = get_machine_local_and_dist_rank() self.data_paths, self.label_paths = dataset_catalog.get_data_files( split, dataset_config=self.cfg["DATA"] ) logging.info( f"Rank: {local_rank} split: {split} Data files:\n{self.data_paths}" ) logging.info( f"Rank: {local_rank} split: {split} Label files:\n{self.label_paths}" )
[docs] def load_single_label_file(self, path): """ Load the single data file. We only support user specifying the numpy label files if user is specifying a data_filelist source of labels. To save memory, if the mmap_mode is set to True for loading, we try to load the images in mmap_mode. If it fails, we simply load the labels without mmap """ assert PathManager.isfile(path), f"Path to labels {path} is not a file" assert path.endswith("npy"), "Please specify a numpy file for labels" if self.cfg["DATA"][self.split].MMAP_MODE: try: with PathManager.open(path, "rb") as fopen: labels = np.load(fopen, allow_pickle=True, mmap_mode="r") except ValueError as e: logging.info(f"Could not mmap {path}: {e}. Trying without PathManager") labels = np.load(path, allow_pickle=True, mmap_mode="r") logging.info("Successfully loaded without PathManager") except Exception: logging.info("Could not mmap without PathManager. Trying without mmap") with PathManager.open(path, "rb") as fopen: labels = np.load(fopen, allow_pickle=True) else: with PathManager.open(path, "rb") as fopen: labels = np.load(fopen, allow_pickle=True) return labels
def _load_labels(self): """ Load the labels if the dataset has labels. In self-supervised pre-training task, we don't use labels. However, we use labels for the evaluations of the self-supervised models on the downstream tasks. For labels, two label sources are supported: disk_filelist and disk_folder In case of disk_filelist, we iteratively read labels for each specified file. See load_single_label_file(). In case of disk_folder, we use the ImageFolder object created during the data loading itself. """ local_rank, _ = get_machine_local_and_dist_rank() for idx, label_source in enumerate(self.label_sources): if label_source == "disk_filelist": paths = self.label_paths[idx] # in case of filelist, we support multiple label files. # we rely on the user to have a proper collator to handle # the multiple labels logging.info(f"Loading labels: {paths}") if isinstance(paths, list): labels = [] for path in paths: path_labels = self.load_single_label_file(path) labels.append(path_labels) else: labels = self.load_single_label_file(paths) elif label_source == "disk_folder": # In this case we use the labels inferred from the directory structure # We enforce that the data source also be a disk folder in this case assert self.data_sources[idx] == self.label_sources[idx] if local_rank == 0: logging.info( f"Using {label_source} labels from {self.data_paths[idx]}" ) # Use the ImageFolder object created when loading images. # We do not create it again since it can be an expensive operation. labels = [x[1] for x in self.data_objs[idx].image_dataset.samples] labels = np.array(labels).astype(np.int64) else: raise ValueError(f"unknown label source: {label_source}") self.label_objs.append(labels)
[docs] def __getitem__(self, idx): """ Get the input sample for the minibatch for a specified data index. For each data object (if we are loading several datasets in a minibatch), we get the sample: consisting of { - image data, - label (if applicable) otherwise idx - data_valid: 0 or 1 indicating if the data is valid image - data_idx : index of the data in the dataset for book-keeping and debugging } Once the sample data is available, we apply the data transform on the sample. The final transformed sample is returned to be added into the minibatch. """ if not self._labels_init and len(self.label_sources) > 0: self._load_labels() self._labels_init = True # TODO: this doesn't yet handle the case where the length of datasets # could be different. item = {"data": [], "data_valid": [], "data_idx": []} for source in self.data_objs: data, valid = source[idx] item["data"].append(data) item["data_idx"].append(idx) item["data_valid"].append(1 if valid else -1) if (len(self.label_objs) > 0) or self.label_type == "standard": item["label"] = [] for source in self.label_objs: if isinstance(source, list): lbl = [entry[idx] for entry in source] else: lbl = _convert_lbl_to_long(source[idx]) item["label"].append(lbl) elif self.label_type == "sample_index": item["label"] = [] for _ in range(len(self.data_objs)): item["label"].append(idx) else: raise ValueError(f"Unknown label type: {self.label_type}") # apply the transforms on the image if self.transform: item = self.transform(item) return item
[docs] def __len__(self): """ Size of the dataset. Assumption made there is only one data source """ return len(self.data_objs[0])
[docs] def get_image_paths(self): """ Get the image paths for all the data sources. Return: image_paths (List[List[str]]): list containing image paths list for each data source. """ image_paths = [] for source in self.data_objs: image_paths.append(source.get_image_paths()) return image_paths
[docs] def get_available_splits(self, dataset_config): """ Get the available splits in the dataset confir. Not specific to this split for which the SSLDataset is being constructed. NOTE: this is deprecated method. """ return [key for key in dataset_config if key.lower() in ["train", "test"]]
[docs] def num_samples(self, source_idx=0): """ Size of the dataset. Assumption made there is only one data source """ return len(self.data_objs[source_idx])
[docs] def get_batchsize_per_replica(self): """ Get the batch size per trainer """ # this searches for batchsize_per_replica in self and then in self.dataset return getattr(self, "batchsize_per_replica", 1)
[docs] def get_global_batchsize(self): """ The global batch size across all the trainers """ return self.get_batchsize_per_replica() * get_world_size()