-
Notifications
You must be signed in to change notification settings - Fork 21
/
lsq.py
138 lines (112 loc) · 4.85 KB
/
lsq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
@inproceedings{
esser2020learned,
title={LEARNED STEP SIZE QUANTIZATION},
author={Steven K. Esser and Jeffrey L. McKinstry and Deepika Bablani and Rathinakumar Appuswamy and Dharmendra S. Modha},
booktitle={International Conference on Learning Representations},
year={2020},
url={https://openreview.net/forum?id=rkgO66VKDS}
}
https://quanoview.readthedocs.io/en/latest/_raw/LSQ.html
"""
import torch
import torch.nn.functional as F
import math
from models.modules import _Conv2dQ, Qmodes, _LinearQ, _ActQ
import ipdb
__all__ = ['Conv2dLSQ', 'LinearLSQ', 'ActLSQ']
class FunLSQ(torch.autograd.Function):
@staticmethod
def forward(ctx, weight, alpha, g, Qn, Qp):
assert alpha > 0, 'alpha = {}'.format(alpha)
ctx.save_for_backward(weight, alpha)
ctx.other = g, Qn, Qp
q_w = (weight / alpha).round().clamp(Qn, Qp)
w_q = q_w * alpha
return w_q
@staticmethod
def backward(ctx, grad_weight):
weight, alpha = ctx.saved_tensors
g, Qn, Qp = ctx.other
q_w = weight / alpha
indicate_small = (q_w < Qn).float()
indicate_big = (q_w > Qp).float()
# indicate_middle = torch.ones(indicate_small.shape).to(indicate_small.device) - indicate_small - indicate_big
indicate_middle = 1.0 - indicate_small - indicate_big # Thanks to @haolibai
grad_alpha = ((indicate_small * Qn + indicate_big * Qp + indicate_middle * (
-q_w + q_w.round())) * grad_weight * g).sum().unsqueeze(dim=0)
grad_weight = indicate_middle * grad_weight
return grad_weight, grad_alpha, None, None, None
def grad_scale(x, scale):
y = x
y_grad = x * scale
return y.detach() - y_grad.detach() + y_grad
def round_pass(x):
y = x.round()
y_grad = x
return y.detach() - y_grad.detach() + y_grad
class Conv2dLSQ(_Conv2dQ):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True, nbits=4,
mode=Qmodes.layer_wise):
super(Conv2dLSQ, self).__init__(
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias,
nbits=nbits, mode=mode)
def forward(self, x):
if self.alpha is None:
return F.conv2d(x, self.weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
Qn = -2 ** (self.nbits - 1)
Qp = 2 ** (self.nbits - 1) - 1
if self.training and self.init_state == 0:
self.alpha.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp))
self.init_state.fill_(1)
g = 1.0 / math.sqrt(self.weight.numel() * Qp)
# Method1: 31GB GPU memory (AlexNet w4a4 bs 2048) 17min/epoch
alpha = grad_scale(self.alpha, g)
w_q = round_pass((self.weight / alpha).clamp(Qn, Qp)) * alpha
# Method2: 25GB GPU memory (AlexNet w4a4 bs 2048) 32min/epoch
# w_q = FunLSQ.apply(self.weight, self.alpha, g, Qn, Qp)
return F.conv2d(x, w_q, self.bias, self.stride,
self.padding, self.dilation, self.groups)
class LinearLSQ(_LinearQ):
def __init__(self, in_features, out_features, bias=True, nbits=4):
super(LinearLSQ, self).__init__(in_features=in_features, out_features=out_features, bias=bias, nbits=nbits)
def forward(self, x):
if self.alpha is None:
return F.linear(x, self.weight, self.bias)
Qn = -2 ** (self.nbits - 1)
Qp = 2 ** (self.nbits - 1) - 1
if self.training and self.init_state == 0:
self.alpha.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp))
self.init_state.fill_(1)
g = 1.0 / math.sqrt(self.weight.numel() * Qp)
# Method1:
alpha = grad_scale(self.alpha, g)
w_q = round_pass((self.weight / alpha).clamp(Qn, Qp)) * alpha
# Method2:
# w_q = FunLSQ.apply(self.weight, self.alpha, g, Qn, Qp)
return F.linear(x, w_q, self.bias)
class ActLSQ(_ActQ):
def __init(self, nbits=4, signed=False):
super(ActLSQ, self).__init(nbits=nbits, signed=signed)
def forward(self, x):
if self.alpha is None:
return x
if self.signed:
Qn = -2 ** (self.nbits - 1)
Qp = 2 ** (self.nbits - 1) - 1
else:
Qn = 0
Qp = 2 ** self.nbits - 1
if self.training and self.init_state == 0:
self.alpha.data.copy_(2 * x.abs().mean() / math.sqrt(Qp))
self.init_state.fill_(1)
g = 1.0 / math.sqrt(x.numel() * Qp)
# Method1:
alpha = grad_scale(self.alpha, g)
x = round_pass((x / alpha).clamp(Qn, Qp)) * alpha
# Method2:
# x_q = FunLSQ.apply(x, self.alpha, g, Qn, Qp)
return x