• Docs >
  • automatic_metric_chart_creation.py
Shortcuts

automatic_metric_chart_creation.py

from ignite_framework.states import state
from ignite_framework.engines import Engine
from ignite_framework.output_handlers import OutputHandler, ScalarsChartOutputHandler
from ignite_framework.metrics import AverageOutput, OutputMetric, RunningAverage
from ignite_framework.tensorboard import EnginesMetricsCharts, EnginesMetricsComparisonCharts

from torch.utils.data.dataloader import DataLoader
from random import randint

state.dataloaders.trainer_dataloader = DataLoader(dataset=[randint(0, 99) for i in range(10000)],
                                                  batch_size=10)

state.dataloaders.evaluator_dataloader = DataLoader(dataset=[randint(0, 99) for i in range(10)],
                                                    batch_size=10)

with state.engines as e:
    @e
    def average_process(engine):
        return float(sum(engine.batch) / engine.batch_size)


    @e
    def max_process(engine):
        return float(max(engine.batch))


    @e
    def double_value(value):
        try:
            return 2 * value
        except TypeError:
            return value


    @e
    def quadro_value(value):
        try:
            return 4 * value
        except TypeError:
            return value


    @e
    def transform(value):
        try:
            return value
        except TypeError:
            return value


    @e
    def half_value(value):
        try:
            return value / 2
        except TypeError:
            return value


    @e
    def quarter_value(value):
        try:
            return value / 4
        except TypeError:
            return value

del e

Engine(name='trainer', process=average_process, dataloader=state.trainer_dataloader,
       engine_run_started_ref=state.state_run_started_ref)

Engine(name='evaluator', process=max_process, dataloader=state.evaluator_dataloader,
       engine_run_started_ref=state.trainer.n_iteration_completed_every_100_ref)

t_oh1 = OutputHandler(input_refs=state.trainer.output_ref, transform_input_func=double_value).name

# Note: `state.get(name)` is just a nicer call for `getattr(state, name)`, also works for state containers
t_oh1_m1 = AverageOutput(metric_name='reg_average_loss', input_ref=state.trainer_double_value.output_ref,
                         started_ref=state.trainer.n_iteration_completed_every_100_ref,
                         completed_ref=state.trainer.n_iteration_completed_every_100_ref).name
state.trainer.n_iteration_completed_every_100_callbacks.reorder(order_idcs=[2, 1, 0])
t_oh1_m1_oh1 = OutputHandler(input_refs=state.get(t_oh1_m1).output_ref, transform_input_func=half_value).name
t_oh1_m1_oh2 = ScalarsChartOutputHandler(input_refs=state.get(t_oh1_m1).output_ref).name

t_oh1_m2 = RunningAverage(metric_name='double_lc_running_loss', input_ref=state.trainer_double_value.output_ref).name
t_oh1_m2 = OutputMetric(metric_name='reg_normal_loss', input_ref=state.trainer_double_value.output_ref).name

t_oh2 = OutputHandler(input_refs=state.trainer.output_ref, transform_input_func=quadro_value).name
t_oh2_m3 = RunningAverage(metric_name='quadro_lc_running_loss', input_ref=state.get(t_oh2).output_ref).name
t_oh2_m3_oh = ScalarsChartOutputHandler(input_refs=state.get(t_oh2_m3).output_ref,
                                        transform_input_func=double_value).name

e_oh1 = OutputHandler(input_refs=state.evaluator.output_ref, transform_input_func=half_value).name
e_oh1_m1 = OutputMetric(metric_name='reg_normal_loss', input_ref=state.get(e_oh1).output_ref).name
e_oh1_m1_oh1 = ScalarsChartOutputHandler(input_refs=state.get(e_oh1_m1).output_ref,
                                         transform_input_func=quarter_value).name
e_oh1_m2 = AverageOutput(metric_name='reg_average_loss', input_ref=state.get(e_oh1).output_ref,
                         started_ref=state.evaluator.n_iteration_completed_every_50_ref,
                         completed_ref=state.evaluator.n_iteration_completed_every_50_ref).name
state.evaluator.n_iteration_completed_every_50_callbacks.reorder(order_idcs=[1, 0])
e_oh1_m1_oh2 = OutputHandler(input_refs=state.get(e_oh1_m1_oh1).output_ref, transform_input_func=half_value).name

EnginesMetricsComparisonCharts(x_axis_ref=state.trainer.n_samples_ref)

EnginesMetricsCharts()

# Get the full data flow graph of state
print(state.get_graphs_in_state())

# Get all data flow paths of state
print(state.get_paths_in_state())

state.run()