diff --git a/examples/quantization_w8a8_fp8/llama3_example.py b/examples/quantization_w8a8_fp8/llama3_example.py index 9de05bea9..8506f2a85 100644 --- a/examples/quantization_w8a8_fp8/llama3_example.py +++ b/examples/quantization_w8a8_fp8/llama3_example.py @@ -1,5 +1,5 @@ from transformers import AutoModelForCausalLM, AutoTokenizer - +from datasets import load_dataset from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.transformers import oneshot @@ -23,21 +23,64 @@ "Embedding": { "output_activations": None }, - "model.layers.21.mlp.down_proj": { - "weight": None, - "input_activations": None, - "output_activations": None + "model.layers.21": { + "output_activationst": None } } recipe = QuantizationModifier( - targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"], transforms=transforms + targets="Linear", scheme="FP8", ignore=["lm_head"], transforms=transforms ) +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" + +# Select number of samples. 512 samples is a good place to start. +# Increasing the number of samples can improve accuracy. +NUM_CALIBRATION_SAMPLES = 1 +MAX_SEQUENCE_LENGTH = 2048 + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + + # Apply quantization. -oneshot(model=model, recipe=recipe) +oneshot( + model=model, + recipe=recipe, + dataset=ds, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES,) # Confirm generations of the quantized model look sane. print("========== SAMPLE GENERATION ==============") + input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda") output = model.generate(input_ids, max_new_tokens=20) print(tokenizer.decode(output[0])) diff --git a/src/llmcompressor/modifiers/quantization/calibration.py b/src/llmcompressor/modifiers/quantization/calibration.py index ee4ce171e..14541ac22 100644 --- a/src/llmcompressor/modifiers/quantization/calibration.py +++ b/src/llmcompressor/modifiers/quantization/calibration.py @@ -121,8 +121,19 @@ def update_weight_zp_scale(module: Module): if module.quantization_scheme.weights is not None: # set weight scale and zero_point up front, calibration data doesn't affect it + transforms = getattr(module, "transforms", None) + + if transforms: + weight_transform = transforms.transforms.get("weight") + untransformed_weight = module.weight.data.clone() + transformed_weight = weight_transform(module.weight) + module.weight.data.copy_(transformed_weight) + call_observer(module=module, base_name="weight") + if transforms: + module.weight.data.copy_(untransformed_weight) + def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): """ @@ -145,6 +156,9 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str): value=value, ) +def calibrate_weight_hook(module: Module, args: Any): + #print("updating weights") + update_weight_zp_scale(module) def calibrate_input_hook(module: Module, args: Any): """ @@ -153,7 +167,14 @@ def calibrate_input_hook(module: Module, args: Any): input QDQ in the module's forward pass. """ args = args[0] if isinstance(args, tuple) else args - calibrate_activations(module, value=args, base_name="input") + transforms = getattr(module, "transforms", None) + input_ = args + if transforms: + input_transform = transforms.transforms.get("input_activations") + if input_transform: + input_ = input_transform(input_) + + calibrate_activations(module, value=input_, base_name="input") def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): @@ -162,14 +183,22 @@ def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor): Will call the observers to update the scales/zp before applying output QDQ. """ + transforms = getattr(module, "transforms", None) + output_ = output + + if transforms: + output_transform = transforms.transforms.get("output_activations") + if output_transform: + output_ = output_transform(output_) + calibrate_activations( module, - value=output, + value=output_, base_name="output", ) output = forward_quantize( module=module, - value=output, + value=output_, base_name="output", args=module.quantization_scheme.output_activations, ) diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index a10866694..c90371b99 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -21,6 +21,7 @@ from llmcompressor.modifiers.quantization.calibration import ( apply_calibration_status, calibrate_input_hook, + calibrate_weight_hook, calibrate_kv_cache_input_hook, calibrate_kv_cache_output_hook, calibrate_output_hook, @@ -106,8 +107,13 @@ def on_initialize(self, state: State, **kwargs) -> bool: # initialize quantization in appropriate modules config = self._apply_modifier_to_model(module) - module.apply(lambda module: initialize_observer(module, base_name="weight")) + #def apply_weight_hooks(module: Module): + # self.register_hook(module, calibrate_weight_hook, "forward_pre") + + #module.apply(lambda model: apply_weight_hooks(model)) + module.apply(lambda module: initialize_observer(module, base_name="weight")) + if self.calculate_start() == -1: # one-shot self._check_calibration_data(config) module.apply(update_weight_zp_scale) @@ -119,7 +125,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: module.apply(freeze_module_quantization) return True - + def on_start(self, state: State, event: Event, **kwargs): module = state.model module.apply(update_weight_zp_scale) @@ -214,26 +220,26 @@ def _apply_modifier_to_model(self, model: Module): modifier_as_config = self.create_init_config() # Add step to attach kv_cache to the model, if present within the config R1 = random_hadamard_matrix(2048) - def weight_transform(R1: torch.Tensor, input_tensor: torch.Tensor): + def weight_transform(input_tensor: torch.Tensor): return input_tensor + """ R1 = R1.to(input_tensor.dtype).to(input_tensor.device) # Should have a different callable depending on the layer type try: return input_tensor @ R1 except: return R1.T @ input_tensor + """ - weight_transform_p = partial(weight_transform, R1) + #weight_transform_p = partial(weight_transform, R1) def input_activation_transform(input_tensor: torch.Tensor): return input_tensor - self.transforms["Linear"]["weight"] = weight_transform_p + self.transforms["Linear"]["weight"] = weight_transform self.transforms["Linear"]["input_activations"] = input_activation_transform self.transforms["Embedding"]["output_activations"] = input_activation_transform - self.transforms["model.layers.21.mlp.down_proj"]["weight"] = weight_transform_p - self.transforms["model.layers.21.mlp.down_proj"]["input_activations"] = input_activation_transform - self.transforms["model.layers.21.mlp.down_proj"]["output_activations"] = input_activation_transform + self.transforms["model.layers.21"]["output_activations"] = input_activation_transform apply_quantization_config(model, modifier_as_config, transforms=self.transforms) model.apply(set_unset_kv_cache)