Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Jan 30, 2025
1 parent 3ef7241 commit 026c250
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 18 deletions.
57 changes: 50 additions & 7 deletions examples/quantization_w8a8_fp8/llama3_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from datasets import load_dataset
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import oneshot

Expand All @@ -23,21 +23,64 @@
"Embedding": {
"output_activations": None
},
"model.layers.21.mlp.down_proj": {
"weight": None,
"input_activations": None,
"output_activations": None
"model.layers.21": {
"output_activationst": None
}
}
recipe = QuantizationModifier(
targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"], transforms=transforms
targets="Linear", scheme="FP8", ignore=["lm_head"], transforms=transforms
)

DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"

# Select number of samples. 512 samples is a good place to start.
# Increasing the number of samples can improve accuracy.
NUM_CALIBRATION_SAMPLES = 1
MAX_SEQUENCE_LENGTH = 2048

# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(tokenize, remove_columns=ds.column_names)


# Apply quantization.
oneshot(model=model, recipe=recipe)
oneshot(
model=model,
recipe=recipe,
dataset=ds,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")

input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
print(tokenizer.decode(output[0]))
Expand Down
35 changes: 32 additions & 3 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,19 @@ def update_weight_zp_scale(module: Module):

if module.quantization_scheme.weights is not None:
# set weight scale and zero_point up front, calibration data doesn't affect it
transforms = getattr(module, "transforms", None)

if transforms:
weight_transform = transforms.transforms.get("weight")
untransformed_weight = module.weight.data.clone()
transformed_weight = weight_transform(module.weight)
module.weight.data.copy_(transformed_weight)

call_observer(module=module, base_name="weight")

if transforms:
module.weight.data.copy_(untransformed_weight)


def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
"""
Expand All @@ -145,6 +156,9 @@ def calibrate_activations(module: Module, value: torch.Tensor, base_name: str):
value=value,
)

def calibrate_weight_hook(module: Module, args: Any):
#print("updating weights")
update_weight_zp_scale(module)

def calibrate_input_hook(module: Module, args: Any):
"""
Expand All @@ -153,7 +167,14 @@ def calibrate_input_hook(module: Module, args: Any):
input QDQ in the module's forward pass.
"""
args = args[0] if isinstance(args, tuple) else args
calibrate_activations(module, value=args, base_name="input")
transforms = getattr(module, "transforms", None)
input_ = args
if transforms:
input_transform = transforms.transforms.get("input_activations")
if input_transform:
input_ = input_transform(input_)

calibrate_activations(module, value=input_, base_name="input")


def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
Expand All @@ -162,14 +183,22 @@ def calibrate_output_hook(module: Module, _args: Any, output: torch.Tensor):
Will call the observers to update the scales/zp before applying
output QDQ.
"""
transforms = getattr(module, "transforms", None)
output_ = output

if transforms:
output_transform = transforms.transforms.get("output_activations")
if output_transform:
output_ = output_transform(output_)

calibrate_activations(
module,
value=output,
value=output_,
base_name="output",
)
output = forward_quantize(
module=module,
value=output,
value=output_,
base_name="output",
args=module.quantization_scheme.output_activations,
)
Expand Down
22 changes: 14 additions & 8 deletions src/llmcompressor/modifiers/quantization/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from llmcompressor.modifiers.quantization.calibration import (
apply_calibration_status,
calibrate_input_hook,
calibrate_weight_hook,
calibrate_kv_cache_input_hook,
calibrate_kv_cache_output_hook,
calibrate_output_hook,
Expand Down Expand Up @@ -106,8 +107,13 @@ def on_initialize(self, state: State, **kwargs) -> bool:

# initialize quantization in appropriate modules
config = self._apply_modifier_to_model(module)
module.apply(lambda module: initialize_observer(module, base_name="weight"))

#def apply_weight_hooks(module: Module):
# self.register_hook(module, calibrate_weight_hook, "forward_pre")

#module.apply(lambda model: apply_weight_hooks(model))
module.apply(lambda module: initialize_observer(module, base_name="weight"))

if self.calculate_start() == -1: # one-shot
self._check_calibration_data(config)
module.apply(update_weight_zp_scale)
Expand All @@ -119,7 +125,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
module.apply(freeze_module_quantization)

return True

def on_start(self, state: State, event: Event, **kwargs):
module = state.model
module.apply(update_weight_zp_scale)
Expand Down Expand Up @@ -214,26 +220,26 @@ def _apply_modifier_to_model(self, model: Module):
modifier_as_config = self.create_init_config()
# Add step to attach kv_cache to the model, if present within the config
R1 = random_hadamard_matrix(2048)
def weight_transform(R1: torch.Tensor, input_tensor: torch.Tensor):
def weight_transform(input_tensor: torch.Tensor):
return input_tensor

"""
R1 = R1.to(input_tensor.dtype).to(input_tensor.device)
# Should have a different callable depending on the layer type
try:
return input_tensor @ R1
except:
return R1.T @ input_tensor
"""

weight_transform_p = partial(weight_transform, R1)
#weight_transform_p = partial(weight_transform, R1)
def input_activation_transform(input_tensor: torch.Tensor):
return input_tensor

self.transforms["Linear"]["weight"] = weight_transform_p
self.transforms["Linear"]["weight"] = weight_transform
self.transforms["Linear"]["input_activations"] = input_activation_transform
self.transforms["Embedding"]["output_activations"] = input_activation_transform
self.transforms["model.layers.21.mlp.down_proj"]["weight"] = weight_transform_p
self.transforms["model.layers.21.mlp.down_proj"]["input_activations"] = input_activation_transform
self.transforms["model.layers.21.mlp.down_proj"]["output_activations"] = input_activation_transform
self.transforms["model.layers.21"]["output_activations"] = input_activation_transform

apply_quantization_config(model, modifier_as_config, transforms=self.transforms)
model.apply(set_unset_kv_cache)
Expand Down

0 comments on commit 026c250

Please sign in to comment.