Source code for vissl.meters.accuracy_list_meter

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import pprint
from typing import List, Union

import torch
from classy_vision.generic.util import is_pos_int
from classy_vision.meters import AccuracyMeter, ClassyMeter, register_meter
from vissl.utils.hydra_config import AttrDict


[docs]@register_meter("accuracy_list_meter") class AccuracyListMeter(ClassyMeter): """ Meter to calculate top-k accuracy for single label image classification task. Supports Single target and multiple output. A list of accuracy meters is constructed and each output has a meter associated. Args: num_meters: number of meters and hence we have same number of outputs topk_values: list of int `k` values. Example: [1, 5] meter_names: list of str indicating the name of meter. Usually corresponds to the output layer name. """ def __init__(self, num_meters: int, topk_values: List[int], meter_names: List[str]): super().__init__() assert is_pos_int(num_meters), "num_meters must be positive" assert isinstance(topk_values, list), "topk_values must be a list" assert len(topk_values) > 0, "topk_values list should have at least one element" assert [ is_pos_int(x) for x in topk_values ], "each value in topk_values must be >= 1" self._num_meters = num_meters self._topk_values = topk_values self._meters = [ AccuracyMeter(self._topk_values) for _ in range(self._num_meters) ] self._meter_names = meter_names self.reset()
[docs] @classmethod def from_config(cls, meters_config: AttrDict): """ Get the AccuracyListMeter instance from the user defined config """ return cls( num_meters=meters_config["num_meters"], topk_values=meters_config["topk_values"], meter_names=meters_config["meter_names"], )
@property def name(self): """ Name of the meter """ return "accuracy_list_meter" @property def value(self): """ Value of the meter globally synced. For each output, all the top-k values are returned. If there are several meters attached to the same layer name, a list of top-k values will be returned for that layer name meter. """ val_dict = {} for ind, meter in enumerate(self._meters): meter_val = meter.value sample_count = meter._total_sample_count val_dict[ind] = {} val_dict[ind]["val"] = meter_val val_dict[ind]["sample_count"] = sample_count # also create dict w.r.t top-k output_dict = {} for k in self._topk_values: top_k_str = f"top_{k}" output_dict[top_k_str] = {} for ind in range(len(self._meters)): meter_name = ( self._meter_names[ind] if (len(self._meter_names) > 0) else ind ) val = 100.0 * round(float(val_dict[ind]["val"][top_k_str]), 6) # we could have several meters with the same name. We append the result # to the dict. if meter_name not in output_dict[top_k_str]: output_dict[top_k_str][meter_name] = [val] else: output_dict[top_k_str][meter_name].append(val) for topk in output_dict: for k in output_dict[topk]: if len(output_dict[topk][k]) == 1: output_dict[topk][k] = output_dict[topk][k][0] return output_dict
[docs] def sync_state(self): """ Globally syncing the state of each meter across all the trainers. """ for _, meter in enumerate(self._meters): meter.sync_state()
[docs] def get_classy_state(self): """ Returns the states of each meter """ meter_states = {} for ind, meter in enumerate(self._meters): state = meter.get_classy_state() meter_states[ind] = {} meter_states[ind]["state"] = state return meter_states
[docs] def set_classy_state(self, state): """ Set the state of each meter """ assert len(state) == len(self._meters), "Incorrect state dict for meters" for ind, meter in enumerate(self._meters): meter.set_classy_state(state[ind]["state"])
def __repr__(self): value = self.value # convert top_k list into csv format for easy copy pasting for k in self._topk_values: top_k_str = f"top_{k}" hr_format = ["%.1f" % (100 * x) for x in value[top_k_str]] value[top_k_str] = ",".join(hr_format) repr_dict = {"name": self.name, "num_meters": self._num_meters, "value": value} return pprint.pformat(repr_dict, indent=2)
[docs] def update( self, model_output: Union[torch.Tensor, List[torch.Tensor]], target: torch.Tensor, ): """ Updates the value of the meter for the given model output list and targets. Args: model_output: list of tensors of shape (B, C) where each value is either logit or class probability. target: tensor of shape (B). NOTE: For binary classification, C=2. """ if isinstance(model_output, torch.Tensor): model_output = [model_output] assert isinstance(model_output, list) assert len(model_output) == self._num_meters for (meter, output) in zip(self._meters, model_output): meter.update(output, target)
[docs] def reset(self): """ Reset all the meters """ [x.reset() for x in self._meters]
[docs] def validate(self, model_output_shape, target_shape): """ Not implemented """ pass