MetaNN for PyTorch Meta Learning

Documentation Status

1. Introduction

In meta learner scenario, it is common use dependent variables as parameters, and back propagate the gradient of the parameters. However, parameters of PyTorch Module are designed to be leaf nodes and it is forbidden for parameters to have grad_fn. Meta learning coders are therefore forced to rewrite the basic layers to adapt the meta learning requirements.

This module provide an extension of torch.nn.Module, DependentModule that has dependent parameters, allowing the differentiable dependent parameters. It also provide the method to transform nn.Module into DependentModule, and turning all of the parameters of a nn.Module into dependent parameters.

2. Installation

pip install MetaNN

3. Example

PyTorch suggest all parameters of a module to be independent variables. Using DependentModule arbitrary torch.nn.module can be transformed into dependent module.

from metann import DependentModule
from torch import nn
net = torch.nn.Sequential(
    nn.Linear(10, 100),
    nn.Linear(100, 5))
net = DependentModule(net)
print(net)

Higher-level api such as MAML class are more recommended to use.

from metann.meta import MAML, default_evaluator_classification as evaluator
from torch import nn
net = torch.nn.Sequential(
    nn.Linear(10, 100),
    nn.Linear(100, 5))
)
maml = MAML(net, steps_train=5, steps_eval=10, lr=0.01)
output = maml(data_train)
loss = evaluator(output, data_test)
loss.backward()

4. Documents

The documents are available at ReadTheDocs. MetaNN

5. License

MIT

Copyright (c) 2019-present, Hanqiao Yu

Contents

metann package

Subpackages

metann.utils package

Submodules
metann.utils.containers module
class metann.utils.containers.DefaultList(factory=<function _none_fun>, fill=None)[source]

Bases: object

fill(data: collections.abc.Iterable)[source]
class metann.utils.containers.MultipleList(lst)[source]

Bases: object

class metann.utils.containers.SubDict(super_dict: collections.abc.Mapping, keys=[], keep_order=True)[source]

Bases: collections.abc.MutableMapping

Provide a sub dict access to a super dict.

Parameters:
  • super_dict (Mapping) – The super dictionary where you want to take a sub dict
  • keys (iterable) – An iterable of keys according to which you want to access a sub dict
  • keep_order (bool) – If set to true the sub dict will keep the iteration order of the super dict when it is iterated. Default: True

Examples

>>> super_dict = collections.OrderedDict({'a': 1, 'b': 2, 'c': 3})
>>> sub_dict = SubDict(super_dict, keys=['a', 'b'])
update_keys()[source]

This method update the keys of the sub dict when the super dict is modified.

Note

Do not call this method when you use the built-in method only.

Returns:
Module contents
class metann.utils.SubDict(super_dict: collections.abc.Mapping, keys=[], keep_order=True)[source]

Bases: collections.abc.MutableMapping

Provide a sub dict access to a super dict.

Parameters:
  • super_dict (Mapping) – The super dictionary where you want to take a sub dict
  • keys (iterable) – An iterable of keys according to which you want to access a sub dict
  • keep_order (bool) – If set to true the sub dict will keep the iteration order of the super dict when it is iterated. Default: True

Examples

>>> super_dict = collections.OrderedDict({'a': 1, 'b': 2, 'c': 3})
>>> sub_dict = SubDict(super_dict, keys=['a', 'b'])
update_keys()[source]

This method update the keys of the sub dict when the super dict is modified.

Note

Do not call this method when you use the built-in method only.

Returns:

Submodules

metann.dependentmodule module

class metann.dependentmodule.DependentModule(*args, **kwargs)[source]

Bases: torch.nn.modules.module.Module

The PyTorch sugggest all parameters of a module to be independent variables, and forbid a parameter to have a grad_fn. This module provides an extension to nn.Module by register a subset of buffers as dependents, which indicates the dependent parameters. This enables the parameters of a DependentModule to be the dependent variables, which is useful in meta learning. This module calls DependentModule.to_dependentmodule when it is created. It turns the module and all of its submodules into sub class of DependentModule. Then you might use clear_params to transform all parameters to dependents.

Examples:

>>> net = Sequential(Linear(10, 5), Linear(5, 2))
>>> DependentModule(net)
DependentSequential(
  (0): DependentLinear(in_features=10, out_features=5, bias=True)
  (1): DependentLinear(in_features=5, out_features=2, bias=True)
)

Note

This class change the origin module when initializing, you might use

>>> DependentModule(deepcopy(net))

if you want the origin model stay unchanged.

clear_params(init=False, clear_filter=<function DependentModule.<lambda>>)[source]

Clear all parameters of self and register them as dependents.

Parameters:
  • init (bool) – Set the values of dependents to None if set to False, otherwise keep the value of origin parameters.
  • clear_filter – Function that return False when those modules you don’t want to clear parameters are input
dependents(recurse=True)[source]
Parameters:recurse – traverse only the direct submodules of self if set to False
Returns:iterator of dependents of self and sub modules.
Return type:Iterative
named_dependents(prefix: str = '', recurse=True)[source]
Parameters:
  • prefix – the prefix of the names
  • recurse – traverse only the direct submodules of self if set to False
Returns:

iterator of name, dependent pairs of self and sub modules.

Return type:

Iterative

register_dependent(name: str, tensor: torch.Tensor) → None[source]

register a named tensor to dependents.

Parameters:
  • name – name of dependent tensor
  • tensor (torch.Tensor) – dependent tensor

Examples

>>> dnet = DependentModule(net)
>>> dnet.register_dependent('some_tensor', torch.randn(3, 3))
>>> dnet.some_tensor
tensor([[ 0.4434,  0.9949, -0.4385],
        [-0.5292,  0.2555,  0.7772],
        [-0.5386,  0.6152, -0.3239]])
classmethod stateless(module: torch.nn.modules.module.Module, clear_filter=<function DependentModule.<lambda>>)[source]

transform input module into a DependentModule whose parameters are cleared.

Parameters:
  • module
  • clear_filter – Function that return False when those modules you don’t want to clear parameters are input
substitute(named_params, strict=True)[source]

Substitute self’s dependents with the tensors of same name

Parameters:
  • named_params – iterator of name, tensor pairs
  • strict (bool) – forbid named_params and self._dependents mismatch if set to True. default: True
substitute_from_list(params)[source]

Substitute from tensor list.

Parameters:params – iterator of tensors
classmethod to_dependentmodule(module: torch.nn.modules.module.Module, recurse=True)[source]

Transform a module and all its submodule into dependent module.

Parameters:
  • module
  • recurse – if set to be True all submodules will be transformed into dependent module recursively.
Returns:

a dependent module

Return type:

DependentModule

update_actives()[source]
update_shapes()[source]

Update the register shape of dependents. Call this method when a dependent is initialize with None and assign to a tensor. Do not call this method when you are using built-in methods only.

metann.meta module

class metann.meta.GDLearner(steps, lr, create_graph=True, evaluator=<function default_evaluator_classification>)[source]

Bases: metann.meta.Learner

forward(model, data, inplace=False, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class metann.meta.Learner[source]

Bases: torch.nn.modules.module.Module

forward(*args, inplace=False, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_inplace(model, data)[source]
forward_pure(model, data)[source]
class metann.meta.MAML(model, steps_train, steps_eval, lr, evaluator=<function default_evaluator_classification>, first_order=False)[source]

Bases: torch.nn.modules.module.Module

forward(data)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class metann.meta.MAMLpp(model, steps_train, steps_eval, lr, evaluator=<function default_evaluator_classification>, first_order=False)[source]

Bases: torch.nn.modules.module.Module

forward(data)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class metann.meta.MultiModel(model: metann.proto.ProtoModule, fast_weight_lst)[source]

Bases: torch.nn.modules.module.Module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class metann.meta.RMSPropLearner(lr=0.01, alpha=0.99, eps=1e-08, centered=False, create_graph=True, evaluator=<function default_evaluator_classification>, steps=None)[source]

Bases: metann.meta.Learner

forward_inplace(model, data, evaluator=None)[source]
forward_pure(model, data, evaluator=None)[source]
class metann.meta.SequentialGDLearner(lr, momentum=0.5, create_graph=True, evaluator=<function default_evaluator_classification>)[source]

Bases: metann.meta.Learner

forward_inplace(model, data, evaluator=None)[source]
forward_pure(model, data, evaluator=None, mimo=False)[source]
metann.meta.active_indices(lst)[source]
metann.meta.default_evaluator_classification(model, data, criterion=CrossEntropyLoss())[source]
metann.meta.mamlpp_evaluator(mimo: metann.meta.MultiModel, data, steps: int, evaluator, gamma=0.6)[source]

metann.proto module

class metann.proto.ProtoModule(module: torch.nn.modules.module.Module)[source]

Bases: torch.nn.modules.module.Module

This module extends nn.Module by providing functional method. It is a stateful module, but allows you to call its stateless functional.

Parameters:module – a nn.Module module
forward(*args, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

functional(params, training=None)[source]
Parameters:
  • params (iterable) – input model parameters for functional
  • training – if the functional set to trainning=True
Returns:

return the output of model

Examples

>>>learner = Learner(net) >>>outputs = learner.functional(net.parameters(), training=True)(x)

named_parameters(prefix='', recurse=True)[source]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters:
  • prefix (str) – prefix to prepend to all parameter names.
  • recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
Yields:

(str, Parameter) – Tuple containing the name and parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>    if name in ['bias']:
>>>        print(param.size())
metann.proto.tensor_copy(tensor_lst)[source]

Module contents

class metann.DependentModule(*args, **kwargs)[source]

Bases: torch.nn.modules.module.Module

The PyTorch sugggest all parameters of a module to be independent variables, and forbid a parameter to have a grad_fn. This module provides an extension to nn.Module by register a subset of buffers as dependents, which indicates the dependent parameters. This enables the parameters of a DependentModule to be the dependent variables, which is useful in meta learning. This module calls DependentModule.to_dependentmodule when it is created. It turns the module and all of its submodules into sub class of DependentModule. Then you might use clear_params to transform all parameters to dependents.

Examples:

>>> net = Sequential(Linear(10, 5), Linear(5, 2))
>>> DependentModule(net)
DependentSequential(
  (0): DependentLinear(in_features=10, out_features=5, bias=True)
  (1): DependentLinear(in_features=5, out_features=2, bias=True)
)

Note

This class change the origin module when initializing, you might use

>>> DependentModule(deepcopy(net))

if you want the origin model stay unchanged.

clear_params(init=False, clear_filter=<function DependentModule.<lambda>>)[source]

Clear all parameters of self and register them as dependents.

Parameters:
  • init (bool) – Set the values of dependents to None if set to False, otherwise keep the value of origin parameters.
  • clear_filter – Function that return False when those modules you don’t want to clear parameters are input
dependents(recurse=True)[source]
Parameters:recurse – traverse only the direct submodules of self if set to False
Returns:iterator of dependents of self and sub modules.
Return type:Iterative
named_dependents(prefix: str = '', recurse=True)[source]
Parameters:
  • prefix – the prefix of the names
  • recurse – traverse only the direct submodules of self if set to False
Returns:

iterator of name, dependent pairs of self and sub modules.

Return type:

Iterative

register_dependent(name: str, tensor: torch.Tensor) → None[source]

register a named tensor to dependents.

Parameters:
  • name – name of dependent tensor
  • tensor (torch.Tensor) – dependent tensor

Examples

>>> dnet = DependentModule(net)
>>> dnet.register_dependent('some_tensor', torch.randn(3, 3))
>>> dnet.some_tensor
tensor([[ 0.4434,  0.9949, -0.4385],
        [-0.5292,  0.2555,  0.7772],
        [-0.5386,  0.6152, -0.3239]])
classmethod stateless(module: torch.nn.modules.module.Module, clear_filter=<function DependentModule.<lambda>>)[source]

transform input module into a DependentModule whose parameters are cleared.

Parameters:
  • module
  • clear_filter – Function that return False when those modules you don’t want to clear parameters are input
substitute(named_params, strict=True)[source]

Substitute self’s dependents with the tensors of same name

Parameters:
  • named_params – iterator of name, tensor pairs
  • strict (bool) – forbid named_params and self._dependents mismatch if set to True. default: True
substitute_from_list(params)[source]

Substitute from tensor list.

Parameters:params – iterator of tensors
classmethod to_dependentmodule(module: torch.nn.modules.module.Module, recurse=True)[source]

Transform a module and all its submodule into dependent module.

Parameters:
  • module
  • recurse – if set to be True all submodules will be transformed into dependent module recursively.
Returns:

a dependent module

Return type:

DependentModule

update_actives()[source]
update_shapes()[source]

Update the register shape of dependents. Call this method when a dependent is initialize with None and assign to a tensor. Do not call this method when you are using built-in methods only.

class metann.ProtoModule(module: torch.nn.modules.module.Module)[source]

Bases: torch.nn.modules.module.Module

This module extends nn.Module by providing functional method. It is a stateful module, but allows you to call its stateless functional.

Parameters:module – a nn.Module module
forward(*args, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

functional(params, training=None)[source]
Parameters:
  • params (iterable) – input model parameters for functional
  • training – if the functional set to trainning=True
Returns:

return the output of model

Examples

>>>learner = Learner(net) >>>outputs = learner.functional(net.parameters(), training=True)(x)

named_parameters(prefix='', recurse=True)[source]

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

Parameters:
  • prefix (str) – prefix to prepend to all parameter names.
  • recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
Yields:

(str, Parameter) – Tuple containing the name and parameter

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> for name, param in self.named_parameters():
>>>    if name in ['bias']:
>>>        print(param.size())

Indices and tables