Source code for vissl.losses.cross_entropy_multiple_output_single_target

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import List, Union

import torch
from classy_vision.generic.util import is_on_gpu
from classy_vision.losses import ClassyLoss, register_loss
from torch import nn
from vissl.utils.hydra_config import AttrDict


[docs]@register_loss("cross_entropy_multiple_output_single_target") class CrossEntropyMultipleOutputSingleTargetLoss(ClassyLoss): """ Intializer for the sum cross-entropy loss. For a single tensor, this is equivalent to the cross-entropy loss. For a list of tensors, this computes the sum of the cross-entropy losses for each tensor in the list against the target. Config params: weight: weight of sample, optional ignore_index: sample should be ignored for loss, optional reduction: specifies reduction to apply to the output, optional temperature: specify temperature for softmax. Default 1.0 """ def __init__(self, loss_config: AttrDict): super(CrossEntropyMultipleOutputSingleTargetLoss, self).__init__() self._weight = None self._ignore_index = -1 self._losses = torch.nn.modules.ModuleList([]) self._normalize_output = False self._temperature = loss_config["temperature"] if "weight" in loss_config: self._weight = loss_config["weight"] if "ignore_index" in loss_config: self._ignore_index = loss_config["ignore_index"] if "normalize_output" in loss_config: self._normalize_output = loss_config["normalize_output"]
[docs] @classmethod def from_config(cls, loss_config: AttrDict): """ Instantiates CrossEntropyMultipleOutputSingleTargetLoss from configuration. Args: loss_config: configuration for the loss Returns: CrossEntropyMultipleOutputSingleTargetLoss instance. """ return cls(loss_config)
def _create_loss_function(self): copy_to_gpu = is_on_gpu(self._losses) self._losses.append( torch.nn.modules.CrossEntropyLoss( weight=self._weight, ignore_index=self._ignore_index ) ) if copy_to_gpu: self._losses.cuda() return self
[docs] def forward( self, output: Union[torch.Tensor, List[torch.Tensor]], target: torch.Tensor ): """ For each output and single target, loss is calculated. The returned loss value is the sum loss across all outputs. """ if isinstance(output, torch.Tensor): output = [output] assert isinstance( output, list ), "Model output should be a list of tensors. Got Type {}".format(type(output)) assert torch.is_tensor(target), "Target should be a tensor. Got Type {}".format( type(target) ) loss = 0 for idx, pred in enumerate(output): normalized_pred = pred if self._normalize_output: normalized_pred = nn.functional.normalize(pred, dim=1, p=2) assert ( target.max().item() < pred.shape[1] ), f"pred.shape[1]={pred.shape[1]} and target.max().item()={target.max().item()}" if idx >= len(self._losses): self._create_loss_function() loss += self._losses[idx](normalized_pred / self._temperature, target) return loss