import sys
from argparse import ArgumentParser
import logging
from ignite_framework.states import state
from ignite_framework.engines import Engine
from ignite_framework.metrics import Accuracy, Loss
from ignite_framework.output_handlers import OutputHandler
from ignite_framework.tensorboard import ScalarChart, ScalarsChart
from ignite_framework.utils import convert_tensor
import torch
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F
from torch.optim import SGD
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from typing import Sequence, Union, Optional
## LOGGER NOT INTEGRATED YET ##
# BUT it's very easy to integrate, see docu "Quickunderstanding State Customization"
# NOTE: General loggers could also be integrated in a `state.loggers` container, so users could add their preferred
# loggers to their `custom_states`
# Setup engine logger
logger = logging.getLogger("ignite.engine.engine.Engine")
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s %(name)-12s %(levelname)-8s %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
### CONFIGURATIONS ###
#=====================
# Set default config `state.configs.default_config.tensorboard_log_dir` with shortcut
state.tensorboard_log_dir = './tensorboard_log_dir'
# NOTE:
# - This config is a default config of `state.configs.default_configs`, therefore a shortcut exists which is used here
# - If `tensorboard_log_dir` would not yet existed, you would call:
# * either:
# `state.configs.tensorboard_log_dir = 'tensorboard_log_dir'
# * or:
# `with state.configs as c:
# c.tensorboard_log_dir = 'tensorboard_log_dir'
# * both would result in a user defined state object (`StateConfigurationVariable)
# `state.configs.user_defined_configs.tensorboar_lod_dir`
### USER DEFINED PARAMETERS & HELPER FUNCTIONS ###
#==================================================
# NOTE:
# - (non-callable) values parameters are assigned to `state.params.user_defined_params` as `StateParameter`s
# (will be created if not yet exists)
# - callables are assigned to `state.params.helpers` as `StateFunction` (will be created if not yet exists)
with state.params as p:
p.train_batch_size = 64
p.n_eval_epochs = 10
p.eval_batch_size = 1000
p.n_xval_step_samples = 640
p.xval_batch_size = 100
p.n_xval_samples = 200
p.n_metric_logging_step_samples = 100
p.non_blocking = False
# Example for function assigning to `state.params.helpers`
# NOTE: It is arguable how useful adding the functions to `state.params.helpers` is, but it's definitely a feature
# for writing cleaner training scripts defining all (most) functions in one segment. It is not required to
# assign any function to a `helpers` category
@p
def check_required_params():
required_params = ['train_batch_size', 'eval_batch_size', 'n_eval_epochs']
if not all([required_param in state.user_defined_params.get_bso_names() for required_param in required_params]):
raise RuntimeError('Not all required params have been assigned to `state.params`.')
else:
print('All params attached!!!!!')
### APPEND TO CALLBACK ###
#=========================
# Check all params at the end of the state initialization by appending `check_reuqired_params` to
# callbacks of `state.state_init_copmleted`
# NOTE:
# - You can check all callbacks in the debugger in `state.engines.state_status.state_init_completed_callbacks` or...
# - check all callbacks with shortcut `print(state.state_init_completed_callbacks)`
state.state_init_completed = state.check_required_params
# Model
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=-1)
### ALL MODULES INTEGRATION ###
#==============================
# NOTE:
# - all `Module`s are INTEGRSTED as `(Meta)StateComponent` in `state.modules`, meaning:
# - important module parameters are integrated as `StateParameters`
# . the module instance is assigned unchanged to e.g. `state.modules.model.component`, the `component` attribute
# - `state.modules.model` behaves identical to the instance by forwarding the command to
# `state.modules.model.component`
with state.modules as m:
m.model = Net()
m.model.cuda(device=state.device_for_modules)
m.x_entropy = nn.CrossEntropyLoss()
### SYNC PARAMETERS ###
#======================
# NOTE: Non shortcutted command of below: `state.modules.model.device = state.configs.default_configs.device_for_modules_ref`
state.model.device = state.device_for_modules_ref
state.x_entropy.device = state.device_for_modules_ref
### HYPER PARAMETERS & ENGINE HELPER FUNCTIONS ###
#=======================
# NOTE:
# - same as with `state.configs` and `state.params` BUT...
# - (non-callable) values assigned to `state.engines.hyperparams` as `StateParameter` (will be created if not ...)
with state.engines as e:
e.lr = 0.01
e.momentum = 0.5
# NOTE: It is arguable how useful adding the functions to `state.engines.helpers` is, but it's definitely a feature
# for writing cleaner training scripts and structuring your thoughts
@e
def _prepare_batch(batch: Sequence[torch.Tensor], device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False):
"""Prepare batch for training: pass to a device with options.
"""
x, y = batch
return (convert_tensor(x, device=device, non_blocking=non_blocking),
convert_tensor(y, device=device, non_blocking=non_blocking))
@e
def get_labels(value):
return value[2:0:-1]
@e
def _update(engine):
state.model.train()
state.sgd.zero_grad()
# NOTE:
# Instead of `e.prepare_batch` you could also the `with state.engines`-statement into 2 whereas the first
# would assign those call in the other functions later on, e.g. `prepare_batch`. In the 2.`with`-statement
# you could now call `state.prepare_batch` (equal to `state.engines.helpers.prepare_batch`)
x, y = state._prepare_batch(engine.batch, device=state.device_for_modules, non_blocking=state.non_blocking)
y_pred = state.model(x)
loss = state.x_entropy(y_pred, y)
loss.backward()
state.sgd.step()
return x, y, y_pred, loss
@e
def _infer(engine):
state.model.eval()
with torch.no_grad():
x, y = state._prepare_batch(engine.batch, device=state.device_for_modules, non_blocking=state.non_blocking)
y_pred = state.model(x)
return x, y, y_pred
### ALL DATALOADERS ###
#======================
# NOTE:
# - `with`-statement useful for multiple state component/dataloader assignments
# - alternative: `state.dataloaders.trainer_loader = DataLoader(...)`
with state.dataloaders as d:
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
# Trainer dataloader
d.trainer_loader = DataLoader(MNIST(download=True, root=".", transform=data_transform, train=True),
batch_size=state.train_batch_size, shuffle=True)
# Xvalidator dataloader
d.xvalidator_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False),
batch_size=state.eval_batch_size, shuffle=False)
# Evaluator dataloader
d.evaluator_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False),
batch_size=state.eval_batch_size, shuffle=False)
### PARAMETER SYNCHRONIZATION OF INTEGRATED MODULE ###
#=====================================================
# NOTE:
# - The suffix '_ref' of a state object return a reference object of itself
# - Is you set a state object with a reference of another state object, this state objects is synced by the value
# of the reference state object
state.model.device = state.default_configs.device_for_modules_ref
### OPTIMIZER(S) ###
#===================
# NOTE:
# - `Optimizer's are assigned to `state.optimizers`
# - can also be assigned using the `with`-statement
# - NO ADDITIONAL PARAMETERS OR HELPER FUNCTIONS SETTING IN `state.optimizers`
state.optimizers.sgd = SGD(m.model.parameters(), lr=state.lr, momentum=state.momentum)
### ENGONES ###
#==============
# Trainer
# NOTE:
# - Below the maximum setting of default values is used.
# - The `Engine` is assigned to `state.engines.trainer`.
# - The argument `engine_run_started_ref=None` mean that the calling state object is not set, so `Engine` will
# NOT be called. (Default values are always placeholdered by ``''``)
Engine(name='trainer',
process=state._update, dataloader=state.trainer_loader,
engine_run_started_ref=state.state_run_started_ref)
# X-Validator
# NOTE:
# . To parametrize the engine starting callback use:
# `engine_run_started_ref=state.trainer.get('n_samples_every', state.n_xcal_step_smaples, 'ref')`
Engine(name='xvalidator',
process=state._infer,
dataloader=state.xvalidator_loader,
engine_run_started_ref=state.trainer.get('n_samples_every', state.n_xval_step_samples, 'ref'),
n_samples=state.n_xval_samples)
# Evaluator
# NOTE: Here is demonstrated how arguments can also be skipped during initialization and provided afterwards,
# namely `dataloader` and `engine_run_started_ref`
Engine(name='evaluator',
process=state._infer,
n_epochs=state.n_eval_epochs)
# Setting up `state.evaluator`'s dataloader and callback after initialization
state.trainer.engine_run_completed = ('starte manually appended `state.evaluator.run` with `state.evaluator_loader',
state.evaluator.run, {'dataloader': state.evaluator_loader})
### OUTPUT HANDLERS ###
#======================
# Trainer output handler
# NOTE: `
# - output_handler_name` is not set and defaults to `''` which turns on auto-naming
# - of course the output handler is assigned automatically to `state.output_handlers`
# - To avoid bothering about the exact name (which you easily will identify in the debugger) you can fetch the name by
# adding the `.name` at the end and assign it th a name variable, see below
#
# EXTRA NOTE: `caller_ref` or in current Ignite term `event attaching`
# Most features have the argument `caller_ref` which refers to the event that calls/triggers e.g. calculating
# an output or logging a metric. Here the default value of `caller_ref` is `input_ref`, so each time a new input
# value is given by e.g. `state.trainer.output` the output handler of trainer is automatically called/triggered
# and will caclulate its own new output. If also the trainer metric's `caller_ref` is set to default, then the new
# output handler output will call/trigger the new calculation of metric... of course, same game with the
# chart/tensorbaord logger.
trainer_output_handler_name = OutputHandler(input_refs=state.trainer.output_ref,
transform_input_func=get_labels).name
xvalidator_output_handler_name = OutputHandler(input_refs=state.xvalidator.output_ref,
transform_input_func=get_labels).name
evaluator_output_handler_name = OutputHandler(input_refs=state.evaluator.output_ref,
transform_input_func=get_labels).name
print('trainer_output_handler = ' + trainer_output_handler_name)
print('xvalidator_output_handler = ' + xvalidator_output_handler_name)
print('evaluator_output_handler = ' + evaluator_output_handler_name)
### METRICS ###
#==============
# NOTE: Currently, the output handlers are not automated in finding the correct inputs and providing the their outputs
# in a standardized way so a metric could automatically pick the correct output values form the output handlers,
# or directly form the engine output. This has to be implemented, see issue #???
# Trainer
trainer_loss_name = Loss(metric_name='loss',
loss_fn=state.x_entropy,
input_ref=state.get(trainer_output_handler_name).output_ref,
completed_ref=state.trainer.n_samples_every_100_ref).name
# NOTE: If you know the output handler name you can call it directly as shortcut attribute of `state` (see below) or
# as an (real) attribute of `state.output_handlers`
trainer_accuracy_name = Accuracy(metric_name='accuracy',
input_ref=state.trainer_get_labels.output_ref,
completed_ref=state.trainer.n_samples_every_100_ref).name
# X-validator
# NOTE: Here we assume we know the xvalidator output handler name `xvalidator_get_lables`
xvalidator_loss_name = Loss(metric_name='loss',
loss_fn=state.x_entropy,
input_ref=state.xvalidator_get_labels.output_ref,
completed_ref=state.xvalidator.engine_run_completed_ref).name
xvalidator_accuracy_name = Accuracy(metric_name='accuracy',
input_ref=state.xvalidator_get_labels.output_ref,
completed_ref=state.xvalidator.engine_run_completed_ref).name
# Evaluator
evaluator_loss_name = Loss(metric_name='loss',
loss_fn=state.x_entropy,
input_ref=state.get(evaluator_output_handler_name).output_ref,
completed_ref=state.evaluator.engine_run_completed_ref).name
evaluator_accuracy_name = Accuracy(metric_name='accuracy',
input_ref=state.get(evaluator_output_handler_name).output_ref,
completed_ref=state.evaluator.engine_run_completed_ref).name
### CHARTS ###
#=============
# NOTE: The following charts manually implement what 2 high-level-APIs with 1-liner-commands accomplish in
# `mnist_with_tensorboard_and_high_level_apis.py`, just to demonstrate the possibilities of framework automation.
# MANUALLY SET CHART IDENTICAL TO THE COMMAND:
# `EnginesMetricsComparisonCharts(x_axis_ref=state.trainer.n_samples_ref, n_identical_metric_name_suffixes=1)`
# -------------------------------------------------------------------------------------------------------------
# Accuracy charts of all engines, 3 metrics in 1 chart
# NOTE: Here the maximum default values are used, setting as little as neccessary
# EXTRA NOTE: Leaving out the `chart_name` in charts with many `y_axes_refs` usually leads to very longish chart names
ScalarsChart(x_axis_ref=state.trainer.n_samples_ref,
y_axes_refs=[state.get(trainer_accuracy_name).output_ref,
state.get(xvalidator_accuracy_name).output_ref,
state.get(evaluator_accuracy_name).output_ref],
caller_refs=[state.get(trainer_accuracy_name).output_ref,
state.xvalidator.engine_run_completed_ref,
state.evaluator.engine_run_completed_ref])
# Loss charts of all engines, 3 metics in 1 chart
# NOTE:
# - Here maximum manual indivializations are called, some arguments are redundent in the way that they
# equal (or behave equally to) the default value, e.g. the `caller_refs` would result in same chart.
# - The `walltime_ref` uses the state timer `state.maintenance.state_timer.state_run_started_time_in_sec`
ScalarsChart(x_axis_ref=state.trainer.n_samples_ref,
y_axes_refs=[state.get(trainer_loss_name).output_ref,
state.get(xvalidator_loss_name).output_ref,
state.get(evaluator_loss_name).output_ref],
caller_refs=[state.trainer.n_samples_every_100_ref,
state.xvalidator.engine_run_completed_ref,
state.evaluator.engine_run_completed_ref],
y_names=['trainer_loss','xvalidator_loss', 'evaluator_loss'],
chart_name='engines_loss_comparison',
bso_ctr_name_suffix='manually_set_chart',
summary_writer=state.tensorboard.summary_writer,
summary_description='manually written description for summary adding',
walltime_ref=state.state_run_started_time_in_sec_ref)
# MANUALLY SET CHART IDENTICAL TO THE COMMAND:
#
# `EnginesMetricsCharts(x_axes_refs=state.trainer.n_samples_ref, n_identical_metric_name_suffixes=1)`
# ---------------------------------------------------------------------------------------------------
# Trainer
ScalarChart(x_axis_ref=state.trainer.n_samples_ref, y_axis_ref=state.get(trainer_accuracy_name).output_ref)
ScalarChart(x_axis_ref=state.trainer.n_samples_ref, y_axis_ref=state.get(trainer_loss_name).output_ref)
# Xvalidator
ScalarChart(x_axis_ref=state.trainer.n_samples_ref,
y_axis_ref=state.get(xvalidator_accuracy_name).output_ref,
caller_refs=state.xvalidator.engine_run_completed_ref)
ScalarChart(x_axis_ref=state.trainer.n_samples_ref,
y_axis_ref=state.get(xvalidator_loss_name).output_ref,
caller_refs=state.xvalidator.engine_run_completed_ref)
# Evaluator
ScalarChart(x_axis_ref=state.trainer.n_samples_ref,
y_axis_ref=state.get(evaluator_accuracy_name).output_ref,
caller_refs=state.evaluator.engine_run_completed_ref)
ScalarChart(x_axis_ref=state.trainer.n_samples_ref,
y_axis_ref=state.get(evaluator_loss_name).output_ref,
caller_refs=state.evaluator.engine_run_completed_ref)
# NOTE:
# - Did you realize... we got rid of the `other_engine`, even though we were using 3 engines :-)
# - Additionally, for these charts one would need to implement `anaother_x_axis` in the current Ignite, too.
### HARDWARE RESOURCE TRACKING TO BE IMPLEMENTED SOON ###
#========================================================
# NOTE: The current GPU-tracker may have the issues that it is triggered by `Event.ITERATION_COMPLETED` logging
# the GPU-usage at this point (not even max usage). Unfortunately, when the iteration is completed the
# GPU-calculations are also completed, `GpuInfo` should only be measuring the resource usage during downtime.
# Therefore the new implementation will be implmented on a separate thread with slightly randomized logging
# timing.
# if sys.version_info > (3,):
# from ignite.contrib.metrics.gpu_info import GpuInfo
# try:
# GpuInfo().attach(state.trainer)
# except RuntimeError:
# print("INFO: By default, in this example it is possible to log GPU information (used memory, utilization). "
# "As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please "
# "install it : `pip install pynvml`")
state.run()