Source code for metann.dependentmodule

from collections import OrderedDict
from copy import deepcopy

import torch
from torch.nn import Module
from torch._six import string_classes

from .utils import SubDict

from typing import Dict, List, Optional, Tuple, Union, Any


[docs]class DependentModule(Module): r""" The PyTorch sugggest all parameters of a module to be independent variables, and forbid a parameter to have a grad_fn. This module provides an extension to nn.Module by register a subset of buffers as **dependents**, which indicates the dependent parameters. This enables the parameters of a DependentModule to be the dependent variables, which is useful in meta learning. This module calls DependentModule.to_dependentmodule when it is created. It turns the module and all of its submodules into sub class of DependentModule. Then you might use clear_params to transform all parameters to dependents. Examples: >>> net = Sequential(Linear(10, 5), Linear(5, 2)) >>> DependentModule(net) DependentSequential( (0): DependentLinear(in_features=10, out_features=5, bias=True) (1): DependentLinear(in_features=5, out_features=2, bias=True) ) Note: This class change the origin module when initializing, you might use >>> DependentModule(deepcopy(net)) if you want the origin model stay unchanged. """ def __new__(cls, *args, **kwargs): if len(args) == 1 and isinstance(args[0], Module): module = cls.to_dependentmodule(args[0]) else: module = super(DependentModule, cls).__new__(cls, *args, **kwargs) return module def __init__(self, *args, **kwargs) -> None: self._reinit() def _reinit(self) -> None: self._dependents = SubDict(self._buffers) self._active_dependents = SubDict(self._dependents) self._dependents_shapes = {} def __setattr__(self, name: str, value: Any) -> None: if isinstance(value, Module): value = self.to_dependentmodule(value) super(DependentModule, self).__setattr__(name, value)
[docs] def register_dependent(self, name: str, tensor: torch.Tensor) -> None: r""" register a named tensor to dependents. Args: name: name of dependent tensor tensor (torch.Tensor): dependent tensor Examples: >>> dnet = DependentModule(net) >>> dnet.register_dependent('some_tensor', torch.randn(3, 3)) >>> dnet.some_tensor tensor([[ 0.4434, 0.9949, -0.4385], [-0.5292, 0.2555, 0.7772], [-0.5386, 0.6152, -0.3239]]) """ if '_dependents' not in self.__dict__: raise AttributeError( "cannot assign dependent parameter before MetaModule.__init__() or MetaModule._reinit() call") elif not isinstance(name, torch._six.string_classes): raise TypeError("dependent parameter name should be a string. " "Got {}".format(torch.typename(name))) elif '.' in name: raise KeyError("dependent parameter name can't contain \".\"") elif name == '': raise KeyError("dependent parameter name can't be empty string \"\"") elif hasattr(self, name) and not name in self._dependents: raise KeyError("attribute '{}' already exists".format(name)) elif tensor is not None and not isinstance(tensor, torch.Tensor): raise TypeError("cannot assign '{}' object to dependent parameter '{}' " "(torch Tensor or None required)" .format(torch.typename(tensor), name)) else: if tensor is not None: self._active_dependents[name] = tensor self._dependents_shapes[name] = tensor.shape else: self._dependents[name] = tensor
[docs] def named_dependents(self, prefix: str='', recurse=True): r""" Args: prefix: the prefix of the names recurse: traverse only the direct submodules of self if set to False Returns: Iterative: iterator of name, dependent pairs of self and sub modules. """ memo = set() modules = self.named_modules(prefix=prefix) if recurse else [(prefix, self)] for module_prefix, module in modules: members = (lambda module: module._active_dependents.items())(module) for k, v in members: if v in memo and v is not None: continue memo.add(v) name = module_prefix + ('.' if module_prefix else '') + k yield name, v
[docs] def dependents(self, recurse=True): r""" Args: recurse: traverse only the direct submodules of self if set to False Returns: Iterative: iterator of dependents of self and sub modules. """ for name, param in self.named_dependents(recurse=recurse): yield param
[docs] def update_shapes(self): r""" Update the register shape of dependents. Call this method when a dependent is initialize with None and assign to a tensor. **Do not** call this method when you are using built-in methods only. """ def gen(): for name, value in self._active_dependents.items(): if value is None: if name in self._dependents_shapes: yield name, self._dependents_shapes[name] else: continue else: yield name, value.shape self._dependents_shapes = dict(gen())
def _substitute(self, name, value): if name not in self._dependents: raise KeyError("{} is not in dependent parameters".format(name)) elif name in self._dependents_shapes.keys() and self._dependents_shapes[name] != value.shape: raise ValueError("size mismatch for {}, expect {}, got {}".format( name, self._dependents_shapes[name], value.shape)) self._dependents[name] = value def _substitute_from_params_dict(self, params_dict, prefix, strict=True): for name in self._dependents: key = prefix + name if strict == True: if key in params_dict: self._substitute(name, params_dict[key]) else: raise ValueError("params_dict and interim parameters mismatch, got {}".format(key)) elif strict == 'one way': if key in params_dict: self._substitute(name, params_dict[key]) else: if key in params_dict: try: self._substitute(name, params_dict[key]) except (KeyError, ValueError): pass
[docs] def substitute(self, named_params, strict=True): r""" Substitute self's dependents with the tensors of same name Args: named_params: iterator of name, tensor pairs strict (bool): forbid named_params and self._dependents mismatch if set to True. default: True """ params_dict = dict(named_params) def load(module: DependentModule, prefix='', _strict=True): module._substitute_from_params_dict(params_dict, prefix, strict=_strict) for name, child in module._modules.items(): load(child, prefix + name + '.', _strict=_strict) load(self, _strict=strict)
[docs] def substitute_from_list(self, params): r""" Substitute from tensor list. Args: params: iterator of tensors """ named_params = ((k, v) for (k, _), v in zip(self.named_dependents(), params)) self.substitute(named_params, strict='one way')
[docs] def update_actives(self): keys = set() for key in self._dependents.keys(): if isinstance(self._dependents[key], torch.Tensor): keys.add(key) self._active_dependents = SubDict(self._dependents, keys)
[docs] def clear_params(self, init=False, clear_filter=lambda x: True): r""" Clear all parameters of self and register them as dependents. Args: init (bool): Set the values of dependents to None if set to False, otherwise keep the value of origin parameters. clear_filter: Function that return False when those modules you don't want to clear parameters are input """ def clear_fn(module: DependentModule): if clear_filter(module): for name, value in module._parameters.items(): module._dependents[name] = value.clone().detach().requires_grad_() if value is not None else None module._parameters = OrderedDict() module.update_actives() module.update_shapes() if not init: for key in module._dependents: module._dependents[key] = None self.apply(clear_fn) return self
@classmethod def _sub_class(cls, module: Module): if not isinstance(module, DependentModule): return type("Dependent" + type(module).__name__, (DependentModule, type(module)), {}) else: return type(module) @classmethod def _make_subclass(cls, module: Module): if not isinstance(module, cls): module.__class__ = type("Dependent" + type(module).__name__, (cls, type(module)), {}) module._reinit() return module
[docs] @classmethod def to_dependentmodule(cls, module: Module, recurse=True): r""" Transform a module and all its submodule into dependent module. Args: module: recurse: if set to be True all submodules will be transformed into dependent module recursively. Returns: DependentModule: a dependent module """ if not recurse: module = cls._make_subclass(module) else: module.apply(lambda x: cls.to_dependentmodule(x, recurse=False)) return module
[docs] @classmethod def stateless(cls, module: Module, clear_filter=lambda x: True): r""" transform input module into a DependentModule whose parameters are cleared. Args: module: clear_filter: Function that return False when those modules you don't want to clear parameters are input """ return cls.to_dependentmodule(deepcopy(module)).clear_params(clear_filter)