import torch.distributed as dist
import warnings
# BASE CLASSES
from ignite_framework.metrics import BaseMetric
# FEATURE CLASSES
from ignite_framework.exceptions import FeatureRuntimeError
from ignite_framework.feature_dev_tools.state_objects import StateParameter
# UTILS
from ignite_framework.feature_dev_tools import create_bso_ctr_name_prefix_from_input_ref
class AverageOutput(BaseMetric):
"""
Averaged output over `Metric`'s period `start_event` till `completed_event` with measures taken at
every `iteration_completed_event`.
"""
def __init__(self, metric_name, input_ref, started_ref='engine_run_started_ref',
iteration_completed_ref='n_iteration_completed_ref', completed_ref='n_epoch_completed_ref',
batch_size_fn=lambda x: len(x), device='', ignore_zero_division=True):
super().__init__(metric_name=metric_name, input_ref=input_ref, started_ref=started_ref,
iteration_completed_ref=iteration_completed_ref, completed_ref=completed_ref, device=device)
self._batch_size_fn = batch_size_fn
self._ignore_zero_division = ignore_zero_division
self.reset()
def reset(self):
# Reseting during class instantiation and at `self.started_event`
self._sum = 0
self._num_accumulations = 0
def update(self):
try:
self. _sum += sum(self._input_ref.caller_name)
self._num_accumulations += self._batch_size_fn(self._input_ref.caller_name)
except TypeError:
self._sum += self._input_ref.caller_name
self._num_accumulations += 1
def compute(self):
if self._num_accumulations == 0:
error_msg = '`state.{}.compute()` cannot be calculated because number of accumulated samples ' \
'`self._num_samples == 0`. At least one sample must be accumulated before `self.compute()` ' \
'can be called.'.format(self.name)
if self._ignore_zero_division:
warnings.warn(error_msg)
return
else:
raise FeatureRuntimeError(error_msg)
return self._sum / self._num_accumulations
class OutputMetric(BaseMetric):
def __init__(self, metric_name, input_ref, caller_ref='', device=''):
caller_ref = input_ref if caller_ref == '' else caller_ref
super().__init__(metric_name=metric_name, input_ref=input_ref, started_ref=None, iteration_completed_ref=None,
completed_ref=caller_ref, device=device)
def reset(self): pass
def update(self): pass
def compute(self):
return self._input_ref.caller_name
class RunningAverage(BaseMetric):
"""
Compute running average of a output, e.g. a metric ouput or the output of process function.
Note: a metric class as `metric_output_ref` is not handled here anymore due to the enhancement of the
metric arguments to customize the callback variables (events). Nevertheless, this feature with default
callback variables can easily be implemented if desired.
"""
# State parameter
alpha = StateParameter()
def __init__(self, metric_name, input_ref, caller_ref='', alpha=0.98, device=''):
# Argument checking & formating
caller_ref = input_ref if caller_ref == '' else caller_ref
super().__init__(metric_name=metric_name, input_ref=input_ref, started_ref=None,
iteration_completed_ref=None, completed_ref=caller_ref, device=device)
if not (0.0 < alpha <= 1.0):
raise ValueError("Argument alpha should be a float between 0.0 and 1.0.")
self.alpha = alpha
def reset(self): pass
def update(self): pass
def compute(self):
try:
return self.output * self.alpha + (1.0 - self.alpha) * self._input_ref.caller_name
except TypeError:
return self._input_ref.caller_name