Source code for vissl.losses.bce_logits_multiple_output_single_target

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import List, Union

import torch
from classy_vision.generic.util import is_on_gpu
from classy_vision.losses import ClassyLoss, register_loss
from torch import nn
from vissl.utils.hydra_config import AttrDict

[docs]@register_loss("bce_logits_multiple_output_single_target") class BCELogitsMultipleOutputSingleTargetLoss(ClassyLoss):
[docs] def __init__(self, loss_config: AttrDict): """ Intializer for the sum cross-entropy loss. For a single tensor, this is equivalent to the cross-entropy loss. For a list of tensors, this computes the sum of the cross-entropy losses for each tensor in the list against the target. Config params: reduction: specifies reduction to apply to the output, optional normalize_output: Whether to L2 normalize the outputs world_size: total number of gpus in training. automatically inferred by vissl """ super(BCELogitsMultipleOutputSingleTargetLoss, self).__init__() self.loss_config = loss_config self._losses = torch.nn.modules.ModuleList([]) self._reduction = loss_config.get("reduction", "none") self._normalize_output = loss_config.get("normalize_output", False) self._world_size = loss_config["world_size"]
[docs] @classmethod def from_config(cls, loss_config: AttrDict): """ Instantiates BCELogitsMultipleOutputSingleTargetLoss from configuration. Args: loss_config: configuration for the loss Returns: BCELogitsMultipleOutputSingleTargetLoss instance. """ return cls(loss_config)
def _create_loss_function(self): copy_to_gpu = is_on_gpu(self._losses) self._losses.append(nn.BCEWithLogitsLoss(reduction=self._reduction)) if copy_to_gpu: self._losses.cuda() return self
[docs] def forward( self, output: Union[torch.Tensor, List[torch.Tensor]], target: torch.Tensor ): """ For each output and single target, loss is calculated. The returned loss value is the sum loss across all outputs. """ if isinstance(output, torch.Tensor): output = [output] assert isinstance( output, list ), "Model output should be a list of tensors. Got Type {}".format(type(output)) assert torch.is_tensor(target), "Target should be a tensor. Got Type {}".format( type(target) ) loss = 0 for idx, pred in enumerate(output): normalized_pred = pred if self._normalize_output: normalized_pred = nn.functional.normalize(pred, dim=1, p=2) mask1 = target == -1 # number of valid (0 or 1 label) entries per class num_per_class = torch.sum(~mask1, dim=0) # number of classes with no valid entries. mask2 = num_per_class == 0 num_per_class.masked_fill_(mask2, 1) if idx >= len(self._losses): self._create_loss_function() loss += torch.sum( self._losses[idx](normalized_pred, target.float()).masked_fill_( mask1, 0 ) / num_per_class ) return loss