|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import math |
|
|
|
|
|
class Quantizer(nn.Module): |
|
|
|
def __init__(self, shape=1): |
|
super(Quantizer, self).__init__() |
|
self.register_buffer('maxq', torch.tensor(0)) |
|
self.register_buffer('scale', torch.zeros(shape)) |
|
self.register_buffer('zero', torch.zeros(shape)) |
|
|
|
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False): |
|
|
|
self.maxq = torch.tensor(2**bits - 1) |
|
self.perchannel = perchannel |
|
self.sym = sym |
|
self.mse = mse |
|
self.norm = norm |
|
self.grid = grid |
|
self.maxshrink = maxshrink |
|
if trits: |
|
self.maxq = torch.tensor(-1) |
|
self.scale = torch.zeros_like(self.scale) |
|
|
|
def _quantize(self, x, scale, zero, maxq): |
|
if maxq < 0: |
|
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero |
|
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) |
|
return scale * (q - zero) |
|
|
|
def find_params(self, x, weight=False): |
|
dev = x.device |
|
self.maxq = self.maxq.to(dev) |
|
|
|
shape = x.shape |
|
if self.perchannel: |
|
if weight: |
|
x = x.flatten(1) |
|
else: |
|
if len(shape) == 4: |
|
x = x.permute([1, 0, 2, 3]) |
|
x = x.flatten(1) |
|
if len(shape) == 3: |
|
x = x.reshape((-1, shape[-1])).t() |
|
if len(shape) == 2: |
|
x = x.t() |
|
else: |
|
x = x.flatten().unsqueeze(0) |
|
|
|
tmp = torch.zeros(x.shape[0], device=dev) |
|
xmin = torch.minimum(x.min(1)[0], tmp) |
|
xmax = torch.maximum(x.max(1)[0], tmp) |
|
|
|
if self.sym: |
|
xmax = torch.maximum(torch.abs(xmin), xmax) |
|
tmp = xmin < 0 |
|
if torch.any(tmp): |
|
xmin[tmp] = -xmax[tmp] |
|
tmp = (xmin == 0) & (xmax == 0) |
|
xmin[tmp] = -1 |
|
xmax[tmp] = +1 |
|
|
|
if self.maxq < 0: |
|
self.scale = xmax |
|
self.zero = xmin |
|
else: |
|
self.scale = (xmax - xmin) / self.maxq |
|
if self.sym: |
|
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) |
|
else: |
|
self.zero = torch.round(-xmin / self.scale) |
|
|
|
if self.mse: |
|
best = torch.full([x.shape[0]], float('inf'), device=dev) |
|
for i in range(int(self.maxshrink * self.grid)): |
|
p = 1 - i / self.grid |
|
xmin1 = p * xmin |
|
xmax1 = p * xmax |
|
scale1 = (xmax1 - xmin1) / self.maxq |
|
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero |
|
q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) |
|
q -= x |
|
q.abs_() |
|
q.pow_(self.norm) |
|
err = torch.sum(q, 1) |
|
tmp = err < best |
|
if torch.any(tmp): |
|
best[tmp] = err[tmp] |
|
self.scale[tmp] = scale1[tmp] |
|
self.zero[tmp] = zero1[tmp] |
|
if not self.perchannel: |
|
if weight: |
|
tmp = shape[0] |
|
else: |
|
tmp = shape[1] if len(shape) != 3 else shape[2] |
|
self.scale = self.scale.repeat(tmp) |
|
self.zero = self.zero.repeat(tmp) |
|
|
|
if weight: |
|
shape = [-1] + [1] * (len(shape) - 1) |
|
self.scale = self.scale.reshape(shape) |
|
self.zero = self.zero.reshape(shape) |
|
return |
|
if len(shape) == 4: |
|
self.scale = self.scale.reshape((1, -1, 1, 1)) |
|
self.zero = self.zero.reshape((1, -1, 1, 1)) |
|
if len(shape) == 3: |
|
self.scale = self.scale.reshape((1, 1, -1)) |
|
self.zero = self.zero.reshape((1, 1, -1)) |
|
if len(shape) == 2: |
|
self.scale = self.scale.unsqueeze(0) |
|
self.zero = self.zero.unsqueeze(0) |
|
|
|
def quantize(self, x): |
|
if self.ready(): |
|
return self._quantize(x, self.scale, self.zero, self.maxq) |
|
|
|
return x |
|
|
|
def enabled(self): |
|
return self.maxq > 0 |
|
|
|
def ready(self): |
|
return torch.all(self.scale != 0) |
|
|