Source code for classy_vision.hooks.classy_hook

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import Any, Dict

from classy_vision.generic.util import log_class_usage


class ClassyHookState:
    """Class to store state within instances of ClassyHook.

    Any serializable data can be stored in the instance's attributes.
    """

    def get_classy_state(self) -> Dict[str, Any]:
        return self.__dict__

    def set_classy_state(self, state_dict: Dict[str, Any]):
        # We take a conservative approach and only update the dictionary instead of
        # replacing it. This allows hooks to continue functioning in case the state
        # is loaded from older implementations.
        self.__dict__.update(state_dict)


class ClassyHook(ABC):
    """Base class for hooks.

    Hooks allow to inject behavior at different places of the training loop, which
    are listed below in the chronological order.

        on_start -> on_phase_start ->
            on_step -> on_phase_end -> on_end

    Deriving classes should call ``super().__init__()`` and store any state in
    ``self.state``. Any state added to this property should be serializable.
    E.g. -

    .. code-block:: python

        class MyHook(ClassyHook):
            def __init__(self, a, b):
                super().__init__()
                self.state.a = [1,2,3]
                self.state.b = "my_hook"
                # the following line is not allowed
                # self.state.my_lambda = lambda x: x^2

    """

    def __init__(self):
        log_class_usage("Hooks", self.__class__)
        self.state = ClassyHookState()

    @classmethod
    def from_config(cls, config) -> "ClassyHook":
        return cls(**config)

    def _noop(self, *args, **kwargs) -> None:
        """Derived classes can set their hook functions to this.

        This is useful if they want those hook functions to not do anything.

        """
        pass

    @classmethod
    def name(cls) -> str:
        """Returns the name of the class."""
        return cls.__name__

    @abstractmethod
    def on_start(self, task) -> None:
        """Called at the start of training."""
        pass

    @abstractmethod
    def on_phase_start(self, task) -> None:
        """Called at the start of each phase."""
        pass

    @abstractmethod
    def on_step(self, task) -> None:
        """Called each time after parameters have been updated by the optimizer."""
        pass

    @abstractmethod
    def on_phase_end(self, task) -> None:
        """Called at the end of each phase (epoch)."""
        pass

    @abstractmethod
    def on_end(self, task) -> None:
        """Called at the end of training."""
        pass

    def get_classy_state(self) -> Dict[str, Any]:
        """Get the state of the ClassyHook.

        The returned state is used for checkpointing.

        Returns:
            A state dictionary containing the state of the hook.\

        """
        return self.state.get_classy_state()

    def set_classy_state(self, state_dict: Dict[str, Any]) -> None:
        """Set the state of the ClassyHook.

        Args:
            state_dict: The state dictionary. Must be the output of a call to
                :func:`get_classy_state`.

        This is used to load the state of the hook from a checkpoint.

        """
        self.state.set_classy_state(state_dict)