import torch
from torch import nn
from .proto import tensor_copy
from .utils.containers import MultipleList
from metann import ProtoModule
from metann.utils.containers import DefaultList
import numpy as np
is_tensor = np.vectorize(lambda x: isinstance(x, torch.Tensor))
[docs]def active_indices(lst):
indices = []
for k, v in enumerate(lst):
if isinstance(v, torch.Tensor):
indices.append(k)
return indices
[docs]def default_evaluator_classification(model, data, criterion=nn.CrossEntropyLoss()):
x, y = data
logits = model(x)
loss = criterion(logits, y)
return loss
# def mamlpp_evaluator(mimo, data, steps, evaluator, gamma=0.6):
# weights = [1*gamma**i for i in range(steps+1)]
# weights = list(reversed(weights))
# evaluators = [evaluator] * (steps+1)
# data = [data] * (steps+1)
# loss = mimo(data, evaluators)
# return sum(i[0] * i[1] for i in zip(loss, weights)), loss[-1]
[docs]class Learner(nn.Module):
def __init__(self):
super(Learner, self).__init__()
[docs] def forward(self, *args, inplace=False, **kwargs):
if inplace:
return self.forward_inplace(*args, **kwargs)
else:
return self.forward_pure(*args, **kwargs)
[docs] def forward_pure(self, model, data):
raise NotImplementedError
[docs] def forward_inplace(self, model, data):
raise NotImplementedError
[docs]class GDLearner(Learner):
def __init__(self, steps, lr, create_graph=True, evaluator=default_evaluator_classification):
super(GDLearner, self).__init__()
self.steps = steps
self.sgd = SequentialGDLearner(lr, momentum=0, create_graph=create_graph, evaluator=evaluator)
[docs] def forward(self, model, data, inplace=False, **kwargs):
kwargs['model'] = model
kwargs['data'] = [data, ]*self.steps
kwargs['inplace'] = inplace
return self.sgd(**kwargs)
[docs]class SequentialGDLearner(Learner):
def __init__(self, lr, momentum=0.5, create_graph=True, evaluator=default_evaluator_classification):
super(SequentialGDLearner, self).__init__()
self.lr = lr
self.momentum = momentum
self.create_graph = create_graph
self.evaluator = evaluator
[docs] def forward_pure(self, model, data, evaluator=None, mimo=False):
evaluator = self.evaluator if evaluator is None else evaluator
model = ProtoModule(model)
model.train()
if mimo:
fast_weights_lst = [MultipleList(list(model.parameters()))]
velocities = DefaultList(lambda: 0)
actives = active_indices(fast_weights_lst[-1])
for batch in data:
fast_weights = tensor_copy(fast_weights_lst[-1])
fast_loss = evaluator(model.functional(fast_weights), batch)
grads = torch.autograd.grad(fast_loss, fast_weights[actives],
create_graph=self.create_graph)
velocities = [grad + velocity*self.momentum for (grad, velocity) in zip(grads, velocities)]
fast_weights[actives] = [w - self.lr * g for (w, g) in zip(fast_weights[actives], velocities)]
fast_weights_lst.append(fast_weights)
return MultiModel(model, fast_weights_lst)
else:
fast_weights = MultipleList(list(model.parameters()))
velocities = DefaultList(lambda: 0)
actives = active_indices(fast_weights)
for batch in data:
fast_loss = evaluator(model.functional(fast_weights), batch)
grads = torch.autograd.grad(fast_loss, fast_weights[actives],
create_graph=self.create_graph)
velocities = [grad + velocity*self.momentum for (grad, velocity) in zip(grads, velocities)]
fast_weights[actives] = [w - self.lr * g for (w, g) in zip(fast_weights[actives], velocities)]
return model.functional(fast_weights)
[docs] def forward_inplace(self, model, data, evaluator=None):
evaluator = self.evaluator if evaluator is None else evaluator
optim = torch.optim.SGD(model.parameters(), lr=self.lr, momentum=self.momentum)
for batch in data:
optim.zero_grad()
loss = evaluator(model, batch)
loss.backward()
optim.step()
return model
def _rms_prop(data, grad, state):
if state['r'] is None:
state['r'] = torch.zeros_like(data)
if state['centered'] and state['grad_avg'] is None:
state['grad_avg'] = torch.zeros_like(data)
alpha, eps = state['alpha'], state['eps']
r = state['r'] = state['r'].mul(alpha).addcmul(grad, grad, value=1-alpha)
if state['centered']:
grad_avg = state['grad_avg'] = state['grad_avg'].mul(alpha).add(grad, alpha=1 - alpha)
avg = r.addcmul(grad_avg, grad_avg, value=-1).sqrt().add(eps)
else:
avg = r.sqrt().add(eps)
return data.addcdiv(grad, avg, value=-state['lr']), state
[docs]class RMSPropLearner(Learner):
def __init__(self, lr=1e-2, alpha=0.99, eps=1e-8, centered=False, create_graph=True,
evaluator=default_evaluator_classification, steps=None):
super(RMSPropLearner, self).__init__()
self.lr = lr
self.alpha = alpha
self.eps = eps
self.centered = centered
self.evaluator = evaluator
self.steps = steps
self.create_graph = create_graph
[docs] def forward_pure(self, model, data, evaluator=None):
if self.steps is not None:
data = [data, ]*self.steps
evaluator = self.evaluator if evaluator is None else evaluator
model = ProtoModule(model)
model.train()
fast_weights = MultipleList(list(model.parameters()))
actives = active_indices(fast_weights)
states = DefaultList(lambda: {'centered': self.centered, 'alpha': self.alpha,
'lr': self.lr, 'eps': self.eps,
'r': None, 'grad_avg': None})
for batch in data:
fast_loss = evaluator(model.functional(fast_weights), batch)
grads = torch.autograd.grad(fast_loss, fast_weights[actives],
create_graph=self.create_graph)
_fast_weights = []
for i, (w, g, s) in enumerate(zip(fast_weights[actives], grads, states)):
w, states[i] = _rms_prop(w, g, s)
_fast_weights.append(w)
fast_weights[actives] = _fast_weights
return model.functional(fast_weights)
[docs] def forward_inplace(self, model, data, evaluator=None):
if self.steps is not None:
data = [data, ]*self.steps
evaluator = self.evaluator if evaluator is None else evaluator
optim = torch.optim.RMSprop(model.parameters(), lr=self.lr,
alpha=self.alpha, eps=self.eps, centered=self.centered)
for batch in data:
optim.zero_grad()
loss = evaluator(model, batch)
loss.backward()
optim.step()
return model
[docs]class MAML(nn.Module):
def __init__(self, model, steps_train, steps_eval, lr,
evaluator=default_evaluator_classification, first_order=False):
super(MAML, self).__init__()
self.model = model
self.steps_train = steps_train
self.steps_eval = steps_eval
self.lr = lr
self.evaluator = evaluator
self.first_order = first_order
[docs] def forward(self, data):
if self.training:
steps = self.steps_train
else:
steps = self.steps_eval
learner = GDLearner(self.steps_train, self.lr, create_graph=not self.first_order)
return learner(self.model, data, evaluator=self.evaluator)
[docs]class MAMLpp(nn.Module):
def __init__(self, model, steps_train, steps_eval, lr,
evaluator=default_evaluator_classification, first_order=False):
super(MAMLpp, self).__init__()
self.model = model
self.steps_train = steps_train
self.steps_eval = steps_eval
self.lr = lr
self.evaluator = evaluator
self.first_order = first_order
[docs] def forward(self, data):
if self.training:
steps = self.steps_train
else:
steps = self.steps_eval
learner = GDLearner(self.steps_train, self.lr, create_graph=not self.first_order)
if self.training:
return learner(self.model, data, evaluator=self.evaluator, mimo=True)
else:
return learner(self.model, data, evaluator=self.evaluator)
[docs]class MultiModel(nn.Module):
def __init__(self, model: ProtoModule, fast_weight_lst):
super(MultiModel, self).__init__()
self.proto = model
self.params_lst = fast_weight_lst
def __getitem__(self, item):
return self.proto.functional(self.params_lst[item])
def __len__(self):
return len(self.params_lst)
def __iter__(self):
for i in range(len(self)):
yield self[i]
[docs] def forward(self, x):
raise NotImplementedError('Forward is not implemented for MultiModel.')
[docs]def mamlpp_evaluator(mimo: MultiModel, data, steps: int, evaluator, gamma=0.6):
weights = [1*gamma**i for i in range(steps+1)]
weights = list(reversed(weights))
loss_lst = []
for i in range(steps+1):
loss = evaluator(mimo[i], data)
loss_lst.append(loss)
return sum(i[0] * i[1] for i in zip(loss_lst, weights)), loss_lst[-1]