Shortcuts

BaseMetric

from ignite_framework.states import state

from ignite_framework.framework.framework import FrameMetric

from ignite_framework.engines import BaseEngine

from ignite_framework.exceptions import FeatureRuntimeError
from ignite_framework.feature_dev_tools.state_objects import StateDevice, StateParameter, StateVariable
from ignite_framework.feature_dev_tools import create_bso_ctr_name_prefix_from_input_ref


from abc import abstractmethod
from functools import wraps
import numbers
import torch
import torch.distributed as dist
import warnings


class BaseMetric(FrameMetric):
    #TODO:
    # - integrate distributed processing, not tested at all so far!!!
    """
    Base class for all Metrics.

    Args:
        metric_name (str):
            name of metric which will be complemented with the engines's name as prefix separated with
            a `_``. The engine's name is derived from 'input_ref`.
        input_ref (StateObjectsReference):
            Reference of the output variable from which the metric is calculated, normally the output of a engine,
            e.g. `state.trainer` (i.e. `state.engines.trainer`)
        output_transform (callable, optional): a callable that is used to transform the
            :class:`~ignite.engine.Engine`'s `process_function`'s output into the
            form expected by the metric. This can be useful if, for example, you have a multi-output model and
            you want to compute the metric with respect to one of the outputs.
        device (str of torch.device, optional): device specification in case of distributed computation usage.
            In most of the cases, it can be defined as "cuda:local_rank" or "cuda"
            if already set `torch.cuda.set_device(local_rank)`. By default, if a distributed process group is
            initialized and available, device is set to `cuda`.
    """

    output = StateVariable()
    device = StateDevice()

    # `@torch.no_grad()` gets `input_ref` as argument which again has
    # `input_ref._owner_instances['caller_name].__dict__['output`]` which refers to the model output possibly a
    # tensor with enabled gradient... which would be turned off by the decorator
    @torch.no_grad()
    def __init__(self, metric_name, input_ref, started_ref='engine_run_started_ref',
                 iteration_completed_ref='n_iteration_completed_ref', completed_ref='n_epoch_completed_ref',
                 device=state.configs.default_configs.device_for_metrics_ref):

        # Automatic/Assisted naming
        name_prefix = create_bso_ctr_name_prefix_from_input_ref(input_ref=input_ref)
        name = '_'.join([name_prefix, metric_name])
        # transform_func_name = otuput_transform
        super().__init__(name=name)

        # Argument checking:either engine which holds the reference default arguments, or all reference
        # arguments must be given
        self._input_ref = input_ref
        # Default value `state.configs.configs_status.device_for_metrics` is set if `device in ('', None)`
        self.device = state.configs.default_configs.device_for_metrics_ref if device == '' else device
        refs = self._set_argument_or_default_refs(started_ref, iteration_completed_ref, completed_ref)
        self._set_callbacks(refs)

        #TODO: - INTEGRATE INTO FRAMEWORK
        # Check device if distributed is initialized:
        if dist.is_available() and dist.is_initialized():
            # check if reset and update methods are decorated. Compute may not be decorated
            if not (hasattr(self.reset, "_decorated") and hasattr(self.update, "_decorated")):
                warnings.warn("{} class does not support distributed setting. Computed result is not collected "
                              "across all computing devices".format(self.__class__.__name__),
                              RuntimeWarning)
            # if device is None:
            #     device = "cuda"
            # device = torch.device(device)

        self._is_reduced = False
        self.reset()

    @abstractmethod
    def reset(self):
        """
        Resets the metric to it's initial state.
        This is called at the start of each epoch.
        """
        pass

    @abstractmethod
    def update(self):
        """
        Updates the metric's state using the passed batch output given by `self._input_ref.caller_name`.
        This is called once for each batch.
        Args:
            output: the is the output from the engine's process function.
        """
        pass

    @abstractmethod
    def compute(self):
        """
        Computes the metric based on it's accumulated state.
        This is called at the end of each epoch.
        Returns:
            Any: the actual quantity of interest.
        Raises:
            NotComputableError: raised when the metric cannot be computed.
        """
        pass

    def _sync_all_reduce(self, tensor):
        if not (dist.is_available() and dist.is_initialized()):
            # Nothing to reduce
            return tensor

        tensor_to_number = False
        if isinstance(tensor, numbers.Number):
            tensor = torch.tensor(tensor, device=self.device)
            tensor_to_number = True

        if isinstance(tensor, torch.Tensor):
            # check if the tensor is at specified device
            if tensor.device != self.device:
                tensor = tensor.to(self.device)
        else:
            raise TypeError("Unhandled input type {}".format(type(tensor)))

        # synchronize and reduce
        dist.barrier()
        dist.all_reduce(tensor)

        if tensor_to_number:
            return tensor.item()
        return tensor

    def started(self):
        self.reset()

    # @torch.no_grad()
    def iteration_completed(self):
        self.update()

    # @torch.no_grad()
    # def iteration_completed(self):
    #     output = self._output_transform(self._input_ref.caller_name)
    #     self.update(output)

    def completed(self):
        result = self.compute()
        if torch.is_tensor(result) and len(result.shape) == 0:
            result = result.item()
        self.output = result

    def _set_argument_or_default_refs(self, started_ref, iteration_completed_ref, completed_ref):
        preceding_bso_ctrs = list(state.dataflow.get_preceding_paths_of_bso_ctr(self.name)[0])
        # Get engine out of preceding bso-ctrs path
        engine = [getattr(state, bso_ctr_name) for bso_ctr_name in preceding_bso_ctrs
                  if isinstance(getattr(state, bso_ctr_name), BaseEngine)]
        if len(engine) != 1:
            raise FeatureRuntimeError(
                'Automatic default reference setting of `{}` failed because {} engines have been detected. Automatic '
                'reference setting (of e.g. `started_ref`) requires exactly 1 engine in the preceding dataflow path, '
                'but detected: `{}`'.format(self.__class__.__name__, len(engine), engine))
        engine = engine[0]
        arg_refs = [started_ref, iteration_completed_ref, completed_ref]
        default_strings = ['engine_run_started_ref', 'n_iteration_completed_ref', 'n_epoch_completed_ref']
        default_refs = [engine.engine_run_started_ref, engine.n_iteration_completed_ref, engine.n_epoch_completed_ref]
        return [default_ref if arg_ref == default_str else arg_ref for arg_ref, default_str, default_ref
                in zip(arg_refs, default_strings, default_refs)]


    def _set_callbacks(self, refs):
        """
        Append methods to callback lists of caller state variables.

        Returns:

        """
        callback_methods = [self.started, self.iteration_completed, self.completed]
        for callback_method, ref in zip(callback_methods, refs):
            if ref == None:
                # Do NOT append callback method at all if reference `ref` is `None`
                continue
            else:
                # Append callback method to reference (state variable/parameter)
                ref.caller_name = ('`state.{}.{}()`'.format(self.name, callback_method.__name__), callback_method)

    # def run(self):
    #     raise_run_not_implemented_and_appendable_to_callback_error(transition=self)

        # # Set the method callers (events) either as
        # # ref_caller_names = ['model_output', 'started', 'iteration_completed', 'completed']
        # ref_names = ['caller_name', 'started_ref', 'iteration_completed_ref', 'completed_ref']
        # arg_refs = [engine.output_ref, started_ref, iteration_completed_ref, completed_ref]
        # default_refs = [engine.engine_run_started, engine.n_iteration_completed, engine.engine_run_completed]
        # refs = [arg_ref if arg_ref is not None else default_ref for arg_ref, default_ref in zip(arg_refs, default_refs)]
        # for ref_name, arg_ref, default_ref in zip(ref_names, arg_refs, default_refs):
        #     setattr(self, ref_name, arg_ref if arg_ref is not None else default_ref)

        # self.started_ref.caller_name = self.started
        # self.iteration_completed_ref.caller_name = self.iteration_completed
        # self.completed_ref.caller_name = self.iteration_completed

    def __add__(self, other):
        from ignite.metrics import MetricsLambda
        return MetricsLambda(lambda x, y: x.caller_name + y.caller_name, self, other)

    def __radd__(self, other):
        from ignite.metrics import MetricsLambda
        return MetricsLambda(lambda x, y: x.caller_name + y.caller_name, other, self)

    def __sub__(self, other):
        from ignite.metrics import MetricsLambda
        return MetricsLambda(lambda x, y: x.caller_name - y.caller_name, self, other)

    def __rsub__(self, other):
        from ignite.metrics import MetricsLambda
        return MetricsLambda(lambda x, y: x.caller_name - y.caller_name, other, self)

    def __mul__(self, other):
        from ignite.metrics import MetricsLambda
        return MetricsLambda(lambda x, y: x.caller_name * y.caller_name, self, other)

    def __rmul__(self, other):
        from ignite.metrics import MetricsLambda
        return MetricsLambda(lambda x, y: x.caller_name * y.caller_name, other, self)

    def __pow__(self, other):
        from ignite.metrics import MetricsLambda
        return MetricsLambda(lambda x, y: x.caller_name ** y.caller_name, self, other)

    def __rpow__(self, other):
        from ignite.metrics import MetricsLambda
        return MetricsLambda(lambda x, y: x.caller_name ** y.caller_name, other, self)

    def __mod__(self, other):
        from ignite.metrics import MetricsLambda
        return MetricsLambda(lambda x, y: x.caller_name % y.caller_name, self, other)

    def __div__(self, other):
        from ignite.metrics import MetricsLambda
        return MetricsLambda(lambda x, y: x.caller_name.__div__(y.caller_name), self, other)

    def __rdiv__(self, other):
        from ignite.metrics import MetricsLambda
        return MetricsLambda(lambda x, y: x.caller_name.__div__(y.caller_name), other, self)

    def __truediv__(self, other):
        from ignite.metrics import MetricsLambda
        return MetricsLambda(lambda x, y: x.caller_name.__truediv__(y.caller_name), self, other)

    def __rtruediv__(self, other):
        from ignite.metrics import MetricsLambda
        return MetricsLambda(lambda x, y: x.caller_name.__truediv__(y.caller_name), other, self)

    def __floordiv__(self, other):
        from ignite.metrics import MetricsLambda
        return MetricsLambda(lambda x, y: x.caller_name // y.caller_name, self, other)

    # TODO: - re-integrate `__gettattr__`
    def __getattr__(self, attr):
        # # HACK: "integrating" the `super().__gettattr__` of bso-container for metric's `StateVariable,
        # if any([bso_name in attr for bso_name in [*self._variable_names[cls_name], *type(self.__class__)._parameter_names[self.__class__.__name__],
        #                                           *type(self.__class__)._other_bso_names[self.__class__.__name__]]]):
        try:
            return super().__getattr__(attr)
        except KeyError:
            from ignite.metrics import MetricsLambda

            def fn(x, *args, **kwargs):
                return getattr(x, attr)(*args, **kwargs)

            def wrapper(*args, **kwargs):
                return MetricsLambda(fn, self, *args, **kwargs)

            return wrapper

    def __getitem__(self, index):
        from ignite.metrics import MetricsLambda
        return MetricsLambda(lambda x: x.caller_name[index], self)

def sync_all_reduce(*attrs):
    def wrapper(func):
        @wraps(func)
        def another_wrapper(self, *args, **kwargs):
            if not isinstance(self, BaseMetric):
                raise RuntimeError("Decorator sync_all_reduce should be used on "
                                   "ignite.metric.Metric class methods only")

            if len(attrs) > 0 and not self._is_reduced:
                for attr in attrs:
                    t = getattr(self, attr, None)
                    if t is not None:
                        t = self._sync_all_reduce(t)
                        self._is_reduced = True
                        setattr(self, attr, t)

            return func(self, *args, **kwargs)

        return another_wrapper

    wrapper._decorated = True
    return wrapper

def reinit__is_reduced(func):
    @wraps(func)
    def wrapper(self, *args, **kwargs):
        func(self, *args, **kwargs)
        self._is_reduced = False

    wrapper._decorated = True
    return wrapper