Resume training from iteration: Stateful data sampler

Issue with PyTorch DataSampler for large data training

PyTorch default torch.utils.data.distributed.DistributedSampler is the default sampler used for many trainings. However, it becomes limiting to use this sampler in case of large batch size trainings for 2 reasons:

  • Using PyTorch DataSampler, each trainer shuffles the full data (assuming shuffling is used) and then each trainer gets a view of this shuffled data. If the dataset is large (100 millions, 1 billion or more), generating very large permutationon each trainer can lead to large CPU memory consumption per machine. Hence, it becomes difficult to use the PyTorch default DataSampler when user wants to train on large data and for several epochs (for example: 10 epochs of 100M images).

  • When using PyTorch DataSampler and the training is resumed, the sampler will serve the full dataset. However, in case of large data trainings (like 1 billion images or more), one mostly trains for 1 epoch only. In such cases, when the training resumes from the middle of the epoch, the sampler will serve the full 1 billion images which is not what we want.

To solve both the above issues, VISSL provides a custom samplier StatefulDistributedSampler which inherits from the PyTorch DistributedSampler and fixes the above issues in following manner:

  • Sampler creates the view of the data per trainer and then shuffles only the data that trainer is supposed to view. This keeps the CPU memory requirement expected.

  • Sampler adds a member start_iter which tracks what iteration number of the given epoch model is at. When the training is used, the start_iter will be properly set to the last iteration number and the sampler will serve only the remainder of data.

How to use VISSL custom DataSampler

Using VISSL provided custom samplier StatefulDistributedSampler is extremely easy and involves simply setting the correct configuration options as below:

DATA:
  TRAIN:
    USE_STATEFUL_DISTRIBUTED_SAMPLER: True
  TEST:
    USE_STATEFUL_DISTRIBUTED_SAMPLER: True

Note

Users can use StatefulDistributedSampler for only training dataset and use PyTorch default DataSampler if desired i.e. it is not mandatory to use the same sampler type for all data splits.