Shortcuts

Source code for ignite_framework.framework.framework_utils

from inspect import currentframe, signature, Signature, Parameter
from collections import OrderedDict
from functools import wraps
from types import FunctionType, MethodType, BuiltinFunctionType, BuiltinMethodType

from ignite_framework.exceptions import FrameworkNotImplementedError, FrameworkRestrictionError, FrameworkTypeError

from ignite_framework.exceptions import FeatureNotImplementedError



# ==================================================================================================================
### CONSTANTS ###
# ==================================================================================================================



CALLABLE_TYPES = (FunctionType, MethodType, BuiltinFunctionType, BuiltinMethodType, type(print).mro()[0])

FSO_OVERLOAD_SUFFIXES = ['ref', 'callbacks', 'once', 'every']

# ==================================================================================================================
### CLASSES ###
# ==================================================================================================================



# class AttributeToItemDictionary(dict):
#     def __init__(self):
#         self['decorator_args'] = {}
#
#     def __getattr__(self, key):
#         return self[key]
#
#     def __setattr__(self, key, value):
#         self[key] = value


[docs]class DecoratorAndAttributeToItemDictionary(dict):
[docs] def __init__(self): self['decorator_args'] = {}
[docs] def __call__(self, func): self['decorator_args'][func.__name__] = func return self['decorator_args'][func.__name__]
[docs] def __getattr__(self, key): try: return self[key] except KeyError as error_msg: try: return self['decorator_args'][key] except KeyError: raise KeyError(error_msg)
[docs] def __setattr__(self, key, value): if isinstance(value, CALLABLE_TYPES): self['decorator_args'][key] = value else: self[key] = value
[docs]class CallbacksList(list):
[docs] def __setitem__(self, key, value): self._type_check_and_format(self, object=value) list.__setitem__(self, key, value)
[docs] def append(self, object): """ Redundant type checking to `BaseCallbackiverloadsStateObject` as the callback list is exposed to the user for full list functionality. Args: object: Returns: """ super().append(self._type_check_and_format(object)) #TODO: # - This return is required for appending callbacks with decorator `@state.<fso_name>_callbacks` # - Discuss if necessary return object
[docs] def _type_check_and_format(self, object): if isinstance(object, CALLABLE_TYPES): object = ('', object, [], {}) if hasattr(object, 'run'): object = (object.name + '.run', object.run, [], {}) elif isinstance(object, tuple): # if len(object) == 2: # # Sub-case: `object = (func, dict)` # if isinstance(object[0], CALLABLE_TYPES) and isinstance(object[1], dict): # object = ('', *object) # elif isinstance(object[0], str) and isinstance(object[1], CALLABLE_TYPES): # object = (*object, {}) # if len(object) == 3 and isinstance(object[0], str)and isinstance(object[1], CALLABLE_TYPES) \ # and isinstance(object[2], dict): # pass # else: # self._raise_error(object) # else: # self._raise_error(object) new_callback = check_if_tuple_is_callback_and_reformat(value=object) if new_callback is None: self._raise_error(self, object) return new_callback
[docs] def reorder(self, order_idcs): list_length = len(self) if len(order_idcs) != list_length: raise FrameworkTypeError( 'To reorder the callbacks the length of `order_idcs` must be equal to the callbacks\' length of {}, ' 'but given length of `idx_order` is {}.'.format(len(self), len(order_idcs))) for idx in order_idcs: self.append(self[idx]) for _ in range(list_length): # Delete old-ordered-list iterative 1-by-1 self.__delitem__(0)
[docs] def _raise_error(self, object): raise FrameworkRestrictionError( 'Cannot append `object` to the callback list, because its format cannot be converted' 'into a callback tuple. A callback tuple is of length `4` ' 'with `object[0]` of type `{}` and `object[1]` of type `dict`, but got: ' '`()`'.format(CALLABLE_TYPES, type(object)([type(obj) for obj in object])))
[docs]class SingleKeyAssignmentDictMixin:
[docs] def __setitem__(self, key, value): if key in self.keys(): raise FrameworkRestrictionError( 'The key value `{}` of `{}()` is read-delete-only, overwriting is forbidden. To change the ' 'value explicitly delete the item and reassign the key to the new value.'.format( key, self.__class__.__name__)) super().__setitem__(key, value)
[docs] def __delitem__(self, key): if key[:13] == '_delete_item_': dict.__delitem__(self, key[13:]) else: object.__delattr__(self[key], key)
[docs]class SingleKeyAssignmentDict(SingleKeyAssignmentDictMixin, dict): pass
[docs]class OrderedSingleKeyAssignmentDict(SingleKeyAssignmentDictMixin, OrderedDict): pass
# ================================================================================================================== ### DECORATORS ### # ================================================================================================================== # #TODO: - remove argument `name` if not required till PR # def namespace_decorator(name, bases, namespace): # def conditional_namespace_item(method=None, only_under_condition=True): # # Indirect `base.__dict__.keys()`-checking instead of `hasattr(base, method.__name__)` to override # # non-intentional/implicit methods of base classes, e.g. '__init__' from `object` which is not listed in # # `__dict__` # def namespace_item(method): # if not (any( # [method.__name__ in base.__dict__.keys() for base in bases ]) or method.__name__ in namespace.keys()): # namespace[method.__name__] = method # return method # # # # Always add given `method` to `namespace` if no `only_for_containers` is given # # if only_under_condition: # # if method is None: # # # Provoke excepted argument error because no `method` was given # # namespace_item() # # else: # # # Add given `method` to `namespace` # # namespace_item(method) # # return method # # # Add given `method` to `namespace` only if the current container is in the given `only_for_containers` # # # else: # # # If current container (class) in `only_for_containers` # if only_under_condition: # # If arguments only partially filled # if method is None: # # Return decorator for standard method decoration # return namespace_item # else: # # Directly add `method` to `namespace` # namespace_item(method) # return method # else: # # Do NOT add `method` to `namespace` because current container should not include this method # # Return identity-decorator, whoch does nothing # return lambda x:x # # return conditional_namespace_item #TODO: - remove argument `name` if not required till PR
[docs]def namespace_decorator(name, bases, namespace): def conditional_namespace_item(method=None, only_for_ctrs=None): # Indirect `base.__dict__.keys()`-checking instead of `hasattr(base, method.__name__)` to override # non-intentional/implicit methods of base classes, e.g. '__init__' from `object` which is not listed in # `__dict__` def namespace_item(method): method_name = method.__name__ if type(method) not in [staticmethod, classmethod] \ else method.__func__.__name__ if not (any( [method_name in base.__dict__.keys() for base in bases]) or method_name in namespace.keys()): namespace[method_name] = method return method # Always add given `method` to `namespace` if no `only_for_containers` is given if only_for_ctrs is None: if method is None: # Provoke excepted argument error because no `method` was given namespace_item() else: # Add given `method` to `namespace` namespace_item(method) return method # Add given `method` to `namespace` only if the current container is in the given `only_for_containers` else: # If current container (class) in `only_for_containers` if name in only_for_ctrs: # If arguments only partially filled if method is None: # Return decorator for standard method decoration return namespace_item else: # Directly add `method` to `namespace` namespace_item(method) return method else: # Do NOT add `method` to `namespace` because current container should not include this method # Return identity-decorator, whoch does nothing return lambda x:x return conditional_namespace_item
# def namespace_decorator(bases, namespace): # def namespace_item(method): # # Indirect `base.__dict__.keys()`-checking instead of `hasattr(base, method.__name__)` to override # # non-intentional/implicit methods of base classes, e.g. '__init__' from `object` which is not listed in # # `__dict__` # if not (any( # [method.__name__ in base.__dict__.keys() for base in bases]) or method.__name__ in namespace.keys()): # namespace[method.__name__] = method # # return namespace_item # ================================================================================================================== ### FUNCTIONS ### # ==================================================================================================================
[docs]def bases_have_attr(bases, name): return any(name in base.__dict__.keys() for base in bases)
[docs]def reformat_as_list(*args): if len(args) == 1: return list(list(args[0])) if not isinstance(args[0], list) else list(args[0]) return tuple([list(list(arg)) if not isinstance(arg, list) else list(arg) for arg in args])
[docs]def check_if_tuple_is_callback_and_reformat(value): """ Identify if `value` as new callback before it is appended to a callback list of e.g. a `StateVariable` or `StateParameter`. Implemented to return quickly in case `value` is not a new callback. Args: instance: value: Returns: """ # check if callable is in tuple at correct position # Note: `n_identified_vals` starts with `1` to include identified callable n_identified_vals = 1 # Partially build possibly missing description description = '' callable_idx = [idx for idx, val in enumerate(value) if isinstance(val, CALLABLE_TYPES)] # If no callable types is given... if not callable_idx: # ... check for transitions (state object containers) with `run` method callable_idx = [idx for idx, val in enumerate(value) if hasattr(val, 'run')] # Exchange any fso-container with it's method `<fso_ctr>.run` if len(callable_idx) == 1: fso_ctr_idx = callable_idx[0] # Partially build possibly missing description description += 'state.{}.'.format(value[fso_ctr_idx].name) value = list(value) value.insert(fso_ctr_idx, value.pop(fso_ctr_idx).run) value = tuple(value) if len(callable_idx) != 1 or callable_idx[0] > 2: return None else: # Check if callback description is in tuple description_idx = [idx for idx, val in enumerate(value) if isinstance(val, str)] # Only 1 description allowed if len(description_idx) > 1: return None # Check if args is in tuple args_idx = [idx for idx, val in enumerate(value) if isinstance(val, list)] # Only 1 args is allowed if len(args_idx) > 1: return None # Check if kwargs is in tuple kwargs_idx = [idx for idx, val in enumerate(value) if isinstance(val, dict)] # Only one kwargs is allowed if len(kwargs_idx) > 1: return None if len(description_idx) == 1: # `description` must be first element if description_idx[0] != 0: return None description = value[description_idx[0]] n_identified_vals += 1 else: try: # ...finalized description description += value[callable_idx[0]].__name__ + '()' except AttributeError: # Reset partial description if finalizing description fails description = '' if len(args_idx) == 1: # `args` must be behind callable if args_idx[0] < callable_idx[0]: return None args = value[args_idx[0]] n_identified_vals += 1 else: args = [] if len(kwargs_idx) == 1: # `kwargs` must be last element if kwargs_idx[0] != len(value) - 1: return None kwargs = value[kwargs_idx[0]] n_identified_vals += 1 else: kwargs = {} if n_identified_vals != len(value): return None return (description, value[callable_idx[0]], args, kwargs)
[docs]def get_argument_name_from_frame_inspection(name, obj, default_name='obj'): f_back_globals = currentframe().f_back.f_globals argument_names = [key for key in f_back_globals if id(f_back_globals[key]) == id(obj)] argument_names = list(set(argument_names)) # argument_names.remove(name) if len(argument_names) != 1: try: return obj.__name__ except AttributeError: raise RuntimeError('Argument name of `{}` could not be identified correctly. List of identified ' 'object keys: {}'.format(name, argument_names)) return '{}_ID{}'.format(obj.__class__.__name__, id(obj)) # return '{}_ID{}'.format(default_name, id(obj)) return argument_names[0]
[docs]def raise_run_not_implemented_and_appendable_to_callback_error(transition): raise FrameworkTypeError( 'The method `{inst}.run()` of class `{cls}` is not implmeneted but was appended to a callback ' 'list. All required methods of `{cls}` have already been appended to callbacks during ' 'initialization. Please delete the code line that appemds `{inst}.run` to a callback list.' ''.fomrat(inst=transition, cls=transition.__class__))
[docs]def raise_or_warn_abstract_method_not_implemented_error(instance, method, error_or_warn=FeatureNotImplementedError): msg = '`{c}.{f}` is an abstract method that must be overriden before ' \ '`{c}` is initialized.'.format( c='cls' if instance is None else instance.__class__.__name__, f=method.__func__.__name__ if isinstance(method, staticmethod) else method.__name__) if isinstance(error_or_warn, BaseException): raise error_or_warn(msg) else: error_or_warn(msg)
# ================================================================================================================== ### STATE COMPONENT INTEGRATION GETTER & SETTER: `instance_attr__get__`, `instance_attr__set__` ### # ================================================================================================================== ### For simple attributes ###
[docs]def get_component_attr(self, instance, owner): """ 'instance_attr__get__` method for `state component` integrating a simple `attribute with name ``self.name``` of a PyTorch class instance (or any external class), e.g. integrating ``Module.training`` as `state parameter` in the `state component` `state.modules.model`. Args: self: `BaseStateObject` (e.g. `StateVariable()`)instance or subclass instance that integrates the relevant \ state component attributes with `state`. instance: `BaseStateObject` container (e.g. `state.engines`) owner: The owner class holding/owning `self` Returns: attribute of `state component` """ return getattr(instance.component, self.name)
[docs]def set_component_attr(self, instance, value): """ `instance_attr__set__` function for `BaseStateObject` to integrate `StateComponent` attributes in `state`. Args: self: `BaseStateObject` (e.g. ``StateVariable()``\) instance or subclass instance that integrates the relevant \ state compoment attributes with `state`. instance: `BaseStateObject` container (e.g. `state.engines`\) value: value to be set """ setattr(instance.component, self.name, value)
### For module device ###
[docs]def get_module_component_device(self, instance, owner): """ The 'instance_attr__get__` method for the `state component` integrating the parameter `device` (`self.name`) Args: instance: the `state object` instance (data descriptor instance attached to the class of `self`) owner: the class of `self` (``self.__class__``) Returns: current device of the module instance `self.component` """ # No `get` methods for devices in model, therefore last setting is saved to state parameter on `instance` and # recalled for there return instance.__dict__[self.name]
[docs]def set_module_component_device(self, instance, value): """ The `instance_attr__set__` method for a `state component` integrating the parameter `device` of a PyTorch class (``Module``\). Actually, the PyTorch class method ``\*.to(device=...)`` is integrated as there is not direct parameter `device` in e.g. ``torch.nn.Module``. Args: instance: the `state object` instance (data descriptor instance attached to the class of `self`) value: the value to set the `instance`'\s component's attribute `self.name` (`==device`) """ # NOTE: shortcut `self.to(device=value)` would also work below instance.component.to(device=value) # save current device in `instance.__dict__` to be recalled by getter, as `Module` does not offer a overall device instance.__dict__[self.name] = value