from ignite_framework.states import state
from ignite_framework.framework.framework import FrameMetric
from ignite_framework.engines import BaseEngine
from ignite_framework.exceptions import FeatureRuntimeError
from ignite_framework.feature_dev_tools.state_objects import StateDevice, StateParameter, StateVariable
from ignite_framework.feature_dev_tools import create_bso_ctr_name_prefix_from_input_ref
from abc import abstractmethod
from functools import wraps
import numbers
import torch
import torch.distributed as dist
import warnings
class BaseMetric(FrameMetric):
#TODO:
# - integrate distributed processing, not tested at all so far!!!
"""
Base class for all Metrics.
Args:
metric_name (str):
name of metric which will be complemented with the engines's name as prefix separated with
a `_``. The engine's name is derived from 'input_ref`.
input_ref (StateObjectsReference):
Reference of the output variable from which the metric is calculated, normally the output of a engine,
e.g. `state.trainer` (i.e. `state.engines.trainer`)
output_transform (callable, optional): a callable that is used to transform the
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (str of torch.device, optional): device specification in case of distributed computation usage.
In most of the cases, it can be defined as "cuda:local_rank" or "cuda"
if already set `torch.cuda.set_device(local_rank)`. By default, if a distributed process group is
initialized and available, device is set to `cuda`.
"""
output = StateVariable()
device = StateDevice()
# `@torch.no_grad()` gets `input_ref` as argument which again has
# `input_ref._owner_instances['caller_name].__dict__['output`]` which refers to the model output possibly a
# tensor with enabled gradient... which would be turned off by the decorator
@torch.no_grad()
def __init__(self, metric_name, input_ref, started_ref='engine_run_started_ref',
iteration_completed_ref='n_iteration_completed_ref', completed_ref='n_epoch_completed_ref',
device=state.configs.default_configs.device_for_metrics_ref):
# Automatic/Assisted naming
name_prefix = create_bso_ctr_name_prefix_from_input_ref(input_ref=input_ref)
name = '_'.join([name_prefix, metric_name])
# transform_func_name = otuput_transform
super().__init__(name=name)
# Argument checking:either engine which holds the reference default arguments, or all reference
# arguments must be given
self._input_ref = input_ref
# Default value `state.configs.configs_status.device_for_metrics` is set if `device in ('', None)`
self.device = state.configs.default_configs.device_for_metrics_ref if device == '' else device
refs = self._set_argument_or_default_refs(started_ref, iteration_completed_ref, completed_ref)
self._set_callbacks(refs)
#TODO: - INTEGRATE INTO FRAMEWORK
# Check device if distributed is initialized:
if dist.is_available() and dist.is_initialized():
# check if reset and update methods are decorated. Compute may not be decorated
if not (hasattr(self.reset, "_decorated") and hasattr(self.update, "_decorated")):
warnings.warn("{} class does not support distributed setting. Computed result is not collected "
"across all computing devices".format(self.__class__.__name__),
RuntimeWarning)
# if device is None:
# device = "cuda"
# device = torch.device(device)
self._is_reduced = False
self.reset()
@abstractmethod
def reset(self):
"""
Resets the metric to it's initial state.
This is called at the start of each epoch.
"""
pass
@abstractmethod
def update(self):
"""
Updates the metric's state using the passed batch output given by `self._input_ref.caller_name`.
This is called once for each batch.
Args:
output: the is the output from the engine's process function.
"""
pass
@abstractmethod
def compute(self):
"""
Computes the metric based on it's accumulated state.
This is called at the end of each epoch.
Returns:
Any: the actual quantity of interest.
Raises:
NotComputableError: raised when the metric cannot be computed.
"""
pass
def _sync_all_reduce(self, tensor):
if not (dist.is_available() and dist.is_initialized()):
# Nothing to reduce
return tensor
tensor_to_number = False
if isinstance(tensor, numbers.Number):
tensor = torch.tensor(tensor, device=self.device)
tensor_to_number = True
if isinstance(tensor, torch.Tensor):
# check if the tensor is at specified device
if tensor.device != self.device:
tensor = tensor.to(self.device)
else:
raise TypeError("Unhandled input type {}".format(type(tensor)))
# synchronize and reduce
dist.barrier()
dist.all_reduce(tensor)
if tensor_to_number:
return tensor.item()
return tensor
def started(self):
self.reset()
# @torch.no_grad()
def iteration_completed(self):
self.update()
# @torch.no_grad()
# def iteration_completed(self):
# output = self._output_transform(self._input_ref.caller_name)
# self.update(output)
def completed(self):
result = self.compute()
if torch.is_tensor(result) and len(result.shape) == 0:
result = result.item()
self.output = result
def _set_argument_or_default_refs(self, started_ref, iteration_completed_ref, completed_ref):
preceding_bso_ctrs = list(state.dataflow.get_preceding_paths_of_bso_ctr(self.name)[0])
# Get engine out of preceding bso-ctrs path
engine = [getattr(state, bso_ctr_name) for bso_ctr_name in preceding_bso_ctrs
if isinstance(getattr(state, bso_ctr_name), BaseEngine)]
if len(engine) != 1:
raise FeatureRuntimeError(
'Automatic default reference setting of `{}` failed because {} engines have been detected. Automatic '
'reference setting (of e.g. `started_ref`) requires exactly 1 engine in the preceding dataflow path, '
'but detected: `{}`'.format(self.__class__.__name__, len(engine), engine))
engine = engine[0]
arg_refs = [started_ref, iteration_completed_ref, completed_ref]
default_strings = ['engine_run_started_ref', 'n_iteration_completed_ref', 'n_epoch_completed_ref']
default_refs = [engine.engine_run_started_ref, engine.n_iteration_completed_ref, engine.n_epoch_completed_ref]
return [default_ref if arg_ref == default_str else arg_ref for arg_ref, default_str, default_ref
in zip(arg_refs, default_strings, default_refs)]
def _set_callbacks(self, refs):
"""
Append methods to callback lists of caller state variables.
Returns:
"""
callback_methods = [self.started, self.iteration_completed, self.completed]
for callback_method, ref in zip(callback_methods, refs):
if ref == None:
# Do NOT append callback method at all if reference `ref` is `None`
continue
else:
# Append callback method to reference (state variable/parameter)
ref.caller_name = ('`state.{}.{}()`'.format(self.name, callback_method.__name__), callback_method)
# def run(self):
# raise_run_not_implemented_and_appendable_to_callback_error(transition=self)
# # Set the method callers (events) either as
# # ref_caller_names = ['model_output', 'started', 'iteration_completed', 'completed']
# ref_names = ['caller_name', 'started_ref', 'iteration_completed_ref', 'completed_ref']
# arg_refs = [engine.output_ref, started_ref, iteration_completed_ref, completed_ref]
# default_refs = [engine.engine_run_started, engine.n_iteration_completed, engine.engine_run_completed]
# refs = [arg_ref if arg_ref is not None else default_ref for arg_ref, default_ref in zip(arg_refs, default_refs)]
# for ref_name, arg_ref, default_ref in zip(ref_names, arg_refs, default_refs):
# setattr(self, ref_name, arg_ref if arg_ref is not None else default_ref)
# self.started_ref.caller_name = self.started
# self.iteration_completed_ref.caller_name = self.iteration_completed
# self.completed_ref.caller_name = self.iteration_completed
def __add__(self, other):
from ignite.metrics import MetricsLambda
return MetricsLambda(lambda x, y: x.caller_name + y.caller_name, self, other)
def __radd__(self, other):
from ignite.metrics import MetricsLambda
return MetricsLambda(lambda x, y: x.caller_name + y.caller_name, other, self)
def __sub__(self, other):
from ignite.metrics import MetricsLambda
return MetricsLambda(lambda x, y: x.caller_name - y.caller_name, self, other)
def __rsub__(self, other):
from ignite.metrics import MetricsLambda
return MetricsLambda(lambda x, y: x.caller_name - y.caller_name, other, self)
def __mul__(self, other):
from ignite.metrics import MetricsLambda
return MetricsLambda(lambda x, y: x.caller_name * y.caller_name, self, other)
def __rmul__(self, other):
from ignite.metrics import MetricsLambda
return MetricsLambda(lambda x, y: x.caller_name * y.caller_name, other, self)
def __pow__(self, other):
from ignite.metrics import MetricsLambda
return MetricsLambda(lambda x, y: x.caller_name ** y.caller_name, self, other)
def __rpow__(self, other):
from ignite.metrics import MetricsLambda
return MetricsLambda(lambda x, y: x.caller_name ** y.caller_name, other, self)
def __mod__(self, other):
from ignite.metrics import MetricsLambda
return MetricsLambda(lambda x, y: x.caller_name % y.caller_name, self, other)
def __div__(self, other):
from ignite.metrics import MetricsLambda
return MetricsLambda(lambda x, y: x.caller_name.__div__(y.caller_name), self, other)
def __rdiv__(self, other):
from ignite.metrics import MetricsLambda
return MetricsLambda(lambda x, y: x.caller_name.__div__(y.caller_name), other, self)
def __truediv__(self, other):
from ignite.metrics import MetricsLambda
return MetricsLambda(lambda x, y: x.caller_name.__truediv__(y.caller_name), self, other)
def __rtruediv__(self, other):
from ignite.metrics import MetricsLambda
return MetricsLambda(lambda x, y: x.caller_name.__truediv__(y.caller_name), other, self)
def __floordiv__(self, other):
from ignite.metrics import MetricsLambda
return MetricsLambda(lambda x, y: x.caller_name // y.caller_name, self, other)
# TODO: - re-integrate `__gettattr__`
def __getattr__(self, attr):
# # HACK: "integrating" the `super().__gettattr__` of bso-container for metric's `StateVariable,
# if any([bso_name in attr for bso_name in [*self._variable_names[cls_name], *type(self.__class__)._parameter_names[self.__class__.__name__],
# *type(self.__class__)._other_bso_names[self.__class__.__name__]]]):
try:
return super().__getattr__(attr)
except KeyError:
from ignite.metrics import MetricsLambda
def fn(x, *args, **kwargs):
return getattr(x, attr)(*args, **kwargs)
def wrapper(*args, **kwargs):
return MetricsLambda(fn, self, *args, **kwargs)
return wrapper
def __getitem__(self, index):
from ignite.metrics import MetricsLambda
return MetricsLambda(lambda x: x.caller_name[index], self)
def sync_all_reduce(*attrs):
def wrapper(func):
@wraps(func)
def another_wrapper(self, *args, **kwargs):
if not isinstance(self, BaseMetric):
raise RuntimeError("Decorator sync_all_reduce should be used on "
"ignite.metric.Metric class methods only")
if len(attrs) > 0 and not self._is_reduced:
for attr in attrs:
t = getattr(self, attr, None)
if t is not None:
t = self._sync_all_reduce(t)
self._is_reduced = True
setattr(self, attr, t)
return func(self, *args, **kwargs)
return another_wrapper
wrapper._decorated = True
return wrapper
def reinit__is_reduced(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
func(self, *args, **kwargs)
self._is_reduced = False
wrapper._decorated = True
return wrapper