• Docs >
  • quickunderstanding_app.py
Shortcuts

quickunderstanding_app.py

# First step: Always import the `state`, either the default `state` or any desired customized `state`
from ignite_framework.states import state

from ignite_framework.engines import Engine
from ignite_framework.feature_dev_tools.state_objects import StateObjectsReference
import torch


# ==========================
# STATE
# ==========================

## INSTANCIATE AN `ENGINE` ##

# Define the `Engine` process method
def average_batch(engine):
    return engine.batch / len(engine.batch)

## Automatic categorization of frame state object containers ##

# e.g. `Engine` in `state.engines`, `OutputHandler` in
# `state.output_handlers`, 'Chart' in `state.tensorboard` etc.

# Initialize the `Engine`...
Engine(name='trainer', process=average_batch)
# ... which will be automatically attached to `state` in the
# correct state container `state.engines`
print(state.engines.trainer)


## Shortcut functionality of `state` ##

# `state` automatically detects unique object names and
# creates shortcuts in `state` and the state container that owns
# for example `trainer` and its bsos (frame state objects)

# Full calling expression of the frame state object (bso) `n_epoch_started`
print(state.engines.trainer.n_epoch_started)
# Shortcut
print(state.trainer.n_epoch_started)
# Shorter shortcut
print(state.n_epoch_started)
# Compare returned shortcut object with longish version
print(id(state.engines.trainer.n_epoch_started) == id(state.n_epoch_started))


# ========================================
# BASE STATE OBJECT FEATURES
# ========================================

## CALLBACKS (EVENT SYSTEM) ##

# Define a callable
def print_triggered():
    print('`state.n_epoch_started` was triggered!')

# Attach callable to bso with overload feature detecting the value type
state.n_epoch_started = print_triggered

# Define and attach with decorator
@state.n_epoch_started_callbacks.append
def print_another_callback():
    print('`state.n_epoch_started` has another callback now!')

# Set state object value and trigger callbacks
state.n_epoch_started += 1

# Value types set as callabks
state.n_epoch_started = ('manual description', print_triggered)
state.n_epoch_started = (print_triggered, [], {})
state.n_epoch_started = (print_triggered, {})
state.n_epoch_started = ('max. callback with description, func, args, kwargs', print_triggered, [], {})

# `CallbackList` is subclass of `list`
state.n_epoch_started_callbacks.clear()

def print_triggered_every_2():
    print('`state.n_epoch_completed_every_2` was triggered!')


## ONCE / EVERY SUFFIXES ##

# Once/Every suffix-overload function
state.n_epoch_completed_every_2 = print_triggered_every_2

# Alternative
state.n_epoch_started_once = 42

for n in range(10):
    print(n)
    state.n_epoch_completed += 1


## REFERENCE ##

# Referencing state objects of between state object containers
ref = state.n_epoch_started_ref

# Equivalent: `state.n_epoch_started`
print(ref.caller_name)

# Equivalent: `state.n_epoch_started_callbacks`
print(ref.caller_name_callbacks)

# Equivalent: `state.n_epoch_started = ('attached_with ref', print_triggered)`
ref.caller_name = ('attached with ref', print_triggered)

# Equivalent: `state.n_epoch_started += 1`
ref.caller_name += 1

# Current syntac limitation: ref does not take `_once/ever`
# NOTE: This can be fixed but 3 framework modules have to be merged together
ref.caller_name_once = 9
print(state.transitions.get_bso_ctr_names())

# Reset state object to initial value
del state.n_epoch_completed


## BSO SYNCHRONIZATION ##

# Synchronize parameters: only in one direction to avoid recursion error
state.n_epoch_started = state.n_epoch_completed_ref

# Change bso 'in sync-direction'
state.n_epoch_completed += 1
# Results in changing value of `state.n_epoch_started` and triggering its callbacks
print(state.n_epoch_started)
# But does not sync in other direction
state.n_epoch_started += 1
print(state.n_epoch_completed)


### PARAMETTRIZED STATE OBJECT GETTER/SETTER ###

n_every = 2
n_once = 55
filter = 'every'
state_object_name = 'n_epoch_started'
overload_feature = 'callbacks'

# Regular use case
state.trainer.get('n_epoch_started_every', n_every)
# No limits to number of parameters
# NOTE:
# - The `.set()` understands the last argument as value and all preceding arguments as state object name parts to be
#   joined to one string with `'_'`
# - This here below creates a new transition `state.transition.trainer_n_epoch_started_once_55`
#   and appends the function `print_another_callback`
#   to `state.trainer_n_epoch_started_once_55.n_conditional_counts_callbacks`
state.trainer.set('n_epoch_started_once', n_once, print_another_callback)
# Print the callbacks list with `.get()` from `state` and `state.engines` and `state.engines.trainer`
print(state.get('engines').get('trainer').get(state_object_name, filter, n_every, overload_feature))
# Append another function to a callback with parametrized state object name
state.trainer.set(state_object_name, filter, n_once, print_another_callback)
# Print the callbacks (here without parameters, but does not matter how)
print(state.trainer.n_epoch_started_every_55_callbacks)


# =====================
# STATE FEATURES
# =====================


## SHORTCUTS ##

# Automatic deletion of shortcuts when object name becomes redundent by e.g. adding a new feature
print(state.n_iteration_completed)
Engine(name='another_engine', process=average_batch)
try:
    print(state.n_iteration_completed)
except AttributeError:
    print('`state` shortcut `state.n_iteartion_completed` was automatically deleted due to ambiguity.')
# But unique shortcuts remain
print(state.trainer.n_iteration_completed)
print(state.another_engine.n_iteration_completed)


## ORGANIZATIONAL ASSISTANCE ##

with state.configs as c:
    c.hardware_config_value = 123
    c.n_gpus = 199

    @c
    def check_all_configs_were_set():
        for config_name in state.user_defined_configs.get_bso_names():
            print('state.configs.user_defined_configs.{} = {}'.format(config_name, state.get(config_name)))
        for helper_func_name in state.configs.helpers.get_bso_names():
                print('state.configs.helper.{} is added.'.format(helper_func_name))

state.configs.extra_path = 'this/is/a/extra/path'
state.params.more_user_params = 199

state.configs.user_defined_configs.directly_assigned = 51

state.check_all_configs_were_set()

# All indented values will be attached to `state.engines.hyperparams` as `StateParameter` (if not defined  explicitly as different state object)
with state.engines as e:

    # Any non-state-oject will be assigned as `StateParameters(initial_value=value)`
    e.evaluator_batch_size = 20
    e.engine_param = 42

    @e
    def max_process(engine):
        return float(max(engine.batch))

    @e
    def double_value(value):
        try:
            return 2 * value
        except TypeError:
            return value

    @e
    def quadro_value(value):
        try:
            return 4 * value
        except TypeError:
            return value

    @e
    def half_value(value):
        try:
            return value / 2
        except TypeError:
            return value


## AUTOMATIC CATEGORIZATION ##

from ignite_framework.output_handlers import OutputHandler, ScalarsChartOutputHandler
from ignite_framework.metrics import AverageOutput, OutputMetric, RunningAverage

Engine(name='evaluator',
       process=state.max_process,
       engine_run_started_ref=state.trainer.n_iteration_completed_every_100_ref)

OutputHandler(input_refs=state.trainer.output_ref,
              transform_input_func=double_value)

# Note: `state.get(name)` is just a nicer call for `getattr(state, name)`, also works for state containers
AverageOutput(metric_name='reg_average_loss', input_ref=state.trainer_double_value.output_ref,
              started_ref=state.trainer.n_iteration_completed_every_100_ref,
              completed_ref=state.trainer.n_iteration_completed_every_100_ref)
OutputHandler(input_refs=state.trainer_double_value_reg_average_loss.output_ref,
              transform_input_func=half_value)
ScalarsChartOutputHandler(input_refs=state.trainer_double_value_reg_average_loss.output_ref)

RunningAverage(metric_name='double_lc_running_loss',
               input_ref=state.trainer_double_value.output_ref)
OutputMetric(metric_name='reg_normal_loss',
             input_ref=state.trainer_double_value.output_ref)


## EXTERNAL CLASS INTEGRATION ##

# Initiate a (random sample) dataloader
from torch.utils.data.dataloader import DataLoader
from random import randint

state.dataloaders.trainer_dataloader = \
    DataLoader(dataset=[randint(0, 999) for i in range(10000)],
               batch_size=10)

# For adding multiple components, or bsos
with state.dataloaders as d:
    d.evaluator_dataloader = \
        DataLoader(dataset=[randint(0, 99) for i in range(10)],
                   batch_size=10)

print(state.dataloaders.trainer_dataloader.component.dataset == state.trainer_dataloader.dataset)

try:
    state.dataloaders.model = torch.nn.Module()
except Exception as e:
    print(e)

# Alternatively for single components
state.modules.model = torch.nn.Module()

# NEW CLASS INTEGRATION

from ignite_framework.feature_dev_tools.state_objects import StateParameter
from ignite_framework.feature_dev_tools.utils import get_component_attr, set_component_attr

# Define a new subclass of `Module`
class UserModule(torch.nn.Module):
    def __init__(self):
        self._modules = {}
        self._parameters = {}
        self._buffers = {}
        user_module_param = 24

# Define which parameter should be integrated additinally
integrated_attrs = {'user_module_param': (StateParameter, 11, get_component_attr, set_component_attr)}
# NOTE: `get_component_attr` and `set_component_attr` are the default getter/setter for class integration, so they
#       can be substituted by `''`:
#       `integrated_attrs = {'user_module_param': (StateParameter, 11, '', '')}`

# Integrate
state.modules.integrate_new_state_component_class(component_class=UserModule,
                                                  integrated_attrs_args_dict=integrated_attrs)

# Assign new subclass
state.modules.a_user_module = UserModule()


## NAME ASSISTANCE  & AUTOMATIC NAMING ##

# Here the assisted/automatic generated name is assigned to a variable
# to use it for defining the input of the successive feature
t_oh = OutputHandler(input_refs=state.trainer.output_ref,
                     transform_input_func=quadro_value).name
t_oh_m = RunningAverage(metric_name='quadro_lc_running_loss',
                        input_ref=state.get(t_oh).output_ref).name
ScalarsChartOutputHandler(input_refs=state.get(t_oh_m).output_ref,
                          transform_input_func=double_value).name


print('breakpoint')