Source code for vissl.hooks.deepclusterv2_hooks

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from classy_vision.hooks.classy_hook import ClassyHook


[docs]class InitMemoryHook(ClassyHook): """ Initialize the memory banks. Valid only for DeepClusterV2 training """ on_phase_start = ClassyHook._noop on_forward = ClassyHook._noop on_loss_and_meter = ClassyHook._noop on_backward = ClassyHook._noop on_update = ClassyHook._noop on_step = ClassyHook._noop on_phase_end = ClassyHook._noop on_end = ClassyHook._noop
[docs] def on_start(self, task) -> None: """ At the begining of the training, initialize the memory banks """ if not task.config["LOSS"]["name"] == "deepclusterv2_loss": return if task.train_phase_idx >= 0: return task.loss.init_memory(task.dataloaders["train"], task.model)
[docs]class ClusterMemoryHook(ClassyHook): """ Cluster the memory banks with distributed k-means. Valid only for DeepClusterV2 trainings. """ on_start = ClassyHook._noop on_forward = ClassyHook._noop on_loss_and_meter = ClassyHook._noop on_backward = ClassyHook._noop on_update = ClassyHook._noop on_step = ClassyHook._noop on_phase_end = ClassyHook._noop on_end = ClassyHook._noop
[docs] def on_phase_start(self, task) -> None: """ At the beginning of each epochs, cluster the memory banks with distributed k-means """ if not task.config["LOSS"]["name"] == "deepclusterv2_loss": return task.loss.cluster_memory()