Shortcuts

Metrics

import torch.distributed as dist
import warnings

# BASE CLASSES
from ignite_framework.metrics import BaseMetric

# FEATURE CLASSES
from ignite_framework.exceptions import FeatureRuntimeError
from ignite_framework.feature_dev_tools.state_objects import StateParameter

# UTILS
from ignite_framework.feature_dev_tools import create_bso_ctr_name_prefix_from_input_ref

class AverageOutput(BaseMetric):
    """
    Averaged output over `Metric`'s period `start_event` till `completed_event` with measures taken at
    every `iteration_completed_event`.
    """
    def __init__(self, metric_name, input_ref, started_ref='engine_run_started_ref',
                 iteration_completed_ref='n_iteration_completed_ref', completed_ref='n_epoch_completed_ref',
                 batch_size_fn=lambda x: len(x), device='', ignore_zero_division=True):

        super().__init__(metric_name=metric_name, input_ref=input_ref, started_ref=started_ref,
                         iteration_completed_ref=iteration_completed_ref, completed_ref=completed_ref, device=device)
        self._batch_size_fn = batch_size_fn
        self._ignore_zero_division = ignore_zero_division
        self.reset()

    def reset(self):
        # Reseting during class instantiation and at `self.started_event`
        self._sum = 0
        self._num_accumulations = 0

    def update(self):
        try:
            self. _sum += sum(self._input_ref.caller_name)
            self._num_accumulations += self._batch_size_fn(self._input_ref.caller_name)
        except TypeError:
            self._sum += self._input_ref.caller_name
            self._num_accumulations += 1

    def compute(self):
        if self._num_accumulations == 0:
            error_msg = '`state.{}.compute()` cannot be calculated because number of accumulated samples ' \
                        '`self._num_samples == 0`. At least one sample must be accumulated before `self.compute()` ' \
                        'can be called.'.format(self.name)
            if self._ignore_zero_division:
                warnings.warn(error_msg)
                return
            else:
                raise FeatureRuntimeError(error_msg)
        return self._sum / self._num_accumulations


class OutputMetric(BaseMetric):

    def __init__(self, metric_name, input_ref, caller_ref='', device=''):

        caller_ref = input_ref if caller_ref == '' else caller_ref

        super().__init__(metric_name=metric_name, input_ref=input_ref, started_ref=None, iteration_completed_ref=None,
                         completed_ref=caller_ref, device=device)

    def reset(self): pass

    def update(self): pass

    def compute(self):
        return self._input_ref.caller_name


class RunningAverage(BaseMetric):
    """
    Compute running average of a output, e.g. a metric ouput or the output of process function.

    Note: a metric class as `metric_output_ref` is not handled here anymore due to the enhancement of the
    metric arguments to customize the callback variables (events). Nevertheless, this feature with default
    callback variables can easily be implemented if desired.
    """
    # State parameter
    alpha = StateParameter()

    def __init__(self, metric_name, input_ref, caller_ref='', alpha=0.98, device=''):

        # Argument checking & formating
        caller_ref = input_ref if caller_ref == '' else caller_ref

        super().__init__(metric_name=metric_name, input_ref=input_ref, started_ref=None,
                         iteration_completed_ref=None, completed_ref=caller_ref, device=device)
        if not (0.0 < alpha <= 1.0):
            raise ValueError("Argument alpha should be a float between 0.0 and 1.0.")
        self.alpha = alpha

    def reset(self): pass

    def update(self): pass

    def compute(self):
        try:
            return self.output * self.alpha + (1.0 - self.alpha) * self._input_ref.caller_name
        except TypeError:
            return self._input_ref.caller_name