from inspect import currentframe, signature, Signature, Parameter
from collections import OrderedDict
from functools import wraps
from types import FunctionType, MethodType, BuiltinFunctionType, BuiltinMethodType
from ignite_framework.exceptions import FrameworkNotImplementedError, FrameworkRestrictionError, FrameworkTypeError
from ignite_framework.exceptions import FeatureNotImplementedError
# ==================================================================================================================
### CONSTANTS ###
# ==================================================================================================================
CALLABLE_TYPES = (FunctionType, MethodType, BuiltinFunctionType, BuiltinMethodType, type(print).mro()[0])
FSO_OVERLOAD_SUFFIXES = ['ref', 'callbacks', 'once', 'every']
# ==================================================================================================================
### CLASSES ###
# ==================================================================================================================
# class AttributeToItemDictionary(dict):
# def __init__(self):
# self['decorator_args'] = {}
#
# def __getattr__(self, key):
# return self[key]
#
# def __setattr__(self, key, value):
# self[key] = value
[docs]class DecoratorAndAttributeToItemDictionary(dict):
[docs] def __init__(self):
self['decorator_args'] = {}
[docs] def __call__(self, func):
self['decorator_args'][func.__name__] = func
return self['decorator_args'][func.__name__]
[docs] def __getattr__(self, key):
try:
return self[key]
except KeyError as error_msg:
try:
return self['decorator_args'][key]
except KeyError:
raise KeyError(error_msg)
[docs] def __setattr__(self, key, value):
if isinstance(value, CALLABLE_TYPES):
self['decorator_args'][key] = value
else:
self[key] = value
[docs]class CallbacksList(list):
[docs] def __setitem__(self, key, value):
self._type_check_and_format(self, object=value)
list.__setitem__(self, key, value)
[docs] def append(self, object):
"""
Redundant type checking to `BaseCallbackiverloadsStateObject` as the callback
list is exposed to the user for full list functionality.
Args:
object:
Returns:
"""
super().append(self._type_check_and_format(object))
#TODO:
# - This return is required for appending callbacks with decorator `@state.<fso_name>_callbacks`
# - Discuss if necessary
return object
[docs] def reorder(self, order_idcs):
list_length = len(self)
if len(order_idcs) != list_length:
raise FrameworkTypeError(
'To reorder the callbacks the length of `order_idcs` must be equal to the callbacks\' length of {}, '
'but given length of `idx_order` is {}.'.format(len(self), len(order_idcs)))
for idx in order_idcs:
self.append(self[idx])
for _ in range(list_length):
# Delete old-ordered-list iterative 1-by-1
self.__delitem__(0)
[docs] def _raise_error(self, object):
raise FrameworkRestrictionError(
'Cannot append `object` to the callback list, because its format cannot be converted'
'into a callback tuple. A callback tuple is of length `4` '
'with `object[0]` of type `{}` and `object[1]` of type `dict`, but got: '
'`()`'.format(CALLABLE_TYPES, type(object)([type(obj) for obj in object])))
[docs]class SingleKeyAssignmentDictMixin:
[docs] def __setitem__(self, key, value):
if key in self.keys():
raise FrameworkRestrictionError(
'The key value `{}` of `{}()` is read-delete-only, overwriting is forbidden. To change the '
'value explicitly delete the item and reassign the key to the new value.'.format(
key, self.__class__.__name__))
super().__setitem__(key, value)
[docs] def __delitem__(self, key):
if key[:13] == '_delete_item_':
dict.__delitem__(self, key[13:])
else:
object.__delattr__(self[key], key)
[docs]class SingleKeyAssignmentDict(SingleKeyAssignmentDictMixin, dict): pass
[docs]class OrderedSingleKeyAssignmentDict(SingleKeyAssignmentDictMixin, OrderedDict): pass
# ==================================================================================================================
### DECORATORS ###
# ==================================================================================================================
# #TODO: - remove argument `name` if not required till PR
# def namespace_decorator(name, bases, namespace):
# def conditional_namespace_item(method=None, only_under_condition=True):
# # Indirect `base.__dict__.keys()`-checking instead of `hasattr(base, method.__name__)` to override
# # non-intentional/implicit methods of base classes, e.g. '__init__' from `object` which is not listed in
# # `__dict__`
# def namespace_item(method):
# if not (any(
# [method.__name__ in base.__dict__.keys() for base in bases ]) or method.__name__ in namespace.keys()):
# namespace[method.__name__] = method
# return method
#
# # # Always add given `method` to `namespace` if no `only_for_containers` is given
# # if only_under_condition:
# # if method is None:
# # # Provoke excepted argument error because no `method` was given
# # namespace_item()
# # else:
# # # Add given `method` to `namespace`
# # namespace_item(method)
# # return method
# # # Add given `method` to `namespace` only if the current container is in the given `only_for_containers`
# # # else:
# # # If current container (class) in `only_for_containers`
# if only_under_condition:
# # If arguments only partially filled
# if method is None:
# # Return decorator for standard method decoration
# return namespace_item
# else:
# # Directly add `method` to `namespace`
# namespace_item(method)
# return method
# else:
# # Do NOT add `method` to `namespace` because current container should not include this method
# # Return identity-decorator, whoch does nothing
# return lambda x:x
#
# return conditional_namespace_item
#TODO: - remove argument `name` if not required till PR
[docs]def namespace_decorator(name, bases, namespace):
def conditional_namespace_item(method=None, only_for_ctrs=None):
# Indirect `base.__dict__.keys()`-checking instead of `hasattr(base, method.__name__)` to override
# non-intentional/implicit methods of base classes, e.g. '__init__' from `object` which is not listed in
# `__dict__`
def namespace_item(method):
method_name = method.__name__ if type(method) not in [staticmethod, classmethod] \
else method.__func__.__name__
if not (any(
[method_name in base.__dict__.keys() for base in bases]) or method_name in namespace.keys()):
namespace[method_name] = method
return method
# Always add given `method` to `namespace` if no `only_for_containers` is given
if only_for_ctrs is None:
if method is None:
# Provoke excepted argument error because no `method` was given
namespace_item()
else:
# Add given `method` to `namespace`
namespace_item(method)
return method
# Add given `method` to `namespace` only if the current container is in the given `only_for_containers`
else:
# If current container (class) in `only_for_containers`
if name in only_for_ctrs:
# If arguments only partially filled
if method is None:
# Return decorator for standard method decoration
return namespace_item
else:
# Directly add `method` to `namespace`
namespace_item(method)
return method
else:
# Do NOT add `method` to `namespace` because current container should not include this method
# Return identity-decorator, whoch does nothing
return lambda x:x
return conditional_namespace_item
# def namespace_decorator(bases, namespace):
# def namespace_item(method):
# # Indirect `base.__dict__.keys()`-checking instead of `hasattr(base, method.__name__)` to override
# # non-intentional/implicit methods of base classes, e.g. '__init__' from `object` which is not listed in
# # `__dict__`
# if not (any(
# [method.__name__ in base.__dict__.keys() for base in bases]) or method.__name__ in namespace.keys()):
# namespace[method.__name__] = method
#
# return namespace_item
# ==================================================================================================================
### FUNCTIONS ###
# ==================================================================================================================
[docs]def bases_have_attr(bases, name):
return any(name in base.__dict__.keys() for base in bases)
[docs]def get_argument_name_from_frame_inspection(name, obj, default_name='obj'):
f_back_globals = currentframe().f_back.f_globals
argument_names = [key for key in f_back_globals if id(f_back_globals[key]) == id(obj)]
argument_names = list(set(argument_names))
# argument_names.remove(name)
if len(argument_names) != 1:
try:
return obj.__name__
except AttributeError:
raise RuntimeError('Argument name of `{}` could not be identified correctly. List of identified '
'object keys: {}'.format(name, argument_names))
return '{}_ID{}'.format(obj.__class__.__name__, id(obj))
# return '{}_ID{}'.format(default_name, id(obj))
return argument_names[0]
[docs]def raise_run_not_implemented_and_appendable_to_callback_error(transition):
raise FrameworkTypeError(
'The method `{inst}.run()` of class `{cls}` is not implmeneted but was appended to a callback '
'list. All required methods of `{cls}` have already been appended to callbacks during '
'initialization. Please delete the code line that appemds `{inst}.run` to a callback list.'
''.fomrat(inst=transition, cls=transition.__class__))
[docs]def raise_or_warn_abstract_method_not_implemented_error(instance, method, error_or_warn=FeatureNotImplementedError):
msg = '`{c}.{f}` is an abstract method that must be overriden before ' \
'`{c}` is initialized.'.format(
c='cls' if instance is None else instance.__class__.__name__,
f=method.__func__.__name__ if isinstance(method, staticmethod) else method.__name__)
if isinstance(error_or_warn, BaseException):
raise error_or_warn(msg)
else:
error_or_warn(msg)
# ==================================================================================================================
### STATE COMPONENT INTEGRATION GETTER & SETTER: `instance_attr__get__`, `instance_attr__set__` ###
# ==================================================================================================================
### For simple attributes ###
[docs]def get_component_attr(self, instance, owner):
"""
'instance_attr__get__` method for `state component` integrating a simple `attribute with name ``self.name``` of
a PyTorch class instance (or any external class), e.g. integrating ``Module.training`` as `state parameter`
in the `state component` `state.modules.model`.
Args:
self: `BaseStateObject` (e.g. `StateVariable()`)instance or subclass instance that integrates the relevant \
state component attributes with `state`.
instance: `BaseStateObject` container (e.g. `state.engines`)
owner: The owner class holding/owning `self`
Returns:
attribute of `state component`
"""
return getattr(instance.component, self.name)
[docs]def set_component_attr(self, instance, value):
"""
`instance_attr__set__` function for `BaseStateObject` to integrate `StateComponent` attributes in `state`.
Args:
self: `BaseStateObject` (e.g. ``StateVariable()``\) instance or subclass instance that integrates the relevant \
state compoment attributes with `state`.
instance: `BaseStateObject` container (e.g. `state.engines`\)
value: value to be set
"""
setattr(instance.component, self.name, value)
### For module device ###
[docs]def get_module_component_device(self, instance, owner):
"""
The 'instance_attr__get__` method for the `state component` integrating the parameter `device` (`self.name`)
Args:
instance: the `state object` instance (data descriptor instance attached to the class of `self`)
owner: the class of `self` (``self.__class__``)
Returns:
current device of the module instance `self.component`
"""
# No `get` methods for devices in model, therefore last setting is saved to state parameter on `instance` and
# recalled for there
return instance.__dict__[self.name]
[docs]def set_module_component_device(self, instance, value):
"""
The `instance_attr__set__` method for a `state component` integrating the parameter `device` of a PyTorch
class (``Module``\).
Actually, the PyTorch class method ``\*.to(device=...)`` is integrated as there is not direct parameter
`device` in e.g. ``torch.nn.Module``.
Args:
instance: the `state object` instance (data descriptor instance attached to the class of `self`)
value: the value to set the `instance`'\s component's attribute `self.name` (`==device`)
"""
# NOTE: shortcut `self.to(device=value)` would also work below
instance.component.to(device=value)
# save current device in `instance.__dict__` to be recalled by getter, as `Module` does not offer a overall device
instance.__dict__[self.name] = value