Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Jan 29, 2025
1 parent a9b8654 commit 3ef7241
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 7 deletions.
24 changes: 19 additions & 5 deletions examples/quantization_w8a8_fp8/llama3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,33 @@
from llmcompressor.transformers import oneshot

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

# Load model.
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Configure the quantization algorithm and scheme.
# In this case, we:
# * quantize the weights to fp8 with per channel via ptq
# * quantize the activations to fp8 with dynamic per token
# If not callable, would have to be a registry of reigstered callables
# replace None with Callable
# Need to add the ability to ignore certain layers when defining "Linear"/larger groups
transforms = {
"Linear": {
"weight": None,
"input_activations": None
},
"Embedding": {
"output_activations": None
},
"model.layers.21.mlp.down_proj": {
"weight": None,
"input_activations": None,
"output_activations": None
}
}
recipe = QuantizationModifier(
targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"]
targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"], transforms=transforms
)

# Apply quantization.
Expand Down
30 changes: 28 additions & 2 deletions src/llmcompressor/modifiers/quantization/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from loguru import logger
from pydantic import Field, field_validator
from torch.nn import Module
import torch

from functools import partial
from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.quantization.calibration import (
Expand All @@ -32,6 +34,7 @@
run_calibration_forward,
)
from llmcompressor.observers.helpers import get_observer_token_count
from llmcompressor.modifiers.quantization.quantization.had import random_hadamard_matrix

__all__ = ["QuantizationModifier"]

Expand Down Expand Up @@ -79,6 +82,7 @@ class QuantizationModifier(Modifier):
kv_cache_scheme: Optional[QuantizationArgs] = None
disable_quantization_observer_epoch: Optional[float] = None
num_calibration_steps: Optional[int] = None
transforms: Optional[Dict] = None

calibration_dataloader_: Any = None
calibration_function_: Any = None
Expand All @@ -103,7 +107,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
# initialize quantization in appropriate modules
config = self._apply_modifier_to_model(module)
module.apply(lambda module: initialize_observer(module, base_name="weight"))

if self.calculate_start() == -1: # one-shot
self._check_calibration_data(config)
module.apply(update_weight_zp_scale)
Expand Down Expand Up @@ -209,7 +213,29 @@ def _check_calibration_data(self, config: QuantizationConfig):
def _apply_modifier_to_model(self, model: Module):
modifier_as_config = self.create_init_config()
# Add step to attach kv_cache to the model, if present within the config
apply_quantization_config(model, modifier_as_config)
R1 = random_hadamard_matrix(2048)
def weight_transform(R1: torch.Tensor, input_tensor: torch.Tensor):
return input_tensor

R1 = R1.to(input_tensor.dtype).to(input_tensor.device)
# Should have a different callable depending on the layer type
try:
return input_tensor @ R1
except:
return R1.T @ input_tensor

weight_transform_p = partial(weight_transform, R1)
def input_activation_transform(input_tensor: torch.Tensor):
return input_tensor

self.transforms["Linear"]["weight"] = weight_transform_p
self.transforms["Linear"]["input_activations"] = input_activation_transform
self.transforms["Embedding"]["output_activations"] = input_activation_transform
self.transforms["model.layers.21.mlp.down_proj"]["weight"] = weight_transform_p
self.transforms["model.layers.21.mlp.down_proj"]["input_activations"] = input_activation_transform
self.transforms["model.layers.21.mlp.down_proj"]["output_activations"] = input_activation_transform

apply_quantization_config(model, modifier_as_config, transforms=self.transforms)
model.apply(set_unset_kv_cache)
return modifier_as_config

Expand Down
108 changes: 108 additions & 0 deletions src/llmcompressor/modifiers/quantization/quantization/had.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import torch

__all__ = ["random_hadamard_matrix"]
# Configure the quantization algorithm and scheme.
# In this case, we:
# * quantize the weights to fp8 with per channel via ptq
# * quantize the activations to fp8 with dynamic per token
def random_hadamard_matrix(size, device="cuda"):
# See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation"
Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64)
Q = Q * 2 - 1
Q = torch.diag(Q)
return matmul_hadU(Q).to(device)

def matmul_hadU(X, transpose=False):
n = X.shape[-1]
hadK, K = get_hadK(n, transpose)
input = X.clone().view(-1, n, 1)
output = input.clone()
while input.shape[1] > K:
input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
output = output.view(input.shape)
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
output = output.view(input.shape[0], input.shape[1], -1)
(input, output) = (output, input)
del output

if K > 1:
# Do not explicitly repeat - OOM
# input = torch.bmm(
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input)
# Use bcast instead
input = hadK.view(1, K, K).to(input) @ input

return input.view(X.shape) / torch.tensor(n).sqrt()

def is_pow2(n):
return (n & (n - 1) == 0) and (n > 0)

def get_hadK(n, transpose=False):
hadK, K = None, None
if n % 172 == 0: # llama-2-7b up
assert is_pow2(n // 172)

K = 172
hadK = get_had172().T if transpose else get_had172()
elif n % 156 == 0: # llama-1-30b 3x hidden
assert is_pow2(n // 156)

K = 156
hadK = get_had156().T if transpose else get_had156()
elif n % 140 == 0: # llama-1-30b intermediate
assert is_pow2(n // 140)

K = 140
hadK = get_had140().T if transpose else get_had140()
elif n % 108 == 0: # llama-1-13b intermediate
assert is_pow2(n // 108)

K = 108
hadK = get_had108().T if transpose else get_had108()
elif n % 60 == 0: # llama-1-13b 3x hidden
assert is_pow2(n // 60)

K = 60
hadK = get_had60().T if transpose else get_had60()
elif n % 52 == 0: # llama-1-13b 1x hidden
assert is_pow2(n // 52)

K = 52
hadK = get_had52().T if transpose else get_had52()
elif n % 36 == 0:
assert is_pow2(n // 36)

K = 36
hadK = get_had36().T if transpose else get_had36()
elif n % 28 == 0:
assert is_pow2(n // 28)

K = 28
hadK = get_had28().T if transpose else get_had28()
elif n % 44 == 0:
assert is_pow2(n // 44)

K = 44
hadK = get_had44().T if transpose else get_had44()
elif n % 40 == 0:
assert is_pow2(n // 40)

K = 40
hadK = get_had40().T if transpose else get_had40()
elif n % 20 == 0:
assert is_pow2(n // 20)

K = 20
hadK = get_had20().T if transpose else get_had20()
elif n % 12 == 0:
assert is_pow2(n // 12)

K = 12
hadK = get_had12().T if transpose else get_had12()
else:
assert is_pow2(n)

K = 1

return hadK, K
1 change: 1 addition & 0 deletions src/llmcompressor/recipe/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def create_instance(
"attempting to process as a string."
)
logger.debug(f"Input string: {path_or_modifiers}")
path_or_modifiers = re.sub("!!python/name:__main__.*", "", path_or_modifiers)
obj = _load_json_or_yaml_string(path_or_modifiers)
return Recipe.model_validate(obj)
else:
Expand Down

0 comments on commit 3ef7241

Please sign in to comment.