• Docs >
  • quickunderstanding_feature_dev.py
Shortcuts

quickunderstanding_feature_dev.py

# STATE
from ignite_framework.states import state

# BASE CLASS
from ignite_framework.transitions import BaseTransition

# UTILS
from ignite_framework.feature_dev_tools.state_objects import StateDevice, StateVariable
from ignite_framework.feature_dev_tools.utils import get_caller_refs, create_bso_ctr_name_prefix_from_input_ref


# SELECT THE BASE CLASS
# NOTE:
# - The base class will define the state container to which
#   the transition/feature is automatically attached to.
# - Here `BaseTransition` will assign `DoubleTheInput` to `state.transitions`

class DoubleTheInput(BaseTransition):
    # Set up state object (current Ignite: "custom event")
    output = StateVariable()
    device = StateDevice()

    def __init__(self, double_name, input_ref, caller_refs='',
                 device=state.configs.default_configs.device_for_modules_ref):
        # ARGUMENT CHECKING
        # NOTE:
        # - mainly `input_ref(s)` and `caller_ref(s)` are reformatted
        #   (e.g. as list) or set to default values etc.
        # - the `caller_ref(s)` default values is(are) normally the `input_ref(s)`
        caller_refs = get_caller_refs(caller_refs=caller_refs, input_refs=input_ref)

        # AUTOMATIC OR ASSISTED NAMING
        name = '_'.join([create_bso_ctr_name_prefix_from_input_ref(input_ref=input_ref), double_name])

        # BASE CLASS INITIALIZATION
        # NOTE: Initializing the base class start framework support, therefore
        #       to order when it's initialized MATTERS!
        super().__init__(name=name)

        # ASSIGN REFERENCES
        # Assigned input references which will be tracked by the framework
        self._input_ref = input_ref
        # NOTE:
        # - `device` could also be a fixed value, e.g. `gpu` or `cpu'
        # - passing in a reference synchronizes `self.device` with the ref, otherwise fixed value is set
        self.device = device

        # Everything else
        self._double_factor = 2

        # APPEND METHODS TO CALLER REFS CALLBACKS
        # NOTE: The callback appending can be individualized as required
        self._set_callbacks(caller_refs=caller_refs)

    def run(self):
        self.output = self._double_factor * self._input_ref.caller_name

    def _set_callbacks(self, caller_refs):
        for caller_ref in caller_refs:
            # Append the methods as desired to the callbacks
            caller_ref.caller_name = ('short description is helpful', self.run)
            # NOTE:
            # - if bso-container has method `run` then run will be detected
            #   and append when simply `self` is assigned, see below
            # - alternative implementations with automatic description
            #   generation: `caller_ref.caller_name = self`


# BASE CLASS
from ignite_framework.pipelines import BasePipeline

# FEATURE CLASSES
from ignite_framework.output_handlers import OutputHandler
from ignite_framework.metrics import OutputMetric


class CurrentIgniteMetricWithOutputHandler(BasePipeline):
    """
    Demo pipeline dowing nothing but piping e.g. the engine output through an output handler and a metric.
    """
    def __init__(self, name, input_ref, caller_refs='', metric_device=''):
        # No Auto or assisted naming

        # Argument checking
        caller_refs = get_caller_refs(caller_refs=caller_refs, input_refs=input_ref)

        # Initialize base class
        super().__init__(name=name, device=metric_device)

        # SAVE NAME OF PIPELINE BSO-CONTAINERS TO `self.composed_bso_ctrs` ordered dictionary
        # NOTE: The name of each bso-container must be related to the pipeline `name`
        self.composed_bso_ctrs['output_handler'] = OutputHandler(input_refs=input_ref,
                                                                 output_handler_name=name,
                                                                 caller_refs=caller_refs).name
        self.composed_bso_ctrs['metric'] = OutputMetric(metric_name=name + '_metric',
                                                        input_ref=input_ref,
                                                        caller_refs='')

        # All callbacks are set by `OutputHandler` and `OutputMetric`