# ADAM optimizer with polyak decay
#
# https://github.com/nicola-decao/BNAF (MIT license)
import math
import torch
[docs]
class Adam(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False, polyak=0.0):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= polyak <= 1.0:
raise ValueError("Invalid polyak decay term: {}".format(polyak))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad, polyak=polyak)
super(Adam, self).__init__(params, defaults)
def __setstate__(self, state):
super(Adam, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
[docs]
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
# Exponential moving average of param
state['exp_avg_param'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
if group['weight_decay'] != 0:
grad.add_(group['weight_decay'], p.data)
# Decay the first and second moment running average coefficient
exp_avg = exp_avg * beta1 + (1 - beta1) * grad
exp_avg_sq = exp_avg_sq * beta2 + (1 - beta2) * (grad**2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
# Update param
p.data = p.data - step_size * exp_avg / denom
polyak = self.defaults['polyak']
state['exp_avg_param'] = polyak * state['exp_avg_param'] + (1 - polyak) * p.data
return loss
[docs]
def swap(self):
"""
Swapping the running average of params and the current params for saving parameters using polyak averaging
"""
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
new = p.data
p.data = state['exp_avg_param']
state['exp_avg_param'] = new
[docs]
def substitute(self):
for group in self.param_groups:
for p in group['params']:
p.data = self.state[p]['exp_avg_param']