Source code for vissl.data.collators

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from functools import partial
from pathlib import Path

from classy_vision.generic.registry_utils import import_all_modules
from torch.utils.data.dataloader import default_collate


FILE_ROOT = Path(__file__).parent


COLLATOR_REGISTRY = {}
COLLATOR_NAMES = set()


[docs]def register_collator(name): """ Registers Self-Supervision data collators. This decorator allows VISSL to add custom data collators, even if the collator itself is not part of VISSL. To use it, apply this decorator to a collator function, like this: .. code-block:: python @register_collator('my_collator_name') def my_collator_name(): ... To get a collator from a configuration file, see :func:`get_collator`. """ def register_collator_fn(func): if name in COLLATOR_REGISTRY: raise ValueError("Cannot register duplicate collator ({})".format(name)) if func.__name__ in COLLATOR_NAMES: raise ValueError( "Cannot register task with duplicate collator name ({})".format( func.__name__ ) ) COLLATOR_REGISTRY[name] = func COLLATOR_NAMES.add(func.__name__) return func return register_collator_fn
[docs]def get_collator(collator_name, collate_params): """ Given the collator name and the collator params, return the collator if registered with VISSL. Also supports pytorch default collators. """ if collator_name == "default_collate": return default_collate else: assert collator_name in COLLATOR_REGISTRY, "Unknown collator" # return COLLATOR_REGISTRY[collator_name] return partial(COLLATOR_REGISTRY[collator_name], **collate_params)
# automatically import any Python files in the collators/ directory import_all_modules(FILE_ROOT, "vissl.data.collators")