Skip to content

Adding support for Llama 3.1 and Llama 3.2 models #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
33 changes: 33 additions & 0 deletions fake_quant/hadamard_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def get_hadK(n, transpose=False):
assert (is_pow2(n // 20))
K = 20
hadK = get_had20().T if transpose else get_had20()
elif n % 24 == 0: # llama-3.2-3B
assert (is_pow2(n // 24))
K = 24
hadK = get_had24().T if transpose else get_had24()
elif n % 12 == 0:
assert (is_pow2(n // 12))
K = 12
Expand Down Expand Up @@ -165,6 +169,35 @@ def get_had12():
[+1, -1, +1, -1, -1, -1, +1, +1, +1, -1, +1, +1],
])

# hadamard matrices for had24.pal
# print("\n".join(["[" + ", ".join([f"{c}1" for c in l]) + "]," for l in a.split("\n")[:-1]]))
def get_had24():
return torch.FloatTensor([
[+1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[+1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1],
[+1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1],
[+1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1],
[+1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1],
[+1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1],
[+1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1],
[+1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1],
[+1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1],
[+1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1],
[+1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1],
[+1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1],
[+1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1],
[+1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1],
[+1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1, -1],
[+1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1, -1],
[+1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1, +1],
[+1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1, -1],
[+1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1, +1],
[+1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1, -1],
[+1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1, -1],
[+1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1, -1],
[+1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1, -1],
[+1, -1, -1, -1, -1, +1, -1, +1, -1, -1, +1, +1, -1, -1, +1, +1, -1, +1, -1, +1, +1, +1, +1, +1],
])

def get_had40():
return torch.FloatTensor([
Expand Down
8 changes: 6 additions & 2 deletions fake_quant/rotation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,11 @@ def rotate_model(model, args):

model_type = model_utils.model_type_extractor(model)
rotate_embeddings(model, Q)
rotate_head(model, Q)

# if the input_embeddings (embeddings) and output_embeddings (lm_head) are tied, avoid rotating twice since they reference the same data.
if not model.config.tie_word_embeddings:
rotate_head(model, Q)

utils.cleanup_memory()
layers = model_utils.get_transformer_layers(model,
model_type=model_type)
Expand Down Expand Up @@ -294,7 +298,7 @@ def forward(self, *args, **kwargs):


if self.k_groupsize == -1: #token-wise quantization
token_wise_k = k.transpose(1, 2).reshape(-1, self.config.hidden_size)
token_wise_k = k.transpose(1, 2).reshape(-1, self.config.hidden_size * self.config.num_key_value_heads / self.config.num_attention_heads)
self.k_quantizer.find_params(token_wise_k)
k = self.k_quantizer(token_wise_k).reshape((bsz, seq_len, num_heads, head_dim)).transpose(1, 2).to(q)
else: #head-wise quantization
Expand Down
8 changes: 6 additions & 2 deletions fake_quant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from datetime import datetime
import logging


from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory

Expand All @@ -17,7 +16,12 @@
'meta-llama/Llama-2-70b-hf',
'meta-llama/Meta-Llama-3-8B',
'meta-llama/Meta-Llama-3-70B',
'facebook/opt-125m'
'meta-llama/Llama-3.1-8B',
'meta-llama/Llama-3.1-70B',
'meta-llama/Llama-3.1-405B',
'meta-llama/Llama-3.2-1B',
'meta-llama/Llama-3.2-3B',
'facebook/opt-125m',
]
supported_datasets = ['wikitext2', 'ptb', 'c4']

Expand Down