diff --git a/sparsebit/quantization/quantizers/__init__.py b/sparsebit/quantization/quantizers/__init__.py index d070308..7f0ff55 100644 --- a/sparsebit/quantization/quantizers/__init__.py +++ b/sparsebit/quantization/quantizers/__init__.py @@ -13,6 +13,7 @@ def register_quantizer(quantizer): from . import lsq_plus from . import pact from . import adaround +from . import dq def build_quantizer(cfg): diff --git a/sparsebit/quantization/quantizers/dq.py b/sparsebit/quantization/quantizers/dq.py new file mode 100644 index 0000000..7c3ef60 --- /dev/null +++ b/sparsebit/quantization/quantizers/dq.py @@ -0,0 +1,105 @@ +import math +import torch +import torch.nn as nn +from torch.nn import functional as F +import warnings + +from sparsebit.quantization.quantizers import Quantizer as BaseQuantizer +from sparsebit.quantization.quantizers import register_quantizer +from sparsebit.quantization.common import Granularity + + +class gs_scaling(torch.autograd.Function): + @staticmethod + def forward(ctx, x, ratio): + ctx.ratio = ratio + return x + + @staticmethod + def backward(ctx, grad): + return grad * ctx.ratio, None + +class STE(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + x_int = x.round() + return x_int + + @staticmethod + def backward(ctx, grad): + return grad + + +@register_quantizer +class Quantizer(BaseQuantizer): + TYPE = "DQ" + + def __init__(self, config): + super(Quantizer, self).__init__(config) + self.init_params = False # LSQ需要基于calibration做初始化 + + def calc_qparams(self): + if self.fake_fused: + return self.scale, self.zero_point + if not self.init_params: + x_oc = self.observer.data_cache.get_data_for_calibration( + Granularity.CHANNELWISE + ) + if x_oc.min() < 0 and not self.qdesc.is_symmetric: + warnings.warn( + "Found data less than 0, reset quantizer scheme as symmetric" + ) + self.qdesc.set_symmetric(True) + if self.is_perchannel: + scale = 2 * x_oc.abs().mean(axis=1) / math.sqrt(self.qdesc.qmax) + else: + scale = 2 * x_oc.abs().mean() / math.sqrt(self.qdesc.qmax) + self.scale = nn.Parameter(self._broadcast_qparams(scale.to(self.device))) + self.zero_point = self._broadcast_qparams(torch.zeros_like(self.scale)) + self.qmax = nn.Parameter(torch.tensor(float(self.qdesc.qmax)).to(self.device)) + if not self.qdesc.is_symmetric: + self.qmin = torch.tensor(0).to(self.device) + self.init_params = True + return self.scale, self.zero_point + + def _qparams_preprocess(self, x): + if self.export_onnx: + return torch.tensor( + self.scale.abs().detach().cpu().numpy(), device=self.device + ), torch.tensor( + torch.clamp(self.zero_point, self.qdesc.qmin, self.qmax) + .detach() + .cpu() + .numpy(), + device=self.device, + ) + scale = self.scale.abs() + zero_point = torch.clamp(self.zero_point, self.qdesc.qmin, self.qdesc.qmax) + return scale, zero_point + + def fix_bit(self): + with torch.no_grad(): + if self.qdesc.is_symmetric: + cur_bit = torch.sqrt(2*self.qmax+2) + new_bit = cur_bit.round() + new_qmax = 2**(new_bit-1)-1 + else: + cur_bit = torch.sqrt(self.qmax+1) + new_bit = cur_bit.round() + new_qmax = 2**(new_bit)-1 + + self.qmax.data.copy_(new_qmax) + self.qmax.requires_grad = False + + def _forward(self, x, scale, zero_point): + if self.is_perchannel: + num_perchannel = x.numel() / x.shape[self.qdesc.ch_axis] + gs_ratio = 1.0 / math.sqrt(num_perchannel * self.qmax.item()) + else: + gs_ratio = 1.0 / math.sqrt(x.numel() * self.qmax.item()) + scale = gs_scaling.apply(scale, gs_ratio) + if self.qdesc.is_symmetric: + x_dq = STE.apply((x / scale).clamp(-self.qmax-1, self.qmax)) * scale + else: + x_dq = STE.apply((x / scale).clamp(self.qmin, self.qmax)) * scale + return x_dq diff --git a/sparsebit/quantization/regularizers/__init__.py b/sparsebit/quantization/regularizers/__init__.py new file mode 100644 index 0000000..ff7e3d5 --- /dev/null +++ b/sparsebit/quantization/regularizers/__init__.py @@ -0,0 +1,15 @@ +REGULARIZERS_MAP = {} + + +def register_regularizer(regularizer): + REGULARIZERS_MAP[regularizer.TYPE.lower()] = regularizer + return regularizer + + +from .base import Regularizer +from . import pact, bops, memory, bit_round + + +def build_regularizer(type, config, *args, **kwargs): + regularizer = REGULARIZERS_MAP[type](config, *args, **kwargs) + return regularizer \ No newline at end of file diff --git a/sparsebit/quantization/regularizers/base.py b/sparsebit/quantization/regularizers/base.py new file mode 100644 index 0000000..acc8579 --- /dev/null +++ b/sparsebit/quantization/regularizers/base.py @@ -0,0 +1,6 @@ +class Regularizer(object): + def __init__(self, config): + self.config = config + + def __call__(self): + pass \ No newline at end of file diff --git a/sparsebit/quantization/regularizers/bit_round.py b/sparsebit/quantization/regularizers/bit_round.py new file mode 100644 index 0000000..a763055 --- /dev/null +++ b/sparsebit/quantization/regularizers/bit_round.py @@ -0,0 +1,39 @@ +import torch +import math +from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer +from sparsebit.quantization.regularizers import register_regularizer +from sparsebit.quantization.quantizers.base import Quantizer + + +@register_regularizer +class Regularizer(BaseRegularizer): + TYPE = "bit_round" + + def __init__(self, + config, + qmodel, + max_coeff = 1e6, + epochs_wo_regularization = 20, + total_epochs = 90, + ): + super(Regularizer, self).__init__(config) + self.epochs_wo_regularization = epochs_wo_regularization + self.total_epochs = total_epochs - epochs_wo_regularization + self.config = config + self.max_coeff = max_coeff + self.quantizers = [] + for m in qmodel.modules(): + if isinstance(m, Quantizer) and not m.fake_fused: + self.quantizers.append(m) + + def __call__(self, epoch): + loss = torch.tensor(0) + if epoch < self.epochs_wo_regularization: + return loss + for quantizer in self.quantizers: + bit = torch.sqrt(2*quantizer.qmax+2) if quantizer.qdesc.is_symmetric else torch.sqrt(quantizer.qmax+1) + bit_floor = math.floor(bit.item()) + bit_bias = bit - bit_floor + loss += bit_bias*(1-bit_bias)/bit_floor + coeff = self.max_coeff*(epoch-self.epochs_wo_regularization+1)/self.total_epochs + return coeff*loss/len(self.quantizers) \ No newline at end of file diff --git a/sparsebit/quantization/regularizers/bops.py b/sparsebit/quantization/regularizers/bops.py new file mode 100644 index 0000000..57035c7 --- /dev/null +++ b/sparsebit/quantization/regularizers/bops.py @@ -0,0 +1,82 @@ +import torch +from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer +from sparsebit.quantization.regularizers import register_regularizer +from sparsebit.quantization.modules import QConv2d, QLinear, MatMul + + +@register_regularizer +class Regularizer(BaseRegularizer): + TYPE = "bops" + + def __init__(self, + config, + qmodel, + coeff = 1e2, + ): + super(Regularizer, self).__init__(config) + self.config = config + self.coeff = coeff + self.module_dict = {} + self.bops_limitation = 0 + for node in qmodel.model.graph.nodes: + if node.op in ["placeholder", "output"]: + continue + module = getattr(qmodel.model, node.target) + if ( + isinstance(module, (QConv2d, QLinear)) + and getattr(module, "input_quantizer", None) + and getattr(module, "weight_quantizer", None) + ): + flops = module.weight.numel() + if isinstance(module, QConv2d): + flops *= module.output_shape[-1] * module.output_shape[-2] + elif ( + isinstance(module, QLinear) + and module.input_quantizer.observer.qdesc._ch_axis == 2 # NLC + ): + flops *= module.output_shape[1] + + self.module_dict[node.target] = { + "flops": flops, + "is_symmetric1":module.weight_quantizer.qdesc.is_symmetric, + "is_symmetric2":module.input_quantizer.qdesc.is_symmetric, + "qmax1": module.weight_quantizer.qmax, + "qmax2": module.input_quantizer.qmax, + } + self.bops_limitation += ( + flops + * module.input_quantizer.bit + * module.weight_quantizer.bit + ) + elif isinstance(module, MatMul) and getattr( + module, "input_quantizer_generated", None + ): + input0_quantizer = getattr(qmodel.model, node.all_input_nodes[0].target) + input1_quantizer = getattr(qmodel.model, node.all_input_nodes[0].target) + input0_shape = input0_quantizer.output_shape + input1_shape = input1_quantizer.output_shape + flops = ( + torch.prod(torch.tensor(input0_shape[1:])) * input1_shape[-1] + ).item() + self.module_dict[node.target] = { + "flops": flops, + "is_symmetric1":input0_quantizer.qdesc.is_symmetric, + "is_symmetric2":input1_quantizer.qdesc.is_symmetric, + "qmax1": input0_quantizer.qmax, + "qmax2": input1_quantizer.qmax, + } + self.bops_limitation += flops * (input0_quantizer.bit*input1_quantizer.bit) + + self.bops_limitation /= 1e9 + print("BOPS limitation of the model:", str(self.bops_limitation), "GBOPS") + + def __call__(self): + current_bops = 0 + for n, dict in self.module_dict.items(): + bit1 = torch.log2(2*dict["qmax1"]+2) if dict["is_symmetric1"] else torch.log2(dict["qmax1"]+1) + bit2 = torch.log2(2*dict["qmax2"]+2) if dict["is_symmetric2"] else torch.log2(dict["qmax2"]+1) + current_bops += dict["flops"]*bit1*bit2 + current_bops /= 1e9 + if current_bops.item()<=self.bops_limitation: + return torch.zeros(1, device=current_bops.device) + return self.coeff*(current_bops-self.bops_limitation) \ No newline at end of file diff --git a/sparsebit/quantization/regularizers/memory.py b/sparsebit/quantization/regularizers/memory.py new file mode 100644 index 0000000..d02d402 --- /dev/null +++ b/sparsebit/quantization/regularizers/memory.py @@ -0,0 +1,45 @@ +import torch +from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer +from sparsebit.quantization.regularizers import register_regularizer +from sparsebit.quantization.modules import QConv2d, QLinear, MatMul + + +@register_regularizer +class Regularizer(BaseRegularizer): + TYPE = "memory" + + def __init__(self, + config, + qmodel, + coeff = 1e2, + ): + super(Regularizer, self).__init__(config) + self.config = config + self.coeff = coeff + self.module_dict = {} + self.memory_limitation = 0 + for node in qmodel.model.graph.nodes: + if node.op in ["placeholder", "output"]: + continue + module = getattr(qmodel.model, node.target) + if ( + isinstance(module, (QConv2d, QLinear)) + and getattr(module, "weight_quantizer", None) + ): + self.module_dict[node.target] = module.weight + self.memory_limitation += module.weight.numel() * module.weight_quantizer.bit/(2**23) + self.module_dict[node.target] = { + "weight_numel": module.weight.numel(), + "qmax": module.weight_quantizer.qmax, + } + + print("Memory limitation of the model:", str(self.memory_limitation), "MB") + + def __call__(self): + current_memory = 0 + for n, dict in self.module_dict.items(): + bit = torch.log2(2*dict["qmax"]+2) + current_memory += dict["weight_numel"]*bit/(2**23) + if current_memory.item()<=self.memory_limitation: + return torch.zeros(1, device=current_memory.device) + return self.coeff*(current_memory-self.memory_limitation) \ No newline at end of file diff --git a/sparsebit/quantization/regularizers/pact.py b/sparsebit/quantization/regularizers/pact.py new file mode 100644 index 0000000..21c90a5 --- /dev/null +++ b/sparsebit/quantization/regularizers/pact.py @@ -0,0 +1,20 @@ +import torch + +from sparsebit.quantization.regularizers import Regularizer as BaseRegularizer +from sparsebit.quantization.regularizers import register_regularizer + + +@register_regularizer +class Regularizer(BaseRegularizer): + TYPE = "Pact" + + def __init__(self, config): + super(Regularizer, self).__init__(config) + self.config = config + + def __call__(self, model): + loss = 0.0 + for n, p in model.named_parameters(): + if "alpha" in n: + loss += (p**2).sum() + return loss \ No newline at end of file diff --git a/sparsebit/quantization/tools/calibration.py b/sparsebit/quantization/tools/calibration.py index 6c2f108..c268ac8 100644 --- a/sparsebit/quantization/tools/calibration.py +++ b/sparsebit/quantization/tools/calibration.py @@ -157,4 +157,5 @@ def module_forward( outputs.append(to_cpu(module(*args, **kwargs))) if isinstance(module, QuantOpr): module.set_quant(w_quant=False, a_quant=False) + module.output_shape = outputs[0].shape return outputs