Resume training from iteration: Stateful data sampler¶
Issue with PyTorch DataSampler for large data training¶
PyTorch default torch.utils.data.distributed.DistributedSampler is the default sampler used for many trainings. However, it becomes limiting to use this sampler in case of large batch size trainings for 2 reasons:
DataSampler, each trainer shuffles the full data (assuming shuffling is used) and then each trainer gets a view of this shuffled data. If the dataset is large (100 millions, 1 billion or more), generating very large permutationon each trainer can lead to large CPU memory consumption per machine. Hence, it becomes difficult to use the PyTorch default
DataSamplerwhen user wants to train on large data and for several epochs (for example: 10 epochs of 100M images).
When using PyTorch
DataSamplerand the training is resumed, the sampler will serve the full dataset. However, in case of large data trainings (like 1 billion images or more), one mostly trains for 1 epoch only. In such cases, when the training resumes from the middle of the epoch, the sampler will serve the full 1 billion images which is not what we want.
To solve both the above issues, VISSL provides a custom samplier
StatefulDistributedSampler which inherits from the PyTorch DistributedSampler and fixes the above issues in following manner:
Sampler creates the view of the data per trainer and then shuffles only the data that trainer is supposed to view. This keeps the CPU memory requirement expected.
Sampler adds a member
start_iterwhich tracks what iteration number of the given epoch model is at. When the training is used, the
start_iterwill be properly set to the last iteration number and the sampler will serve only the remainder of data.
How to use VISSL custom DataSampler¶
Using VISSL provided custom samplier
StatefulDistributedSampler is extremely easy and involves simply setting the correct configuration options as below:
DATA: TRAIN: USE_STATEFUL_DISTRIBUTED_SAMPLER: True TEST: USE_STATEFUL_DISTRIBUTED_SAMPLER: True
Users can use
StatefulDistributedSampler for only training dataset and use PyTorch default
DataSampler if desired i.e. it is not mandatory to use the same sampler type for all data splits.