-
Notifications
You must be signed in to change notification settings - Fork 0
/
linear.py
108 lines (91 loc) · 2.76 KB
/
linear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
'''fully connected layers'''
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
from typing import Optional
from ai.model.actv import build_actv
def fc(
n1: int,
n2: int,
actv: Optional[str] = None,
bias: bool = True,
bias_init: Optional[float] = None,
scale_w: bool = False,
lr_mult: Optional[float] = None,
dropout: Optional[float] = None,
):
'''Fully connected layer.
INPUT
tensor[b, <n1>]
OUTPUT
tensor[b, <n2>]
operations in order:
linear
either:
torch.nn.Linear
or:
a custom implementation that can handle learning rate
multiplying, scaling weights, and initializing bias
activation function
dropout
ARGS
n1 : int
input size
n2 : int
output size
actv : str or null
activation (see model/actv.py)
bias : bool
enable bias in linear op (default true)
bias_init : float or null
optional initial value for bias
scale_w : bool
if enabled, scale weights by 1/sqrt(n1)
lr_mult : float or null
learning rate multiplier (scale weights and bias)
dropout : float or null
'''
if scale_w or lr_mult is not None or bias_init is not None:
linear = Linear(n1, n2, bias, bias_init, scale_w, lr_mult)
else:
linear = nn.Linear(n1, n2, bias=bias)
ops = [linear]
if actv is not None:
ops.append(build_actv(actv)) # type: ignore
if dropout is not None and dropout != 0.:
ops.append(nn.Dropout(dropout)) # type: ignore
if len(ops) > 1:
return nn.Sequential(*ops)
return ops[0]
class Linear(nn.Module):
def __init__(s,
n1: int,
n2: int,
bias: bool = True,
bias_init: Optional[float] = None,
scale_w: bool = False,
lr_mult: Optional[float] = None,
):
super().__init__()
if lr_mult is None:
lr_mult = 1.
s._weight = nn.Parameter(torch.randn([n2, n1]) / lr_mult)
s._bias = nn.Parameter(torch.randn([n2])) if bias else None
s._bias_init = bias_init
s._weight_mult = lr_mult
if scale_w:
s._weight_mult /= np.sqrt(n1)
s._bias_mult = lr_mult
def init_params(s):
nn.init.normal_(s._weight)
if s._bias is not None:
if s._bias_init is None:
nn.init.normal_(s._bias)
else:
nn.init.constant_(s._bias, s._bias_init)
def forward(s, x):
w = s._weight * s._weight_mult
if s._bias is None:
return F.linear(x, w)
return F.linear(x, w, bias=s._bias * s._bias_mult)