-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
156 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
108 changes: 108 additions & 0 deletions
108
src/llmcompressor/modifiers/quantization/quantization/had.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import torch | ||
|
||
__all__ = ["random_hadamard_matrix"] | ||
# Configure the quantization algorithm and scheme. | ||
# In this case, we: | ||
# * quantize the weights to fp8 with per channel via ptq | ||
# * quantize the activations to fp8 with dynamic per token | ||
def random_hadamard_matrix(size, device="cuda"): | ||
# See https://cornell-relaxml.github.io/quip-sharp/ , Section "Randomized Hadamard Transformation" | ||
Q = torch.randint(low=0, high=2, size=(size,)).to(torch.float64) | ||
Q = Q * 2 - 1 | ||
Q = torch.diag(Q) | ||
return matmul_hadU(Q).to(device) | ||
|
||
def matmul_hadU(X, transpose=False): | ||
n = X.shape[-1] | ||
hadK, K = get_hadK(n, transpose) | ||
input = X.clone().view(-1, n, 1) | ||
output = input.clone() | ||
while input.shape[1] > K: | ||
input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2]) | ||
output = output.view(input.shape) | ||
output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :] | ||
output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :] | ||
output = output.view(input.shape[0], input.shape[1], -1) | ||
(input, output) = (output, input) | ||
del output | ||
|
||
if K > 1: | ||
# Do not explicitly repeat - OOM | ||
# input = torch.bmm( | ||
# hadK.repeat(len(input), 1, 1).to(input.device).to(input.dtype), input) | ||
# Use bcast instead | ||
input = hadK.view(1, K, K).to(input) @ input | ||
|
||
return input.view(X.shape) / torch.tensor(n).sqrt() | ||
|
||
def is_pow2(n): | ||
return (n & (n - 1) == 0) and (n > 0) | ||
|
||
def get_hadK(n, transpose=False): | ||
hadK, K = None, None | ||
if n % 172 == 0: # llama-2-7b up | ||
assert is_pow2(n // 172) | ||
|
||
K = 172 | ||
hadK = get_had172().T if transpose else get_had172() | ||
elif n % 156 == 0: # llama-1-30b 3x hidden | ||
assert is_pow2(n // 156) | ||
|
||
K = 156 | ||
hadK = get_had156().T if transpose else get_had156() | ||
elif n % 140 == 0: # llama-1-30b intermediate | ||
assert is_pow2(n // 140) | ||
|
||
K = 140 | ||
hadK = get_had140().T if transpose else get_had140() | ||
elif n % 108 == 0: # llama-1-13b intermediate | ||
assert is_pow2(n // 108) | ||
|
||
K = 108 | ||
hadK = get_had108().T if transpose else get_had108() | ||
elif n % 60 == 0: # llama-1-13b 3x hidden | ||
assert is_pow2(n // 60) | ||
|
||
K = 60 | ||
hadK = get_had60().T if transpose else get_had60() | ||
elif n % 52 == 0: # llama-1-13b 1x hidden | ||
assert is_pow2(n // 52) | ||
|
||
K = 52 | ||
hadK = get_had52().T if transpose else get_had52() | ||
elif n % 36 == 0: | ||
assert is_pow2(n // 36) | ||
|
||
K = 36 | ||
hadK = get_had36().T if transpose else get_had36() | ||
elif n % 28 == 0: | ||
assert is_pow2(n // 28) | ||
|
||
K = 28 | ||
hadK = get_had28().T if transpose else get_had28() | ||
elif n % 44 == 0: | ||
assert is_pow2(n // 44) | ||
|
||
K = 44 | ||
hadK = get_had44().T if transpose else get_had44() | ||
elif n % 40 == 0: | ||
assert is_pow2(n // 40) | ||
|
||
K = 40 | ||
hadK = get_had40().T if transpose else get_had40() | ||
elif n % 20 == 0: | ||
assert is_pow2(n // 20) | ||
|
||
K = 20 | ||
hadK = get_had20().T if transpose else get_had20() | ||
elif n % 12 == 0: | ||
assert is_pow2(n // 12) | ||
|
||
K = 12 | ||
hadK = get_had12().T if transpose else get_had12() | ||
else: | ||
assert is_pow2(n) | ||
|
||
K = 1 | ||
|
||
return hadK, K |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters