Add new Models¶
VISSL allows adding new models (head and trunks easily) and combine different trunks and heads to train a new model. Follow the steps below on how to add new heads or trunks.
Adding New Heads¶
To add a new model head, follow the steps:
Step1: Add the new head
my_new_head
undervissl/models/heads/my_new_head.py
following the template:
import torch
import torch.nn as nn
from vissl.models.heads import register_model_head
@register_model_head("my_new_head")
class MyNewHead(nn.Module):
"""
Add documentation on what this head does and also link any papers where the head is used
"""
def __init__(self, model_config: AttrDict, param1: val, ....):
"""
Args:
add documentation on what are the parameters to the head
"""
super().__init__()
# implement what the init of head should do. Example, it can construct the layers in the head
# like FC etc., initialize the parameters or anything else
....
# the input to the model should be a torch Tensor or list of torch tensors.
def forward(self, batch: torch.Tensor or List[torch.Tensor]):
"""
add documentation on what the head input structure should be, shapes expected
and what the output should be
"""
# implement the forward pass of the head
Step2: The new head is ready to use. Test it by setting the new head in the configuration file.
MODEL:
HEAD:
PARAMS: [
...
["my_new_head", {"param1": val, ...}]
...
]
Adding New Trunks¶
To add a new trunk (a new architecture like vision transformers, etc.), follow the steps:
Step1: Add your new trunk
my_new_trunk
undervissl/data/trunks/my_new_trunk.py
following the template:
import torch
import torch.nn as nn
from vissl.models.trunks import register_model_trunk
@register_model_trunk("my_new_trunk")
class MyNewTrunk(nn.Module):
"""
documentation on what the trunk does and links to technical reports
using this trunk (if applicable)
"""
def __init__(self, model_config: AttrDict, model_name: str):
super(MyNewTrunk, self).__init__()
self.model_config = model_config
# get the params trunk takes from the config
trunk_config = self.model_config.TRUNK.TRUNK_PARAMS.MyNewTrunk
# implement the model trunk and construct all the layers that the trunk uses
model_layer1 = ??
model_layer2 = ??
...
...
# give a name to the layers of your trunk so that these features
# can be used for other purposes: like feature extraction etc.
# the name is fully upto user descretion. User may chose to
# only name one layer which is the last layer of the model.
self._feature_blocks = nn.ModuleDict(
[
("my_layer1_name", model_layer1),
("my_layer1_name", model_layer2),
...
]
)
def forward(
self, x: torch.Tensor, out_feat_keys: List[str] = None
) -> List[torch.Tensor]:
# implement the forward pass of the model. See the forward pass of resnext.py
# for reference.
# The output would be a list. The list can have one tensor (the trunk output)
# or mutliple tensors (corresponding to several features of the trunk)
...
...
return output
Step2: Inform VISSL about the parameters of the trunk. Register the params with VISSL Configuration by adding the params in VISSL defaults.yaml as follows:
MODEL:
TRUNK:
MyNewTrunk:
param1: value1
param2: value2
...
Step3: The trunk is ready to use. Set the trunk name and params in your config file
MODEL.TRUNK.NAME=my_new_trunk