from ignite_framework.states import state
from ignite_framework.engines import Engine
from ignite_framework.output_handlers import OutputHandler, ScalarsChartOutputHandler
from ignite_framework.metrics import AverageOutput, OutputMetric, RunningAverage
from ignite_framework.tensorboard import EnginesMetricsCharts, EnginesMetricsComparisonCharts
from torch.utils.data.dataloader import DataLoader
from random import randint
state.dataloaders.trainer_dataloader = DataLoader(dataset=[randint(0, 99) for i in range(10000)],
batch_size=10)
state.dataloaders.evaluator_dataloader = DataLoader(dataset=[randint(0, 99) for i in range(10)],
batch_size=10)
with state.engines as e:
@e
def average_process(engine):
return float(sum(engine.batch) / engine.batch_size)
@e
def max_process(engine):
return float(max(engine.batch))
@e
def double_value(value):
try:
return 2 * value
except TypeError:
return value
@e
def quadro_value(value):
try:
return 4 * value
except TypeError:
return value
@e
def transform(value):
try:
return value
except TypeError:
return value
@e
def half_value(value):
try:
return value / 2
except TypeError:
return value
@e
def quarter_value(value):
try:
return value / 4
except TypeError:
return value
del e
Engine(name='trainer', process=average_process, dataloader=state.trainer_dataloader,
engine_run_started_ref=state.state_run_started_ref)
Engine(name='evaluator', process=max_process, dataloader=state.evaluator_dataloader,
engine_run_started_ref=state.trainer.n_iteration_completed_every_100_ref)
t_oh1 = OutputHandler(input_refs=state.trainer.output_ref, transform_input_func=double_value).name
# Note: `state.get(name)` is just a nicer call for `getattr(state, name)`, also works for state containers
t_oh1_m1 = AverageOutput(metric_name='reg_average_loss', input_ref=state.trainer_double_value.output_ref,
started_ref=state.trainer.n_iteration_completed_every_100_ref,
completed_ref=state.trainer.n_iteration_completed_every_100_ref).name
state.trainer.n_iteration_completed_every_100_callbacks.reorder(order_idcs=[2, 1, 0])
t_oh1_m1_oh1 = OutputHandler(input_refs=state.get(t_oh1_m1).output_ref, transform_input_func=half_value).name
t_oh1_m1_oh2 = ScalarsChartOutputHandler(input_refs=state.get(t_oh1_m1).output_ref).name
t_oh1_m2 = RunningAverage(metric_name='double_lc_running_loss', input_ref=state.trainer_double_value.output_ref).name
t_oh1_m2 = OutputMetric(metric_name='reg_normal_loss', input_ref=state.trainer_double_value.output_ref).name
t_oh2 = OutputHandler(input_refs=state.trainer.output_ref, transform_input_func=quadro_value).name
t_oh2_m3 = RunningAverage(metric_name='quadro_lc_running_loss', input_ref=state.get(t_oh2).output_ref).name
t_oh2_m3_oh = ScalarsChartOutputHandler(input_refs=state.get(t_oh2_m3).output_ref,
transform_input_func=double_value).name
e_oh1 = OutputHandler(input_refs=state.evaluator.output_ref, transform_input_func=half_value).name
e_oh1_m1 = OutputMetric(metric_name='reg_normal_loss', input_ref=state.get(e_oh1).output_ref).name
e_oh1_m1_oh1 = ScalarsChartOutputHandler(input_refs=state.get(e_oh1_m1).output_ref,
transform_input_func=quarter_value).name
e_oh1_m2 = AverageOutput(metric_name='reg_average_loss', input_ref=state.get(e_oh1).output_ref,
started_ref=state.evaluator.n_iteration_completed_every_50_ref,
completed_ref=state.evaluator.n_iteration_completed_every_50_ref).name
state.evaluator.n_iteration_completed_every_50_callbacks.reorder(order_idcs=[1, 0])
e_oh1_m1_oh2 = OutputHandler(input_refs=state.get(e_oh1_m1_oh1).output_ref, transform_input_func=half_value).name
EnginesMetricsComparisonCharts(x_axis_ref=state.trainer.n_samples_ref)
EnginesMetricsCharts()
# Get the full data flow graph of state
print(state.get_graphs_in_state())
# Get all data flow paths of state
print(state.get_paths_in_state())
state.run()