from abc import abstractmethod
from ignite_framework.engines import BaseEngine
from ignite_framework.feature_dev_tools.state_objects import StateBooleanVariable, StateObject, StateVariable, \
StateIntCounter, StateIterationCounter
from ignite_framework.exceptions import FeatureValueError, FeatureTypeError
class Engine(BaseEngine):
"""
The default engine.
Args:
name:
process:
dataloader:
engine_run_started_ref: `state object` to which's callbacks `Engine().run` will be added/attached
n_samples:
n_epochs:
"""
engine_run_started = StateBooleanVariable(initial_value=False)
n_epoch_started = StateIterationCounter()
n_iteration_started = StateIterationCounter()
n_samples = StateIntCounter()
n_iteration_completed = StateIterationCounter()
n_epoch_completed = StateIterationCounter()
engine_run_completed = StateBooleanVariable()
output = StateVariable()
batch = StateObject()
terminating_run = StateBooleanVariable()
terminating_epoch = StateBooleanVariable()
def __init__(self, name, process, dataloader=None, engine_run_started_ref=None,
n_samples=None, n_epochs=4):
# Not automatic or assisted naming for `Engine`
# Initialize base class
super().__init__(name)
# Set instance attributes
# Bound arg `process(engine)` as bound method so it becomes `self.process()` with argument `engine` -> `self`
self.process = process.__get__(self, self.__class__)
self._set_dataloader_and_params(dataloader)
self.n_total_samples = n_samples
self.n_total_epochs = n_epochs
# Set callback
self._set_callbacks(caller_ref=engine_run_started_ref)
@abstractmethod
def process(self):
"""
Process function executed in each iteration step.
`self.process` is overwritten during initialization by `process` argument
Returns:
no `return`, all results are passed to `self.output`
"""
def run(self, dataloader=None): # , n_total_samples=None, n_total_epochs=None
"""
Called to start process loop (start engine). `run` is added to the `engine_run_started_ref` callbacks
during intialization.
Args:
dataloader: Only required if `dataloader` was not set during intialization or initial dataloader
should be overwritten
Returns:
"""
if dataloader:
self._set_dataloader_and_params(dataloader)
if self.dataloader is None:
raise FeatureValueError(
'The `dataloader` of the engine `{name}` is not set, yet. Please add a dataloader to either the '
'initialization arguments or to `{name}.run(dataloader)'.format(name=self.name))
self.engine_run_started = True
for self.n_epoch_started in range(self.n_total_epochs):
# State variable `self.n_iteration_started` is set by the `for`-loop to current iteration and reseted also
# to `0` after each epoch
for self.n_iteration_started, self.batch in enumerate(self.dataloader):
self.output = self.process()
# print('THE ENGINE OUTPUT `self.output = {}`'.format(self.output))
self.n_samples += self.batch_size
self.n_iteration_completed += 1
if (self.n_total_samples is not None and self.n_samples > self.n_total_samples)\
or (self.terminating_run or self.terminating_epoch):
break
self.terminating_epoch = False
self.n_epoch_completed += 1
# Terminate engine run after `self.n_total_samples` of samples was
if self.n_total_samples and self.n_samples > self.n_total_samples:
self.terminate_run()
if self.terminating_run:
break
else:
self.n_epoch_completed += 1
self.engine_run_completed = True
def terminate_run(self):
self.terminating_run = True
def terminate_epoch(self):
self.terminating_epoch = True
def _set_dataloader_and_params(self, dataloader):
if dataloader is not None:
try:
self.dataloader = dataloader.component
# NOTE: 'batch_size' of dataloader cannot be changed once set,
# so a `batch_size_ref` would not make sense here
self.batch_size = dataloader.batch_size
self.n_dataloader_samples = len(self.dataloader)
return
except AttributeError:
raise FeatureTypeError(
'For full integration please attach the dataloader to `state.dataloaders` simple writing '
'e.g. `state.dataloaders.trainer_dataloader = trainer_dataloader`. You can alternatively use'
'a standard `with state.dataloaders as d: ...` in case you have more dataloaders.')
self.dataloader = None
self.batch_size = None
self.n_dataloader_samples = None
def _set_callbacks(self, caller_ref):
if caller_ref != None:
# Note: setting a bso or ref value to a bso-ctr with a `run`-method equals setting <bso-ctr>.run
caller_ref.caller_name = self