• Docs >
  • mnist.mnist_with_tensorboard_logger_and_low_level_apis.py
Shortcuts

mnist.mnist_with_tensorboard_logger_and_low_level_apis.py

import sys
from argparse import ArgumentParser
import logging

from ignite_framework.states import state

from ignite_framework.engines import Engine

from ignite_framework.metrics import Accuracy, Loss
from ignite_framework.output_handlers import OutputHandler
from ignite_framework.tensorboard import ScalarChart, ScalarsChart
from ignite_framework.utils import convert_tensor

import torch
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F
from torch.optim import SGD
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from typing import Sequence, Union, Optional


## LOGGER NOT INTEGRATED YET ##
# BUT it's very easy to integrate, see docu "Quickunderstanding State Customization"
# NOTE: General loggers could also be integrated in a `state.loggers` container, so users could add their preferred
# loggers to their `custom_states`
# Setup engine logger
logger = logging.getLogger("ignite.engine.engine.Engine")
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s %(name)-12s %(levelname)-8s %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)


### CONFIGURATIONS ###
#=====================

# Set default config `state.configs.default_config.tensorboard_log_dir` with shortcut
state.tensorboard_log_dir = './tensorboard_log_dir'
# NOTE:
#   - This config is a default config of `state.configs.default_configs`, therefore a shortcut exists which is used here
#   - If `tensorboard_log_dir` would not yet existed, you would call:
#       * either:
#           `state.configs.tensorboard_log_dir = 'tensorboard_log_dir'
#       * or:
#           `with state.configs as c:
#                c.tensorboard_log_dir = 'tensorboard_log_dir'
#       * both would result in a user defined state object (`StateConfigurationVariable)
#         `state.configs.user_defined_configs.tensorboar_lod_dir`

###  USER DEFINED PARAMETERS & HELPER FUNCTIONS ###
#==================================================

# NOTE:
#   - (non-callable) values parameters are assigned to `state.params.user_defined_params` as `StateParameter`s
#     (will be created if not yet exists)
#   - callables are assigned to `state.params.helpers` as `StateFunction` (will be created if not yet exists)
with state.params as p:
    p.train_batch_size = 64
    p.n_eval_epochs = 10
    p.eval_batch_size = 1000
    p.n_xval_step_samples = 640
    p.xval_batch_size = 100
    p.n_xval_samples = 200
    p.n_metric_logging_step_samples = 100
    p.non_blocking = False

    # Example for function assigning to `state.params.helpers`
    # NOTE: It is arguable how useful adding the functions to `state.params.helpers` is, but it's definitely a feature
    #       for writing cleaner training scripts defining all (most) functions in one segment. It is not required to
    #       assign any function to a `helpers` category
    @p
    def check_required_params():
        required_params = ['train_batch_size', 'eval_batch_size', 'n_eval_epochs']
        if not all([required_param in state.user_defined_params.get_bso_names() for required_param in required_params]):
            raise RuntimeError('Not all required params have been assigned to `state.params`.')
        else:
            print('All params attached!!!!!')


### APPEND TO CALLBACK ###
#=========================

# Check all params at the end of the state initialization by appending `check_reuqired_params` to
# callbacks of `state.state_init_copmleted`
# NOTE:
#   - You can check all callbacks in the debugger in `state.engines.state_status.state_init_completed_callbacks` or...
#   - check all callbacks with shortcut `print(state.state_init_completed_callbacks)`
state.state_init_completed = state.check_required_params


# Model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=-1)


### ALL MODULES INTEGRATION ###
#==============================

# NOTE:
# - all `Module`s are INTEGRSTED as `(Meta)StateComponent` in `state.modules`, meaning:
#       - important module parameters are integrated as `StateParameters`
#       . the module instance is assigned unchanged to e.g. `state.modules.model.component`, the `component` attribute
#       - `state.modules.model` behaves identical to the instance by forwarding the command to
#         `state.modules.model.component`
with state.modules as m:
    m.model = Net()
    m.model.cuda(device=state.device_for_modules)
    m.x_entropy = nn.CrossEntropyLoss()


### SYNC PARAMETERS ###
#======================

# NOTE: Non shortcutted command of below: `state.modules.model.device = state.configs.default_configs.device_for_modules_ref`
state.model.device = state.device_for_modules_ref
state.x_entropy.device = state.device_for_modules_ref


### HYPER PARAMETERS & ENGINE HELPER FUNCTIONS ###
#=======================

# NOTE:
#   - same as with `state.configs` and `state.params` BUT...
#   - (non-callable) values assigned to `state.engines.hyperparams` as `StateParameter` (will be created if not ...)
with state.engines as e:
    e.lr = 0.01
    e.momentum = 0.5

    # NOTE: It is arguable how useful adding the functions to `state.engines.helpers` is, but it's definitely a feature
    #       for writing cleaner training scripts and structuring your thoughts
    @e
    def _prepare_batch(batch: Sequence[torch.Tensor], device: Optional[Union[str, torch.device]] = None,
                       non_blocking: bool = False):
        """Prepare batch for training: pass to a device with options.
        """
        x, y = batch
        return (convert_tensor(x, device=device, non_blocking=non_blocking),
                convert_tensor(y, device=device, non_blocking=non_blocking))

    @e
    def get_labels(value):
        return value[2:0:-1]

    @e
    def _update(engine):
        state.model.train()
        state.sgd.zero_grad()
        # NOTE:
        # Instead of `e.prepare_batch` you could also the `with state.engines`-statement into 2 whereas the first
        # would assign those call in the other functions later on, e.g. `prepare_batch`. In the 2.`with`-statement
        # you could now call `state.prepare_batch` (equal to `state.engines.helpers.prepare_batch`)
        x, y = state._prepare_batch(engine.batch, device=state.device_for_modules, non_blocking=state.non_blocking)
        y_pred = state.model(x)
        loss = state.x_entropy(y_pred, y)
        loss.backward()
        state.sgd.step()
        return x, y, y_pred, loss

    @e
    def _infer(engine):
        state.model.eval()
        with torch.no_grad():
            x, y = state._prepare_batch(engine.batch, device=state.device_for_modules, non_blocking=state.non_blocking)
            y_pred = state.model(x)
            return x, y, y_pred


### ALL DATALOADERS ###
#======================

# NOTE:
# - `with`-statement useful for multiple state component/dataloader assignments
# - alternative: `state.dataloaders.trainer_loader = DataLoader(...)`
with state.dataloaders as d:
    data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
    # Trainer dataloader
    d.trainer_loader = DataLoader(MNIST(download=True, root=".", transform=data_transform, train=True),
                              batch_size=state.train_batch_size, shuffle=True)
    # Xvalidator dataloader
    d.xvalidator_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False),
                            batch_size=state.eval_batch_size, shuffle=False)
    # Evaluator dataloader
    d.evaluator_loader = DataLoader(MNIST(download=False, root=".", transform=data_transform, train=False),
                            batch_size=state.eval_batch_size, shuffle=False)



### PARAMETER SYNCHRONIZATION OF INTEGRATED MODULE ###
#=====================================================

# NOTE:
# - The suffix '_ref' of a state object return a reference object of itself
# - Is you set a state object with a reference of another state object, this state objects is synced by the value
#   of the reference state object
state.model.device = state.default_configs.device_for_modules_ref


### OPTIMIZER(S) ###
#===================

# NOTE:
# - `Optimizer's are assigned to `state.optimizers`
# - can also be assigned using the `with`-statement
# - NO ADDITIONAL PARAMETERS OR HELPER FUNCTIONS SETTING IN `state.optimizers`
state.optimizers.sgd = SGD(m.model.parameters(), lr=state.lr, momentum=state.momentum)


### ENGONES ###
#==============

# Trainer
# NOTE:
# - Below the maximum setting of default values is used.
# - The `Engine` is assigned to `state.engines.trainer`.
# - The argument `engine_run_started_ref=None` mean that the calling state object is not set, so `Engine` will
#   NOT be called. (Default values are always placeholdered by ``''``)
Engine(name='trainer',
       process=state._update, dataloader=state.trainer_loader,
       engine_run_started_ref=state.state_run_started_ref)

# X-Validator
# NOTE:
# . To parametrize the engine starting callback use:
#   `engine_run_started_ref=state.trainer.get('n_samples_every', state.n_xcal_step_smaples, 'ref')`
Engine(name='xvalidator',
       process=state._infer,
       dataloader=state.xvalidator_loader,
       engine_run_started_ref=state.trainer.get('n_samples_every', state.n_xval_step_samples, 'ref'),
       n_samples=state.n_xval_samples)

# Evaluator
# NOTE: Here is demonstrated how arguments can also be skipped during initialization and provided afterwards,
#       namely `dataloader` and `engine_run_started_ref`
Engine(name='evaluator',
       process=state._infer,
       n_epochs=state.n_eval_epochs)
# Setting up `state.evaluator`'s dataloader and callback after initialization
state.trainer.engine_run_completed = ('starte manually appended `state.evaluator.run` with `state.evaluator_loader',
                                     state.evaluator.run, {'dataloader': state.evaluator_loader})

### OUTPUT HANDLERS ###
#======================

# Trainer output handler
# NOTE: `
# - output_handler_name` is not set and defaults to `''` which turns on auto-naming
# - of course the output handler is assigned automatically to `state.output_handlers`
# - To avoid bothering about the exact name (which you easily will identify in the debugger) you can fetch the name by
#   adding the `.name` at the end and assign it th a name variable, see below
#
# EXTRA NOTE: `caller_ref` or in current Ignite term `event attaching`
# Most features have the argument `caller_ref` which refers to the event that calls/triggers e.g. calculating
# an output or logging a metric. Here the default value of `caller_ref` is `input_ref`, so each time a new input
# value is given by e.g. `state.trainer.output` the output handler of trainer is automatically called/triggered
# and will caclulate its own new output. If also the trainer metric's `caller_ref` is set to default, then the new
# output handler output will call/trigger the new calculation of metric... of course, same game with the
# chart/tensorbaord logger.
trainer_output_handler_name = OutputHandler(input_refs=state.trainer.output_ref,
                                            transform_input_func=get_labels).name
xvalidator_output_handler_name = OutputHandler(input_refs=state.xvalidator.output_ref,
                                               transform_input_func=get_labels).name
evaluator_output_handler_name = OutputHandler(input_refs=state.evaluator.output_ref,
                                              transform_input_func=get_labels).name

print('trainer_output_handler = ' + trainer_output_handler_name)
print('xvalidator_output_handler = ' + xvalidator_output_handler_name)
print('evaluator_output_handler = ' + evaluator_output_handler_name)

### METRICS ###
#==============

# NOTE: Currently, the output handlers are not automated in finding the correct inputs and providing the their outputs
#       in a standardized way so a metric could automatically pick the correct output values form the output handlers,
#       or directly form the engine output. This has to be implemented, see issue #???

# Trainer
trainer_loss_name = Loss(metric_name='loss',
                         loss_fn=state.x_entropy,
                         input_ref=state.get(trainer_output_handler_name).output_ref,
                         completed_ref=state.trainer.n_samples_every_100_ref).name

# NOTE: If you know the output handler name you can call it directly as shortcut attribute of `state` (see below) or
#       as an (real) attribute of `state.output_handlers`
trainer_accuracy_name = Accuracy(metric_name='accuracy',
                                 input_ref=state.trainer_get_labels.output_ref,
                                 completed_ref=state.trainer.n_samples_every_100_ref).name

# X-validator
# NOTE: Here we assume we know the xvalidator output handler name `xvalidator_get_lables`
xvalidator_loss_name = Loss(metric_name='loss',
                            loss_fn=state.x_entropy,
                            input_ref=state.xvalidator_get_labels.output_ref,
                            completed_ref=state.xvalidator.engine_run_completed_ref).name

xvalidator_accuracy_name = Accuracy(metric_name='accuracy',
                                    input_ref=state.xvalidator_get_labels.output_ref,
                                    completed_ref=state.xvalidator.engine_run_completed_ref).name

# Evaluator
evaluator_loss_name = Loss(metric_name='loss',
                           loss_fn=state.x_entropy,
                           input_ref=state.get(evaluator_output_handler_name).output_ref,
                           completed_ref=state.evaluator.engine_run_completed_ref).name

evaluator_accuracy_name = Accuracy(metric_name='accuracy',
                                    input_ref=state.get(evaluator_output_handler_name).output_ref,
                                    completed_ref=state.evaluator.engine_run_completed_ref).name

### CHARTS ###
#=============

# NOTE: The following charts manually implement what 2 high-level-APIs with 1-liner-commands accomplish in
#       `mnist_with_tensorboard_and_high_level_apis.py`, just to demonstrate the possibilities of framework automation.

# MANUALLY SET CHART IDENTICAL TO THE COMMAND:
# `EnginesMetricsComparisonCharts(x_axis_ref=state.trainer.n_samples_ref, n_identical_metric_name_suffixes=1)`
# -------------------------------------------------------------------------------------------------------------

# Accuracy charts of all engines, 3 metrics in 1 chart
# NOTE: Here the maximum default values are used, setting as little as neccessary
# EXTRA NOTE: Leaving out the `chart_name` in charts with many `y_axes_refs` usually leads to very longish chart names
ScalarsChart(x_axis_ref=state.trainer.n_samples_ref,
             y_axes_refs=[state.get(trainer_accuracy_name).output_ref,
                          state.get(xvalidator_accuracy_name).output_ref,
                          state.get(evaluator_accuracy_name).output_ref],
             caller_refs=[state.get(trainer_accuracy_name).output_ref,
                          state.xvalidator.engine_run_completed_ref,
                          state.evaluator.engine_run_completed_ref])

# Loss charts of all engines, 3 metics in 1 chart
# NOTE:
# - Here maximum manual indivializations are called, some arguments are redundent in the way that they
#   equal (or behave equally to) the default value, e.g. the `caller_refs` would result in same chart.
# - The `walltime_ref` uses the state timer `state.maintenance.state_timer.state_run_started_time_in_sec`
ScalarsChart(x_axis_ref=state.trainer.n_samples_ref,
             y_axes_refs=[state.get(trainer_loss_name).output_ref,
                         state.get(xvalidator_loss_name).output_ref,
                         state.get(evaluator_loss_name).output_ref],
             caller_refs=[state.trainer.n_samples_every_100_ref,
                          state.xvalidator.engine_run_completed_ref,
                          state.evaluator.engine_run_completed_ref],
             y_names=['trainer_loss','xvalidator_loss', 'evaluator_loss'],
             chart_name='engines_loss_comparison',
             bso_ctr_name_suffix='manually_set_chart',
             summary_writer=state.tensorboard.summary_writer,
             summary_description='manually written description for summary adding',
             walltime_ref=state.state_run_started_time_in_sec_ref)


# MANUALLY SET CHART IDENTICAL TO THE COMMAND:
#
# `EnginesMetricsCharts(x_axes_refs=state.trainer.n_samples_ref, n_identical_metric_name_suffixes=1)`
# ---------------------------------------------------------------------------------------------------

# Trainer
ScalarChart(x_axis_ref=state.trainer.n_samples_ref, y_axis_ref=state.get(trainer_accuracy_name).output_ref)
ScalarChart(x_axis_ref=state.trainer.n_samples_ref, y_axis_ref=state.get(trainer_loss_name).output_ref)
# Xvalidator
ScalarChart(x_axis_ref=state.trainer.n_samples_ref,
            y_axis_ref=state.get(xvalidator_accuracy_name).output_ref,
            caller_refs=state.xvalidator.engine_run_completed_ref)
ScalarChart(x_axis_ref=state.trainer.n_samples_ref,
            y_axis_ref=state.get(xvalidator_loss_name).output_ref,
            caller_refs=state.xvalidator.engine_run_completed_ref)
# Evaluator
ScalarChart(x_axis_ref=state.trainer.n_samples_ref,
            y_axis_ref=state.get(evaluator_accuracy_name).output_ref,
            caller_refs=state.evaluator.engine_run_completed_ref)
ScalarChart(x_axis_ref=state.trainer.n_samples_ref,
            y_axis_ref=state.get(evaluator_loss_name).output_ref,
            caller_refs=state.evaluator.engine_run_completed_ref)

# NOTE:
# - Did you realize... we got rid of the `other_engine`, even though we were using 3 engines :-)
# - Additionally, for these charts one would need to implement `anaother_x_axis` in the current Ignite, too.

### HARDWARE RESOURCE TRACKING TO BE IMPLEMENTED SOON ###
#========================================================
# NOTE: The current GPU-tracker may have the issues that it is triggered by `Event.ITERATION_COMPLETED` logging
#       the GPU-usage at this point (not even max usage). Unfortunately, when the iteration is completed the
#       GPU-calculations are also completed, `GpuInfo` should only be measuring the resource usage during downtime.
#       Therefore the new implementation will be implmented on a separate thread with slightly randomized logging
#       timing.

# if sys.version_info > (3,):
#     from ignite.contrib.metrics.gpu_info import GpuInfo
#     try:
#         GpuInfo().attach(state.trainer)
#     except RuntimeError:
#         print("INFO: By default, in this example it is possible to log GPU information (used memory, utilization). "
#               "As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please "
#               "install it : `pip install pynvml`")

state.run()