Add new Models

VISSL allows adding new models (head and trunks easily) and combining different trunks and heads to train a new model. Follow the steps below on how to add new heads or trunks.

Adding New Heads

To add a new model head, follow the steps:

  • Step1: Add the new head my_new_head under vissl/models/heads/my_new_head.py following the template:

import torch
import torch.nn as nn
from vissl.models.heads import register_model_head

@register_model_head("my_new_head")
class MyNewHead(nn.Module):
    """
    Add documentation on what this head does and also link any papers where the head is used
    """

    def __init__(self, model_config: AttrDict, param1: val, ....):
        """
        Args:
            add documentation on what are the parameters to the head
        """
        super().__init__()
        # implement what the init of head should do. Example, it can construct the layers in the head
        # like FC etc., initialize the parameters or anything else
        ....

    # the input to the model should be a torch Tensor or list of torch tensors.
    def forward(self, batch: torch.Tensor or List[torch.Tensor]):
        """
        add documentation on what the head input structure should be, shapes expected
        and what the output should be
        """
        # implement the forward pass of the head
  • Step2: The new head is ready to use. Test it by setting the new head in the configuration file.

MODEL:
  HEAD:
    PARAMS: [
        ...
        ["my_new_head", {"param1": val, ...}]
        ...
    ]

Adding New Trunks

To add a new trunk (a new architecture like vision transformers, etc.), follow the steps:

  • Step1: Add your new trunk my_new_trunk under vissl/data/trunks/my_new_trunk.py following the template:

import torch
import torch.nn as nn
from vissl.models.trunks import register_model_trunk

@register_model_trunk("my_new_trunk")
class MyNewTrunk(nn.Module):
    """
    documentation on what the trunk does and links to technical reports
    using this trunk (if applicable)
    """

    def __init__(self, model_config: AttrDict, model_name: str):
        super(MyNewTrunk, self).__init__()
        self.model_config = model_config

        # get the params trunk takes from the config
        trunk_config = self.model_config.TRUNK.MyNewTrunk

        # implement the model trunk and construct all the layers that the trunk uses
        model_layer1 = ??
        model_layer2 = ??
        ...
        ...

        # give a name to the layers of your trunk so that these features
        # can be used for other purposes: like feature extraction etc.
        # the name is fully upto user descretion. User may chose to
        # only name one layer which is the last layer of the model.
        self._feature_blocks = nn.ModuleDict(
            [
                ("my_layer1_name", model_layer1),
                ("my_layer1_name", model_layer2),
                ...
            ]
        )

    def forward(
        self, x: torch.Tensor, out_feat_keys: List[str] = None
    ) -> List[torch.Tensor]:
        # implement the forward pass of the model. See the forward pass of resnext.py
        # for reference.
        # The output would be a list. The list can have one tensor (the trunk output)
        # or mutliple tensors (corresponding to several features of the trunk)
        ...
        ...

        return output
  • Step2: Inform VISSL about the parameters of the trunk. Register the params with VISSL Configuration by adding the params in VISSL defaults.yaml as follows:

MODEL:
  TRUNK:
    MyNewTrunk:
      param1: value1
      param2: value2
      ...
  • Step3: The trunk is ready to use. Set the trunk name and params in your config file MODEL.TRUNK.NAME=my_new_trunk

Adding New Base Model

VISSL’s uses BaseSSLMultiInputOutputModel as it’s base model class where it invokes the Trunk and the Head models. When altering the head or trunk does not offer enough flexibility, a user may wish to override the entire base model. NOTE: Usually implementing a new HEAD or TRUNK should fulfill your needs. Only use this if necessary. - - Step1: Add the new model my_new_head under vissl/models/my_new_model.py following the template for full compatibility with VISSL:

from classy_vision.models import ClassyModel, register_model

@register_model("my_new_model")
class MyNewModel(ClassyModel):
    """
    Add documentation on what this model is.
    """

    def __init__(self, model_config: AttrDict, param1: val, ....):
        """
        Args:
            add documentation on what are the parameters to the head
        """
        super().__init__()
        # implement what the init of model should do.
        ...

    def forward(self, batch):
        """
        Main forward of the model. Depending on the model type the calls are patched
        to the suitable function.
        """
        ...

    def freeze_head(self):
        """
        Freeze the model head.
        """
        ...

    def freeze_trunk(self):
        """
        Freeze the model trunk
        """
        ...

    def freeze_head_and_trunk(self):
        """
        Freeze the model trunk and head.
        """
        ...

    def is_fully_frozen_model(self):
        """
        If the model is fully frozen.
        """
        ...

    def get_classy_state(self, deep_copy=False):
        """
        Return the model state (trunk + heads) to checkpoint.
        """
        ...

    def set_classy_state(self, deep_copy=False):
        """
        Initialize the model trunk and head from the state dictionary.
        """
        ...

    def init_model_from_weights_params_file(self):
        """
        We initialize the weights from this checkpoint.
        """
        ...
  • Step2: The new model is ready to use. Test it by setting the new model in the configuration file.

MODEL:
  # default model. User can define their own model and use that instead.
  BASE_MODEL_NAME: multi_input_output_model