-
Notifications
You must be signed in to change notification settings - Fork 0
/
optim.py
64 lines (53 loc) · 2.11 KB
/
optim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
import argparse
import json
import math
import os
import random
import signal
import subprocess
import sys
import time
import numpy as np
from PIL import Image, ImageOps, ImageFilter
from torch import nn, optim
import torch
import torchvision
import torchvision.transforms as transforms
class LARS(optim.Optimizer):
def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001,
weight_decay_filter=False, lars_adaptation_filter=False):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
eta=eta, weight_decay_filter=weight_decay_filter,
lars_adaptation_filter=lars_adaptation_filter)
super().__init__(params, defaults)
def exclude_bias_and_norm(self, p):
return p.ndim == 1
@torch.no_grad()
def step(self):
for g in self.param_groups:
for p in g['params']:
dp = p.grad
if dp is None:
continue
if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p):
dp = dp.add(p, alpha=g['weight_decay'])
if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p):
param_norm = torch.linalg.norm(p)
update_norm = torch.linalg.norm(dp)
one = torch.ones_like(param_norm)
q = torch.where(param_norm > 0.,
torch.where(update_norm > 0,
(g['eta'] * param_norm / update_norm), one), one)
dp = dp.mul(q)
param_state = self.state[p]
if 'mu' not in param_state:
param_state['mu'] = torch.zeros_like(p)
mu = param_state['mu']
mu.mul_(g['momentum']).add_(dp)
p.add_(mu, alpha=-g['lr'])