# First step: Always import the `state`, either the default `state` or any desired customized `state`
from ignite_framework.states import state
from ignite_framework.engines import Engine
from ignite_framework.feature_dev_tools.state_objects import StateObjectsReference
import torch
# ==========================
# STATE
# ==========================
## INSTANCIATE AN `ENGINE` ##
# Define the `Engine` process method
def average_batch(engine):
return engine.batch / len(engine.batch)
## Automatic categorization of frame state object containers ##
# e.g. `Engine` in `state.engines`, `OutputHandler` in
# `state.output_handlers`, 'Chart' in `state.tensorboard` etc.
# Initialize the `Engine`...
Engine(name='trainer', process=average_batch)
# ... which will be automatically attached to `state` in the
# correct state container `state.engines`
print(state.engines.trainer)
## Shortcut functionality of `state` ##
# `state` automatically detects unique object names and
# creates shortcuts in `state` and the state container that owns
# for example `trainer` and its bsos (frame state objects)
# Full calling expression of the frame state object (bso) `n_epoch_started`
print(state.engines.trainer.n_epoch_started)
# Shortcut
print(state.trainer.n_epoch_started)
# Shorter shortcut
print(state.n_epoch_started)
# Compare returned shortcut object with longish version
print(id(state.engines.trainer.n_epoch_started) == id(state.n_epoch_started))
# ========================================
# BASE STATE OBJECT FEATURES
# ========================================
## CALLBACKS (EVENT SYSTEM) ##
# Define a callable
def print_triggered():
print('`state.n_epoch_started` was triggered!')
# Attach callable to bso with overload feature detecting the value type
state.n_epoch_started = print_triggered
# Define and attach with decorator
@state.n_epoch_started_callbacks.append
def print_another_callback():
print('`state.n_epoch_started` has another callback now!')
# Set state object value and trigger callbacks
state.n_epoch_started += 1
# Value types set as callabks
state.n_epoch_started = ('manual description', print_triggered)
state.n_epoch_started = (print_triggered, [], {})
state.n_epoch_started = (print_triggered, {})
state.n_epoch_started = ('max. callback with description, func, args, kwargs', print_triggered, [], {})
# `CallbackList` is subclass of `list`
state.n_epoch_started_callbacks.clear()
def print_triggered_every_2():
print('`state.n_epoch_completed_every_2` was triggered!')
## ONCE / EVERY SUFFIXES ##
# Once/Every suffix-overload function
state.n_epoch_completed_every_2 = print_triggered_every_2
# Alternative
state.n_epoch_started_once = 42
for n in range(10):
print(n)
state.n_epoch_completed += 1
## REFERENCE ##
# Referencing state objects of between state object containers
ref = state.n_epoch_started_ref
# Equivalent: `state.n_epoch_started`
print(ref.caller_name)
# Equivalent: `state.n_epoch_started_callbacks`
print(ref.caller_name_callbacks)
# Equivalent: `state.n_epoch_started = ('attached_with ref', print_triggered)`
ref.caller_name = ('attached with ref', print_triggered)
# Equivalent: `state.n_epoch_started += 1`
ref.caller_name += 1
# Current syntac limitation: ref does not take `_once/ever`
# NOTE: This can be fixed but 3 framework modules have to be merged together
ref.caller_name_once = 9
print(state.transitions.get_bso_ctr_names())
# Reset state object to initial value
del state.n_epoch_completed
## BSO SYNCHRONIZATION ##
# Synchronize parameters: only in one direction to avoid recursion error
state.n_epoch_started = state.n_epoch_completed_ref
# Change bso 'in sync-direction'
state.n_epoch_completed += 1
# Results in changing value of `state.n_epoch_started` and triggering its callbacks
print(state.n_epoch_started)
# But does not sync in other direction
state.n_epoch_started += 1
print(state.n_epoch_completed)
### PARAMETTRIZED STATE OBJECT GETTER/SETTER ###
n_every = 2
n_once = 55
filter = 'every'
state_object_name = 'n_epoch_started'
overload_feature = 'callbacks'
# Regular use case
state.trainer.get('n_epoch_started_every', n_every)
# No limits to number of parameters
# NOTE:
# - The `.set()` understands the last argument as value and all preceding arguments as state object name parts to be
# joined to one string with `'_'`
# - This here below creates a new transition `state.transition.trainer_n_epoch_started_once_55`
# and appends the function `print_another_callback`
# to `state.trainer_n_epoch_started_once_55.n_conditional_counts_callbacks`
state.trainer.set('n_epoch_started_once', n_once, print_another_callback)
# Print the callbacks list with `.get()` from `state` and `state.engines` and `state.engines.trainer`
print(state.get('engines').get('trainer').get(state_object_name, filter, n_every, overload_feature))
# Append another function to a callback with parametrized state object name
state.trainer.set(state_object_name, filter, n_once, print_another_callback)
# Print the callbacks (here without parameters, but does not matter how)
print(state.trainer.n_epoch_started_every_55_callbacks)
# =====================
# STATE FEATURES
# =====================
## SHORTCUTS ##
# Automatic deletion of shortcuts when object name becomes redundent by e.g. adding a new feature
print(state.n_iteration_completed)
Engine(name='another_engine', process=average_batch)
try:
print(state.n_iteration_completed)
except AttributeError:
print('`state` shortcut `state.n_iteartion_completed` was automatically deleted due to ambiguity.')
# But unique shortcuts remain
print(state.trainer.n_iteration_completed)
print(state.another_engine.n_iteration_completed)
## ORGANIZATIONAL ASSISTANCE ##
with state.configs as c:
c.hardware_config_value = 123
c.n_gpus = 199
@c
def check_all_configs_were_set():
for config_name in state.user_defined_configs.get_bso_names():
print('state.configs.user_defined_configs.{} = {}'.format(config_name, state.get(config_name)))
for helper_func_name in state.configs.helpers.get_bso_names():
print('state.configs.helper.{} is added.'.format(helper_func_name))
state.configs.extra_path = 'this/is/a/extra/path'
state.params.more_user_params = 199
state.configs.user_defined_configs.directly_assigned = 51
state.check_all_configs_were_set()
# All indented values will be attached to `state.engines.hyperparams` as `StateParameter` (if not defined explicitly as different state object)
with state.engines as e:
# Any non-state-oject will be assigned as `StateParameters(initial_value=value)`
e.evaluator_batch_size = 20
e.engine_param = 42
@e
def max_process(engine):
return float(max(engine.batch))
@e
def double_value(value):
try:
return 2 * value
except TypeError:
return value
@e
def quadro_value(value):
try:
return 4 * value
except TypeError:
return value
@e
def half_value(value):
try:
return value / 2
except TypeError:
return value
## AUTOMATIC CATEGORIZATION ##
from ignite_framework.output_handlers import OutputHandler, ScalarsChartOutputHandler
from ignite_framework.metrics import AverageOutput, OutputMetric, RunningAverage
Engine(name='evaluator',
process=state.max_process,
engine_run_started_ref=state.trainer.n_iteration_completed_every_100_ref)
OutputHandler(input_refs=state.trainer.output_ref,
transform_input_func=double_value)
# Note: `state.get(name)` is just a nicer call for `getattr(state, name)`, also works for state containers
AverageOutput(metric_name='reg_average_loss', input_ref=state.trainer_double_value.output_ref,
started_ref=state.trainer.n_iteration_completed_every_100_ref,
completed_ref=state.trainer.n_iteration_completed_every_100_ref)
OutputHandler(input_refs=state.trainer_double_value_reg_average_loss.output_ref,
transform_input_func=half_value)
ScalarsChartOutputHandler(input_refs=state.trainer_double_value_reg_average_loss.output_ref)
RunningAverage(metric_name='double_lc_running_loss',
input_ref=state.trainer_double_value.output_ref)
OutputMetric(metric_name='reg_normal_loss',
input_ref=state.trainer_double_value.output_ref)
## EXTERNAL CLASS INTEGRATION ##
# Initiate a (random sample) dataloader
from torch.utils.data.dataloader import DataLoader
from random import randint
state.dataloaders.trainer_dataloader = \
DataLoader(dataset=[randint(0, 999) for i in range(10000)],
batch_size=10)
# For adding multiple components, or bsos
with state.dataloaders as d:
d.evaluator_dataloader = \
DataLoader(dataset=[randint(0, 99) for i in range(10)],
batch_size=10)
print(state.dataloaders.trainer_dataloader.component.dataset == state.trainer_dataloader.dataset)
try:
state.dataloaders.model = torch.nn.Module()
except Exception as e:
print(e)
# Alternatively for single components
state.modules.model = torch.nn.Module()
# NEW CLASS INTEGRATION
from ignite_framework.feature_dev_tools.state_objects import StateParameter
from ignite_framework.feature_dev_tools.utils import get_component_attr, set_component_attr
# Define a new subclass of `Module`
class UserModule(torch.nn.Module):
def __init__(self):
self._modules = {}
self._parameters = {}
self._buffers = {}
user_module_param = 24
# Define which parameter should be integrated additinally
integrated_attrs = {'user_module_param': (StateParameter, 11, get_component_attr, set_component_attr)}
# NOTE: `get_component_attr` and `set_component_attr` are the default getter/setter for class integration, so they
# can be substituted by `''`:
# `integrated_attrs = {'user_module_param': (StateParameter, 11, '', '')}`
# Integrate
state.modules.integrate_new_state_component_class(component_class=UserModule,
integrated_attrs_args_dict=integrated_attrs)
# Assign new subclass
state.modules.a_user_module = UserModule()
## NAME ASSISTANCE & AUTOMATIC NAMING ##
# Here the assisted/automatic generated name is assigned to a variable
# to use it for defining the input of the successive feature
t_oh = OutputHandler(input_refs=state.trainer.output_ref,
transform_input_func=quadro_value).name
t_oh_m = RunningAverage(metric_name='quadro_lc_running_loss',
input_ref=state.get(t_oh).output_ref).name
ScalarsChartOutputHandler(input_refs=state.get(t_oh_m).output_ref,
transform_input_func=double_value).name
print('breakpoint')