|
- import torch
- import numpy as np
- import torch.nn as nn
- from torch.autograd import Variable
-
-
- def _concat(xs):
- return torch.cat([x.view(-1) for x in xs])
-
-
- def _clip(grads, max_norm):
- total_norm = 0
- for g in grads:
- param_norm = g.data.norm(2)
- total_norm += param_norm ** 2
- total_norm = total_norm ** 0.5
- clip_coef = max_norm / (total_norm + 1e-6)
- if clip_coef < 1:
- for g in grads:
- g.data.mul_(clip_coef)
- return clip_coef
-
-
- class Architect(object):
-
- def __init__(self, model, args):
- self.network_weight_decay = args.wdecay
- self.network_clip = args.clip
- self.model = model
- self.optimizer = torch.optim.Adam(self.model.arch_parameters(), lr=args.arch_lr, weight_decay=args.arch_wdecay)
-
- def _compute_unrolled_model(self, hidden, input, target, eta):
- loss, hidden_next = self.model._loss(hidden, input, target)
- theta = _concat(self.model.parameters()).data
- grads = torch.autograd.grad(loss, self.model.parameters())
- clip_coef = _clip(grads, self.network_clip)
- dtheta = _concat(grads).data + self.network_weight_decay*theta
- unrolled_model = self._construct_model_from_theta(theta.sub(eta, dtheta))
- return unrolled_model, clip_coef
-
- def step(self,
- hidden_train, input_train, target_train,
- hidden_valid, input_valid, target_valid,
- network_optimizer, unrolled):
- eta = network_optimizer.param_groups[0]['lr']
- self.optimizer.zero_grad()
- if unrolled:
- hidden = self._backward_step_unrolled(hidden_train, input_train, target_train, hidden_valid, input_valid, target_valid, eta)
- else:
- hidden = self._backward_step(hidden_valid, input_valid, target_valid)
- self.optimizer.step()
- return hidden, None
-
- def _backward_step(self, hidden, input, target):
- loss, hidden_next = self.model._loss(hidden, input, target)
- loss.backward()
- return hidden_next
-
- def _backward_step_unrolled(self,
- hidden_train, input_train, target_train,
- hidden_valid, input_valid, target_valid, eta):
- unrolled_model, clip_coef = self._compute_unrolled_model(hidden_train, input_train, target_train, eta)
- unrolled_loss, hidden_next = unrolled_model._loss(hidden_valid, input_valid, target_valid)
-
- unrolled_loss.backward()
- dalpha = [v.grad for v in unrolled_model.arch_parameters()]
- dtheta = [v.grad for v in unrolled_model.parameters()]
- _clip(dtheta, self.network_clip)
- vector = [dt.data for dt in dtheta]
- implicit_grads = self._hessian_vector_product(vector, hidden_train, input_train, target_train, r=1e-2)
-
- for g, ig in zip(dalpha, implicit_grads):
- g.data.sub_(eta * clip_coef, ig.data)
-
- for v, g in zip(self.model.arch_parameters(), dalpha):
- if v.grad is None:
- v.grad = Variable(g.data)
- else:
- v.grad.data.copy_(g.data)
- return hidden_next
-
- def _construct_model_from_theta(self, theta):
- model_new = self.model.new()
- model_dict = self.model.state_dict()
-
- params, offset = {}, 0
- for k, v in self.model.named_parameters():
- v_length = np.prod(v.size())
- params[k] = theta[offset: offset+v_length].view(v.size())
- offset += v_length
-
- assert offset == len(theta)
- model_dict.update(params)
- model_new.load_state_dict(model_dict)
- return model_new.cuda()
-
- def _hessian_vector_product(self, vector, hidden, input, target, r=1e-2):
- R = r / _concat(vector).norm()
- for p, v in zip(self.model.parameters(), vector):
- p.data.add_(R, v)
- loss, _ = self.model._loss(hidden, input, target)
- grads_p = torch.autograd.grad(loss, self.model.arch_parameters())
-
- for p, v in zip(self.model.parameters(), vector):
- p.data.sub_(2*R, v)
- loss, _ = self.model._loss(hidden, input, target)
- grads_n = torch.autograd.grad(loss, self.model.arch_parameters())
-
- for p, v in zip(self.model.parameters(), vector):
- p.data.add_(R, v)
-
- return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)]
|