Source code for vissl.data.dataloader_sync_gpu_wrapper

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Iterable, Iterator

import torch
from classy_vision.dataset.dataloader_wrapper import DataloaderWrapper
from classy_vision.generic.util import recursive_copy_to_gpu


[docs]class DataloaderSyncGPUWrapper(DataloaderWrapper): """ Dataloader which wraps another dataloader, and moves the data to GPU in async manner so as to overlap the cost of copying data from cpu to gpu with the previous model iteration. """ def __init__(self, dataloader: Iterable) -> None: assert torch.cuda.is_available(), "This Dataloader wrapper needs a CUDA setup" super().__init__(dataloader) self._iter = None def __iter__(self) -> Iterator[Any]: # The wrapped dataloader may have been changed in place # rebuild a new iterator and prefetch self._iter = iter(self.dataloader) return self def __next__(self) -> Any: # Get data from the iterator and move to GPU # This can raise `StopIteration` return recursive_copy_to_gpu(next(self._iter), non_blocking=True) def __len__(self) -> int: return len(self.dataloader)