Shortcuts

Engine

from abc import abstractmethod

from ignite_framework.engines import BaseEngine
from ignite_framework.feature_dev_tools.state_objects import StateBooleanVariable, StateObject, StateVariable, \
    StateIntCounter, StateIterationCounter
from ignite_framework.exceptions import FeatureValueError, FeatureTypeError


class Engine(BaseEngine):
    """
    The default engine.

    Args:
        name:
        process:
        dataloader:
        engine_run_started_ref: `state object` to which's callbacks `Engine().run` will be added/attached
        n_samples:
        n_epochs:
    """
    engine_run_started = StateBooleanVariable(initial_value=False)
    n_epoch_started = StateIterationCounter()
    n_iteration_started = StateIterationCounter()
    n_samples = StateIntCounter()
    n_iteration_completed = StateIterationCounter()
    n_epoch_completed = StateIterationCounter()
    engine_run_completed = StateBooleanVariable()
    output = StateVariable()
    batch = StateObject()
    terminating_run = StateBooleanVariable()
    terminating_epoch = StateBooleanVariable()

    def __init__(self, name, process, dataloader=None, engine_run_started_ref=None,
                 n_samples=None, n_epochs=4):
        # Not automatic or assisted naming for `Engine`

        # Initialize base class
        super().__init__(name)

        # Set instance attributes
        # Bound arg `process(engine)` as bound method so it becomes `self.process()` with argument `engine` -> `self`
        self.process = process.__get__(self, self.__class__)
        self._set_dataloader_and_params(dataloader)
        self.n_total_samples = n_samples
        self.n_total_epochs = n_epochs

        # Set callback
        self._set_callbacks(caller_ref=engine_run_started_ref)

    @abstractmethod
    def process(self):
        """
        Process function executed in each iteration step.

        `self.process` is overwritten during initialization by `process` argument

        Returns:
            no `return`, all results are passed to `self.output`
        """

    def run(self, dataloader=None): # , n_total_samples=None, n_total_epochs=None
        """
        Called to start process loop (start engine). `run` is added to the `engine_run_started_ref` callbacks
        during intialization.

        Args:
            dataloader: Only required if `dataloader` was not set during intialization or initial dataloader
            should be overwritten

        Returns:

        """
        if dataloader:
            self._set_dataloader_and_params(dataloader)
        if self.dataloader is None:
            raise FeatureValueError(
                'The `dataloader` of the engine `{name}` is not set, yet. Please add a dataloader to either the '
                'initialization arguments or to `{name}.run(dataloader)'.format(name=self.name))
        self.engine_run_started = True
        for self.n_epoch_started in range(self.n_total_epochs):
            # State variable `self.n_iteration_started` is set by the `for`-loop to current iteration and reseted also
            # to `0` after each epoch
            for self.n_iteration_started, self.batch in enumerate(self.dataloader):
                self.output = self.process()
                # print('THE ENGINE OUTPUT `self.output = {}`'.format(self.output))
                self.n_samples += self.batch_size
                self.n_iteration_completed += 1
                if (self.n_total_samples is not None and self.n_samples > self.n_total_samples)\
                        or (self.terminating_run or self.terminating_epoch):
                    break
            self.terminating_epoch = False
            self.n_epoch_completed += 1
            # Terminate engine run after `self.n_total_samples` of samples was
            if self.n_total_samples and self.n_samples > self.n_total_samples:
                self.terminate_run()
            if self.terminating_run:
                break
        else:
            self.n_epoch_completed += 1
        self.engine_run_completed = True

    def terminate_run(self):
        self.terminating_run = True

    def terminate_epoch(self):
        self.terminating_epoch = True

    def _set_dataloader_and_params(self, dataloader):
        if dataloader is not None:
            try:
                self.dataloader = dataloader.component
                # NOTE: 'batch_size' of dataloader cannot be changed once set,
                #       so a `batch_size_ref` would not make sense here
                self.batch_size = dataloader.batch_size
                self.n_dataloader_samples = len(self.dataloader)
                return
            except AttributeError:
                raise FeatureTypeError(
                    'For full integration please attach the dataloader to `state.dataloaders` simple writing '
                    'e.g. `state.dataloaders.trainer_dataloader = trainer_dataloader`. You can alternatively use'
                    'a standard `with state.dataloaders as d: ...` in case you have more dataloaders.')
        self.dataloader = None
        self.batch_size = None
        self.n_dataloader_samples = None

    def _set_callbacks(self, caller_ref):
        if caller_ref != None:
            # Note: setting a bso or ref value to a bso-ctr with a `run`-method equals setting <bso-ctr>.run
            caller_ref.caller_name = self