Source code for vissl.engines.train

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import logging
import os
from typing import Any, Callable, List

import torch
from classy_vision.hooks.classy_hook import ClassyHook
from vissl.hooks import default_hook_generator
from vissl.trainer import SelfSupervisionTrainer
from vissl.utils.collect_env import collect_env_info
from vissl.utils.env import (
    get_machine_local_and_dist_rank,
    print_system_env_info,
    set_env_vars,
)
from vissl.utils.hydra_config import AttrDict, print_cfg
from vissl.utils.logger import setup_logging, shutdown_logging
from vissl.utils.misc import set_seeds, setup_multiprocessing_method


[docs]def train_main( cfg: AttrDict, dist_run_id: str, checkpoint_path: str, checkpoint_folder: str, local_rank: int = 0, node_id: int = 0, hook_generator: Callable[[Any], List[ClassyHook]] = default_hook_generator, ): """ Sets up and executes training workflow per machine. Args: cfg (AttrDict): user specified input config that has optimizer, loss, meters etc settings relevant to the training dist_run_id (str): For multi-gpu training with PyTorch, we have to specify how the gpus are going to rendezvous. This requires specifying the communication method: file, tcp and the unique rendezvous run_id that is specific to 1 run. We recommend: 1) for 1node: use init_method=tcp and run_id=auto 2) for multi-node, use init_method=tcp and specify run_id={master_node}:{port} checkpoint_path (str): if the training is being resumed from a checkpoint, path to the checkpoint. The tools/run_distributed_engines.py automatically looks for the checkpoint in the checkpoint directory. checkpoint_folder (str): what directory to use for checkpointing. The tools/run_distributed_engines.py creates the directory based on user input in the yaml config file. local_rank (int): id of the current device on the machine. If using gpus, local_rank = gpu number on the current machine node_id (int): id of the current machine. starts from 0. valid for multi-gpu hook_generator (Callable): The utility function that prepares all the hoooks that will be used in training based on user selection. Some basic hooks are used by default. """ # setup the environment variables set_env_vars(local_rank, node_id, cfg) dist_rank = int(os.environ["RANK"]) # setup logging setup_logging(__name__, output_dir=checkpoint_folder, rank=dist_rank) logging.info(f"Env set for rank: {local_rank}, dist_rank: {dist_rank}") # print the environment info for the current node if local_rank == 0: current_env = os.environ.copy() print_system_env_info(current_env) # setup the multiprocessing to be forkserver. # See https://fb.quip.com/CphdAGUaM5Wf setup_multiprocessing_method(cfg.MULTI_PROCESSING_METHOD) # set seeds logging.info("Setting seed....") set_seeds(cfg, node_id) # We set the CUDA device here as well as a safe solution for all downstream # `torch.cuda.current_device()` calls to return correct device. if cfg.MACHINE.DEVICE == "gpu" and torch.cuda.is_available(): local_rank, _ = get_machine_local_and_dist_rank() torch.cuda.set_device(local_rank) # print the training settings and system settings if local_rank == 0: print_cfg(cfg) logging.info("System config:\n{}".format(collect_env_info())) # get the hooks - these hooks are executed per replica hooks = hook_generator(cfg) # build the SSL trainer. The trainer first prepares a "task" object which # acts as a container for various things needed in a training: datasets, # dataloader, optimizers, losses, hooks, etc. "Task" will also have information # about phases (train, test) both. The trainer then sets up distributed # training. trainer = SelfSupervisionTrainer( cfg, dist_run_id, checkpoint_path, checkpoint_folder, hooks ) trainer.train() logging.info("All Done!") # close the logging streams including the filehandlers shutdown_logging()