Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dq for mixbit #122

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sparsebit/quantization/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def register_quantizer(quantizer):
from . import lsq_plus
from . import pact
from . import adaround
from . import dq


def build_quantizer(cfg):
Expand Down
105 changes: 105 additions & 0 deletions sparsebit/quantization/quantizers/dq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
import warnings

from sparsebit.quantization.quantizers import Quantizer as BaseQuantizer
from sparsebit.quantization.quantizers import register_quantizer
from sparsebit.quantization.common import Granularity


class gs_scaling(torch.autograd.Function):
@staticmethod
def forward(ctx, x, ratio):
ctx.ratio = ratio
return x

@staticmethod
def backward(ctx, grad):
return grad * ctx.ratio, None

class STE(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
x_int = x.round()
return x_int

@staticmethod
def backward(ctx, grad):
return grad


@register_quantizer
class Quantizer(BaseQuantizer):
TYPE = "DQ"

def __init__(self, config):
super(Quantizer, self).__init__(config)
self.init_params = False # LSQ需要基于calibration做初始化

def calc_qparams(self):
if self.fake_fused:
return self.scale, self.zero_point
if not self.init_params:
x_oc = self.observer.data_cache.get_data_for_calibration(
Granularity.CHANNELWISE
)
if x_oc.min() < 0 and not self.qdesc.is_symmetric:
warnings.warn(
"Found data less than 0, reset quantizer scheme as symmetric"
)
self.qdesc.set_symmetric(True)
if self.is_perchannel:
scale = 2 * x_oc.abs().mean(axis=1) / math.sqrt(self.qdesc.qmax)
else:
scale = 2 * x_oc.abs().mean() / math.sqrt(self.qdesc.qmax)
self.scale = nn.Parameter(self._broadcast_qparams(scale.to(self.device)))
self.zero_point = self._broadcast_qparams(torch.zeros_like(self.scale))
self.qmax = nn.Parameter(torch.tensor(float(self.qdesc.qmax)).to(self.device))
if not self.qdesc.is_symmetric:
self.qmin = torch.tensor(0).to(self.device)
self.init_params = True
return self.scale, self.zero_point

def _qparams_preprocess(self, x):
if self.export_onnx:
return torch.tensor(
self.scale.abs().detach().cpu().numpy(), device=self.device
), torch.tensor(
torch.clamp(self.zero_point, self.qdesc.qmin, self.qmax)
.detach()
.cpu()
.numpy(),
device=self.device,
)
scale = self.scale.abs()
zero_point = torch.clamp(self.zero_point, self.qdesc.qmin, self.qdesc.qmax)
return scale, zero_point

def fix_bit(self):
with torch.no_grad():
if self.qdesc.is_symmetric:
cur_bit = torch.sqrt(2*self.qmax+2)
new_bit = cur_bit.round()
new_qmax = 2**(new_bit-1)-1
else:
cur_bit = torch.sqrt(self.qmax+1)
new_bit = cur_bit.round()
new_qmax = 2**(new_bit)-1

self.qmax.data.copy_(new_qmax)
self.qmax.requires_grad = False

def _forward(self, x, scale, zero_point):
if self.is_perchannel:
num_perchannel = x.numel() / x.shape[self.qdesc.ch_axis]
gs_ratio = 1.0 / math.sqrt(num_perchannel * self.qmax.item())
else:
gs_ratio = 1.0 / math.sqrt(x.numel() * self.qmax.item())
scale = gs_scaling.apply(scale, gs_ratio)
if self.qdesc.is_symmetric:
x_dq = STE.apply((x / scale).clamp(-self.qmax-1, self.qmax)) * scale
else:
x_dq = STE.apply((x / scale).clamp(self.qmin, self.qmax)) * scale
return x_dq
15 changes: 15 additions & 0 deletions sparsebit/quantization/regularizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
REGULARIZERS_MAP = {}


def register_regularizer(regularizer):
REGULARIZERS_MAP[regularizer.TYPE.lower()] = regularizer
return regularizer


from .base import Regularizer
from . import pact, bops, memory, bit_round


def build_regularizer(type, config, *args, **kwargs):
regularizer = REGULARIZERS_MAP[type](config, *args, **kwargs)
return regularizer
6 changes: 6 additions & 0 deletions sparsebit/quantization/regularizers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class Regularizer(object):
def __init__(self, config):
self.config = config

def __call__(self):
pass
39 changes: 39 additions & 0 deletions sparsebit/quantization/regularizers/bit_round.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
import math
from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer
from sparsebit.quantization.regularizers import register_regularizer
from sparsebit.quantization.quantizers.base import Quantizer


@register_regularizer
class Regularizer(BaseRegularizer):
TYPE = "bit_round"

def __init__(self,
config,
qmodel,
max_coeff = 1e6,
epochs_wo_regularization = 20,
total_epochs = 90,
):
super(Regularizer, self).__init__(config)
self.epochs_wo_regularization = epochs_wo_regularization
self.total_epochs = total_epochs - epochs_wo_regularization
self.config = config
self.max_coeff = max_coeff
self.quantizers = []
for m in qmodel.modules():
if isinstance(m, Quantizer) and not m.fake_fused:
self.quantizers.append(m)

def __call__(self, epoch):
loss = torch.tensor(0)
if epoch < self.epochs_wo_regularization:
return loss
for quantizer in self.quantizers:
bit = torch.sqrt(2*quantizer.qmax+2) if quantizer.qdesc.is_symmetric else torch.sqrt(quantizer.qmax+1)
bit_floor = math.floor(bit.item())
bit_bias = bit - bit_floor
loss += bit_bias*(1-bit_bias)/bit_floor
coeff = self.max_coeff*(epoch-self.epochs_wo_regularization+1)/self.total_epochs
return coeff*loss/len(self.quantizers)
82 changes: 82 additions & 0 deletions sparsebit/quantization/regularizers/bops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer
from sparsebit.quantization.regularizers import register_regularizer
from sparsebit.quantization.modules import QConv2d, QLinear, MatMul


@register_regularizer
class Regularizer(BaseRegularizer):
TYPE = "bops"

def __init__(self,
config,
qmodel,
coeff = 1e2,
):
super(Regularizer, self).__init__(config)
self.config = config
self.coeff = coeff
self.module_dict = {}
self.bops_limitation = 0
for node in qmodel.model.graph.nodes:
if node.op in ["placeholder", "output"]:
continue
module = getattr(qmodel.model, node.target)
if (
isinstance(module, (QConv2d, QLinear))
and getattr(module, "input_quantizer", None)
and getattr(module, "weight_quantizer", None)
):
flops = module.weight.numel()
if isinstance(module, QConv2d):
flops *= module.output_shape[-1] * module.output_shape[-2]
elif (
isinstance(module, QLinear)
and module.input_quantizer.observer.qdesc._ch_axis == 2 # NLC
):
flops *= module.output_shape[1]

self.module_dict[node.target] = {
"flops": flops,
"is_symmetric1":module.weight_quantizer.qdesc.is_symmetric,
"is_symmetric2":module.input_quantizer.qdesc.is_symmetric,
"qmax1": module.weight_quantizer.qmax,
"qmax2": module.input_quantizer.qmax,
}
self.bops_limitation += (
flops
* module.input_quantizer.bit
* module.weight_quantizer.bit
)
elif isinstance(module, MatMul) and getattr(
module, "input_quantizer_generated", None
):
input0_quantizer = getattr(qmodel.model, node.all_input_nodes[0].target)
input1_quantizer = getattr(qmodel.model, node.all_input_nodes[0].target)
input0_shape = input0_quantizer.output_shape
input1_shape = input1_quantizer.output_shape
flops = (
torch.prod(torch.tensor(input0_shape[1:])) * input1_shape[-1]
).item()
self.module_dict[node.target] = {
"flops": flops,
"is_symmetric1":input0_quantizer.qdesc.is_symmetric,
"is_symmetric2":input1_quantizer.qdesc.is_symmetric,
"qmax1": input0_quantizer.qmax,
"qmax2": input1_quantizer.qmax,
}
self.bops_limitation += flops * (input0_quantizer.bit*input1_quantizer.bit)

self.bops_limitation /= 1e9
print("BOPS limitation of the model:", str(self.bops_limitation), "GBOPS")

def __call__(self):
current_bops = 0
for n, dict in self.module_dict.items():
bit1 = torch.log2(2*dict["qmax1"]+2) if dict["is_symmetric1"] else torch.log2(dict["qmax1"]+1)
bit2 = torch.log2(2*dict["qmax2"]+2) if dict["is_symmetric2"] else torch.log2(dict["qmax2"]+1)
current_bops += dict["flops"]*bit1*bit2
current_bops /= 1e9
if current_bops.item()<=self.bops_limitation:
return torch.zeros(1, device=current_bops.device)
return self.coeff*(current_bops-self.bops_limitation)
45 changes: 45 additions & 0 deletions sparsebit/quantization/regularizers/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer
from sparsebit.quantization.regularizers import register_regularizer
from sparsebit.quantization.modules import QConv2d, QLinear, MatMul


@register_regularizer
class Regularizer(BaseRegularizer):
TYPE = "memory"

def __init__(self,
config,
qmodel,
coeff = 1e2,
):
super(Regularizer, self).__init__(config)
self.config = config
self.coeff = coeff
self.module_dict = {}
self.memory_limitation = 0
for node in qmodel.model.graph.nodes:
if node.op in ["placeholder", "output"]:
continue
module = getattr(qmodel.model, node.target)
if (
isinstance(module, (QConv2d, QLinear))
and getattr(module, "weight_quantizer", None)
):
self.module_dict[node.target] = module.weight
self.memory_limitation += module.weight.numel() * module.weight_quantizer.bit/(2**23)
self.module_dict[node.target] = {
"weight_numel": module.weight.numel(),
"qmax": module.weight_quantizer.qmax,
}

print("Memory limitation of the model:", str(self.memory_limitation), "MB")

def __call__(self):
current_memory = 0
for n, dict in self.module_dict.items():
bit = torch.log2(2*dict["qmax"]+2)
current_memory += dict["weight_numel"]*bit/(2**23)
if current_memory.item()<=self.memory_limitation:
return torch.zeros(1, device=current_memory.device)
return self.coeff*(current_memory-self.memory_limitation)
20 changes: 20 additions & 0 deletions sparsebit/quantization/regularizers/pact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch

from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer
from sparsebit.quantization.regularizers import register_regularizer


@register_regularizer
class Regularizer(BaseRegularizer):
TYPE = "Pact"

def __init__(self, config):
super(Regularizer, self).__init__(config)
self.config = config

def __call__(self, model):
loss = 0.0
for n, p in model.named_parameters():
if "alpha" in n:
loss += (p**2).sum()
return loss
1 change: 1 addition & 0 deletions sparsebit/quantization/tools/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,5 @@ def module_forward(
outputs.append(to_cpu(module(*args, **kwargs)))
if isinstance(module, QuantOpr):
module.set_quant(w_quant=False, a_quant=False)
module.output_shape = outputs[0].shape
return outputs