Source code for vissl.optimizers.param_scheduler.cosine_warm_restart_scheduler

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import bisect
import math
from enum import Enum
from typing import Any, Dict

import numpy as np
from classy_vision.optim.param_scheduler import (
    ClassyParamScheduler,
    UpdateInterval,
    register_param_scheduler,
)


[docs]class CosineWaveTypes(str, Enum): half = "half" full = "full"
[docs]@register_param_scheduler("cosine_warm_restart") class CosineWarmRestartScheduler(ClassyParamScheduler): """ Changes the param value after every epoch based on a `cosine schedule <https: //arxiv.org/pdf/1608.03983.pdf>`_. The schedule is updated after every train step by default. Can be used for cosine learning rate with warm restarts. For restarts, we calculate what will be the maximum learning rate after every restart. There are 3 options: - Option 1: LR after every restart is same as original max LR - Option 2: LR after every restart decays with a fixed LR multiplier - Option 3: LR after every restart is adaptively calculated such that the resulting max LR matches the original cosine wave LR. Args: wave_type: half | full lr_multiplier: float value -> LR after every restart decays with a fixed LR multiplier is_adaptive: True -> if after every restart, maximum LR is adaptively calculated such that the resulting max LR matches the original cosine wave LR. update_interval: step | epoch -> if the LR should be updated after every training iteration or after training epoch Example: .. code-block:: python start_value: 0.1 end_value: 0.0001 restart_interval_length: 0.5 # for 1 restart wave_type: half lr_multiplier: 1.0 # for using a decayed max LR value at every restart use_adaptive_decay: False """ def __init__( self, start_value: float, end_value: float, restart_interval_length: float, wave_type: str, lr_multiplier: float, is_adaptive: bool, update_interval: UpdateInterval = UpdateInterval.STEP, ): super().__init__(update_interval=update_interval) self.restart_interval_length = restart_interval_length self.wave_type = CosineWaveTypes(wave_type) self._start_value = start_value self._end_value = end_value self._max_lr = self._start_value self.lr_multiplier = lr_multiplier self.is_adaptive = is_adaptive self.restart_steps = [] self.lr_restart_values = [] self._init_restart_step_values() # we calculate what will be the maximum learning rate after every restart. # There are 3 options: # Option 1: LR after every restart is same as original max LR # Option 2: LR after every restart decays with a fixed LR multiplier # Option 3: LR after every restart is adaptively calculated such that the resulting # max LR matches the original cosine wave LR. self._compute_lr_restart_values() def _init_restart_step_values(self): # calculate where the LR is restarted. We need this so we can find # what restart number we are at based on "where" we are in the training. self.restart_steps = [0.0] if self.wave_type == CosineWaveTypes.full: self.restart_steps.extend( np.arange( self.restart_interval_length, 1.0, 2 * self.restart_interval_length ).tolist() ) elif self.wave_type == CosineWaveTypes.half: self.restart_steps.extend( np.arange( self.restart_interval_length, 1.0, self.restart_interval_length ).tolist() ) def _compute_lr_restart_values(self): for idx in range(len(self.restart_steps)): if self.is_adaptive: if self.wave_type == CosineWaveTypes.half: desired_max_lr = self._end_value + 0.5 * ( self._start_value - self._end_value ) * (1 + math.cos(math.pi * self.restart_steps[idx])) elif self.wave_type == CosineWaveTypes.full: desired_max_lr = self._end_value + 0.5 * ( self._start_value - self._end_value ) * ( 1 + math.cos( math.pi * (self.restart_interval_length + self.restart_steps[idx]) ) ) lr_multiplier = desired_max_lr / self._start_value else: lr_multiplier = pow(self.lr_multiplier, idx) restart_lr = self._start_value * lr_multiplier self.lr_restart_values.append(restart_lr)
[docs] @classmethod def from_config(cls, config: Dict[str, Any]) -> "CosineWarmRestartScheduler": """ Instantiates a CosineWarmRestartScheduler from a configuration. Args: config: A configuration for a CosineWarmRestartScheduler. See :func:`__init__` for parameters expected in the config. Returns: A CosineWarmRestartScheduler instance. """ assert ( "start_value" in config and "end_value" in config ), "CosineWarmRestartScheduler requires a start_value and a end_value" assert ( "restart_interval_length" in config and "wave_type" in config ), "CosineWarmRestartScheduler requires restart_interval_length and wave_type" return cls( start_value=config["start_value"], end_value=config["end_value"], restart_interval_length=config["restart_interval_length"], wave_type=config.get("wave_type", CosineWaveTypes.half), lr_multiplier=config.get("lr_multiplier", 1.0), is_adaptive=config.get("is_adaptive", False), update_interval=UpdateInterval.from_config(config, UpdateInterval.STEP), )
def __call__(self, where: float): if self.wave_type == CosineWaveTypes.half: restart_num = max(bisect.bisect(self.restart_steps, where) - 1, 0) self._max_lr = self.lr_restart_values[restart_num] where = (where % float(self.restart_interval_length)) / float( self.restart_interval_length ) elif self.wave_type == CosineWaveTypes.full: restart_num = max(bisect.bisect(self.restart_steps, where) - 1, 0) self._max_lr = self.lr_restart_values[restart_num] where = where / float(self.restart_interval_length) return self._end_value + 0.5 * (self._max_lr - self._end_value) * ( 1 + math.cos(math.pi * where) )