Skip to content

Commit

Permalink
Merge pull request ReaLLMASIC#287 from mmoffatt2/quantization_lambda_…
Browse files Browse the repository at this point in the history
…param

Added Lambda "Quantization Level" Parameter
  • Loading branch information
gkielian authored Nov 12, 2024
2 parents 7a04e65 + 2480112 commit 034fa34
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 44 deletions.
38 changes: 38 additions & 0 deletions demos/quantization_level_demo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/bin/bash

# Train a fully quantized model
## on the dataset tinystories
## using a linear quantization scheduler, increasing to full quantization
## after 45000 iterations
python3 train.py \
--max_iters 90000 \
--full_quant_iteration 45000 \
--dataset tiny-stories \
--n_head 8 \
--n_embd 512 \
--block_size 256 \
--bias false \
--dtype bfloat16 \
--quantization_warmup_iters 0 \
--quantize_attn_act true \
--quantize_mlp_act true \
--linear_variant_attn quantized_linear \
--linear_variant_mlp quantized_linear \
--quantize_linear_method symmetric_quant \
--activations_quant_method symmetric_quant \
--dropout 0 \
--grad_clip 1.0 \
--beta1 0.95 \
--beta2 0.95 \
--weight_decay 0.05 \
--learning_rate 0.75e-3 \
--quant_scheduler linear \
--max_sample_tokens 100 \
--sample_each_eval true

# Test the model's inference capabilities when holding the scales and zero points static
python3 sample.py \
--out_dir quantization_tinystories/tiny_stories \
--eval_only \
--eval_dataset="tiny-stories" \
--static_eval_scales
5 changes: 5 additions & 0 deletions gpt_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ class GPTConfig:
linear_std_init: float= 0.02

# Quantizations
start_quant_level: float = 0
quant_scheduler: str = None
full_quant_iteration: int = None
# Needed for quant_level printing
eval_interval: int = 250

## Embedding Quantizations
quantize_wte: bool = False
Expand Down
65 changes: 39 additions & 26 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ def create_activation_buffers(obj, arg):
class CausalSelfAttention(nn.Module):
def __init__(self, config, fire_pos_enc=None):
super().__init__()

self.full_quant_iteration = config.full_quant_iteration
self.eval_interval = config.eval_interval
self.start_quant_level = config.start_quant_level
self.quant_scheduler = config.quant_scheduler

if (config.n_kv_group == None):
config.n_kv_group = config.n_head
else:
Expand Down Expand Up @@ -288,13 +294,13 @@ def get_block_mask(self, T, device):
return block_mask
# End Flex Attention Related

def forward(self, x):
def forward(self, x, iter_num):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

if self.quantization_attn_dict["quantize_attn_act_input"]:
num_bits = self.quantization_attn_dict["quantize_attn_act_input_bits"]
quant_method = self.quantization_attn_dict["activations_quant_method"]
x = fake_quantize_act(self, "attn_act_input", x, num_bits, quant_method)
x = fake_quantize_act(self, "attn_act_input", x, num_bits, quant_method, iter_num)

q = self.c_attn_q(x)
k = self.c_attn_k(x)
Expand Down Expand Up @@ -348,11 +354,11 @@ def forward(self, x):
if self.quantization_attn_dict["quantize_attn_act_qk_mult_q_input"]:
num_bits = self.quantization_attn_dict["quantize_attn_act_qk_mult_q_input_bits"]
quant_method = self.quantization_attn_dict["activations_quant_method"]
q = fake_quantize_act(self, "attn_act_qk_mult_q_input", q, num_bits, quant_method)
q = fake_quantize_act(self, "attn_act_qk_mult_q_input", q, num_bits, quant_method, iter_num)
if self.quantization_attn_dict["quantize_attn_act_qk_mult_k_input"]:
num_bits = self.quantization_attn_dict["quantize_attn_act_qk_mult_k_input_bits"]
quant_method = self.quantization_attn_dict["activations_quant_method"]
k = fake_quantize_act(self, "attn_act_qk_mult_k_input", k, num_bits, quant_method)
k = fake_quantize_act(self, "attn_act_qk_mult_k_input", k, num_bits, quant_method, iter_num)

att = None
# manual implementation of attention
Expand All @@ -378,7 +384,7 @@ def forward(self, x):
if self.quantization_attn_dict["quantize_attn_act_softmax_input"]:
num_bits = self.quantization_attn_dict["quantize_attn_act_softmax_input_bits"]
quant_method = self.quantization_attn_dict["activations_quant_method"]
att = fake_quantize_act(self, "attn_act_softmax_input", att, num_bits, quant_method, causal_mask=True)
att = fake_quantize_act(self, "attn_act_softmax_input", att, num_bits, quant_method, iter_num, causal_mask=True)

# softmax variation
if self.softmax_variant_attn != 'softmax':
Expand All @@ -391,11 +397,11 @@ def forward(self, x):
if self.quantization_attn_dict["quantize_attn_act_pv_mult_p_input"]:
num_bits = self.quantization_attn_dict["quantize_attn_act_pv_mult_p_input_bits"]
quant_method = self.quantization_attn_dict["activations_quant_method"]
att = fake_quantize_act(self, "attn_act_pv_mult_p_input", att, num_bits, quant_method)
att = fake_quantize_act(self, "attn_act_pv_mult_p_input", att, num_bits, quant_method, iter_num)
if self.quantization_attn_dict["quantize_attn_act_pv_mult_v_input"]:
num_bits = self.quantization_attn_dict["quantize_attn_act_pv_mult_v_input_bits"]
quant_method = self.quantization_attn_dict["activations_quant_method"]
v = fake_quantize_act(self, "attn_act_pv_mult_v_input", v, num_bits, quant_method)
v = fake_quantize_act(self, "attn_act_pv_mult_v_input", v, num_bits, quant_method, iter_num)

if self.n_head != self.n_kv_group:
v_repeated = v.repeat_interleave(self.n_head // self.n_kv_group, dim=1)
Expand All @@ -406,7 +412,7 @@ def forward(self, x):
if self.quantization_attn_dict["quantize_attn_act_pv_mult_output"]:
num_bits = self.quantization_attn_dict["quantize_attn_act_pv_mult_output_bits"]
quant_method = self.quantization_attn_dict["activations_quant_method"]
y = fake_quantize_act(self, "attn_act_pv_mult_output", y, num_bits, quant_method)
y = fake_quantize_act(self, "attn_act_pv_mult_output", y, num_bits, quant_method, iter_num)

y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

Expand All @@ -416,7 +422,7 @@ def forward(self, x):
if self.quantization_attn_dict["quantize_attn_act_output"]:
num_bits = self.quantization_attn_dict["quantize_attn_act_output_bits"]
quant_method = self.quantization_attn_dict["activations_quant_method"]
y = fake_quantize_act(self, "attn_act_output", y, num_bits, quant_method)
y = fake_quantize_act(self, "attn_act_output", y, num_bits, quant_method, iter_num)

return y

Expand All @@ -425,9 +431,15 @@ class MLP(nn.Module):
def __init__(self, config):
super().__init__()

self.full_quant_iteration = config.full_quant_iteration
self.eval_interval = config.eval_interval

# Select "mlp variant"
self.mlp_variant = config.mlp_variant

self.start_quant_level = config.start_quant_level
self.quant_scheduler = config.quant_scheduler

# If "MLP Variant" is KAN, then we skip MLP specific items
if self.mlp_variant == "kan":
self.kan = linear_dictionary["kan"](config.n_embd, config.n_embd, config=config)
Expand Down Expand Up @@ -468,11 +480,12 @@ def __init__(self, config):

self.dropout = nn.Dropout(config.dropout)

def forward(self, x):
def forward(self, x, iter_num):

if self.quantization_mlp_dict["quantize_mlp_act_input"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_input_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x = fake_quantize_act(self, "mlp_act_input", x, num_bits, quant_method)
x = fake_quantize_act(self, "mlp_act_input", x, num_bits, quant_method, iter_num)

if self.mlp_variant == "kan":
x = self.kan(x)
Expand All @@ -483,14 +496,14 @@ def forward(self, x):
if self.quantization_mlp_dict["quantize_mlp_act_activation_input"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_activation_input_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x = fake_quantize_act(self, "mlp_act_activation_input", x, num_bits, quant_method)
x = fake_quantize_act(self, "mlp_act_activation_input", x, num_bits, quant_method, iter_num)

x = self.activation_variant(x)

if self.quantization_mlp_dict["quantize_mlp_act_activation_output"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_activation_output_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x = fake_quantize_act(self, "mlp_act_activation_output", x, num_bits, quant_method)
x = fake_quantize_act(self, "mlp_act_activation_output", x, num_bits, quant_method, iter_num)

x = self.c_proj(x)

Expand All @@ -500,14 +513,14 @@ def forward(self, x):
if self.quantization_mlp_dict["quantize_mlp_act_activation_input"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_activation_input_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x_in1 = fake_quantize_act(self, "mlp_act_activation_input", x_in1, num_bits, quant_method)
x_in1 = fake_quantize_act(self, "mlp_act_activation_input", x_in1, num_bits, quant_method, iter_num)

x_in1 = self.activation_variant(x_in1)

if self.quantization_mlp_dict["quantize_mlp_act_activation_output"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_activation_output_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x_in1 = fake_quantize_act(self, "mlp_act_activation_output", x_in1, num_bits, quant_method)
x_in1 = fake_quantize_act(self, "mlp_act_activation_output", x_in1, num_bits, quant_method, iter_num)

x_in2 = self.c_fc_in2(x)
x_out = x_in1 * x_in2
Expand All @@ -518,7 +531,7 @@ def forward(self, x):
if self.quantization_mlp_dict["quantize_mlp_act_output"]:
num_bits = self.quantization_mlp_dict["quantize_mlp_act_output_bits"]
quant_method = self.quantization_mlp_dict["activations_quant_method"]
x = fake_quantize_act(self, "mlp_act_output", x, num_bits, quant_method)
x = fake_quantize_act(self, "mlp_act_output", x, num_bits, quant_method, iter_num)
return x

class Block(nn.Module):
Expand Down Expand Up @@ -547,22 +560,22 @@ def __init__(self, config, mlp=None, attn=None):
else:
self.mlp = mlp

def forward(self, x):
def forward(self, x, iter_num):
def custom_forward(*inputs):
x = inputs[0]
if self.use_post_ln:
if self.use_parallel_mlp:
x = self.ln_1(x + self.attn(x) + self.mlp(x))
x = self.ln_1(x + self.attn(x, iter_num) + self.mlp(x, iter_num))
else:
x = self.ln_1(x + self.attn(x))
x = self.ln_2(x + self.mlp(x))
x = self.ln_1(x + self.attn(x, iter_num))
x = self.ln_2(x + self.mlp(x, iter_num))
else:
if self.use_parallel_mlp:
ln_1 = self.ln_1(x)
x = x + self.attn(ln_1) + self.mlp(ln_1)
x = x + self.attn(ln_1, iter_num) + self.mlp(ln_1, iter_num)
else:
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
x = x + self.attn(self.ln_1(x), iter_num)
x = x + self.mlp(self.ln_2(x), iter_num)
return x

if self.use_gradient_checkpointing and x.requires_grad:
Expand Down Expand Up @@ -774,7 +787,7 @@ def export_scale_matrices(self, file_path):
np.savez(file_path, scale_up=scale_up_matrix, scale_down=scale_down_matrix)
print(f"Scale matrices saved to {file_path}")

def forward(self, idx, targets=None):
def forward(self, idx, targets=None, iter_num=None):
device = idx.device
b, t = idx.size()
# assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
Expand All @@ -801,9 +814,9 @@ def forward(self, idx, targets=None):
for block in self.transformer.h:
# Propagate tokens through layers
if self.config.use_gradient_checkpointing:
x = checkpoint.checkpoint(block, x, use_reentrant=self.config.recompute_backward_pass)
x = checkpoint.checkpoint(block, x, iter_num, use_reentrant=self.config.recompute_backward_pass)
else:
x = block(x)
x = block(x, iter_num)

# Intercept for Learned Steering Vectors
if self.use_lsv and layer == self.config.apply_lsv_at_layer_idx:
Expand Down
59 changes: 47 additions & 12 deletions quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,18 @@ def ternary_quantize(tensor, bits, causal_mask=False):
scale = tensor.abs().mean().clamp(min=1e-5)
result = (tensor / scale).round().clamp(-1, 1).to(dtype=torch.int8)
return torch.tensor([0], device=tensor.device), scale, result


def calculate_quant_level(training, quant_scheduler, start_quant_level, full_quant_iter, iter_num):
if full_quant_iter == None:
raise ValueError("Full quant iteration was not specified.")
if iter_num == None:
raise ValueError("Iter_num was not passed to GPT model")
if not training:
return 1
if quant_scheduler == "static":
return start_quant_level
elif quant_scheduler == "linear":
return min(iter_num / full_quant_iter + (full_quant_iter * start_quant_level), 1)

def symmetric_quantize(tensor, bits, causal_mask=False):
"""
Expand Down Expand Up @@ -112,20 +123,36 @@ def dequantize(zero_point, scale, tensor, causal_mask=False):
:return: Dequantized weights
"""
dequantized = (tensor - zero_point) * scale
if causal_mask:
# Create a mask for the upper triangular part
upper_tri_mask = torch.triu(torch.ones_like(dequantized), diagonal=1).bool()

# Set the upper triangular part to -inf
dequantized[upper_tri_mask] = -float('inf')
return dequantized

def fake_quantize_act(obj, activation, tensor, num_bits, quant_method, causal_mask=False):
def fake_quantize_act(obj, activation, tensor, num_bits, quant_method, iter_num, causal_mask=False):
zero_point, scale, act = quantize_dictionary[quant_method](tensor, num_bits, causal_mask=causal_mask)
setattr(obj, activation, act)
setattr(obj, f"{activation}_scale", scale)
setattr(obj, f"{activation}_zero_point", zero_point)
return dequantize(zero_point, scale, act, causal_mask=causal_mask)
dequantized = dequantize(zero_point, scale, act, causal_mask=causal_mask)
if causal_mask:
# Create a mask for the upper triangular part
upper_tri_mask = torch.triu(torch.ones_like(tensor), diagonal=1).bool()

# Set the upper triangular part to -inf
tensor[upper_tri_mask] = 0

# If scheduler is set, then we need to calculate the current quantization level
if obj.quant_scheduler != None:
quant_level = calculate_quant_level(obj.training, obj.quant_scheduler, obj.start_quant_level, obj.full_quant_iteration, iter_num)
# print quantization level for every evaluation interval
if obj.training and iter_num % obj.eval_interval == 0:
print("quant level: ", quant_level)
# adds quantization error to the original tensor
result = tensor + quant_level * (dequantized - tensor).detach()
else:
result = dequantized

if causal_mask:
result[upper_tri_mask] = -float('inf')

return result

class FakeLinearQuantizationFunction(torch.autograd.Function):
"""Simulates error caused by quantization. Uses Straight-Through Estimator for Back prop
Expand All @@ -134,7 +161,7 @@ class FakeLinearQuantizationFunction(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, input, bits=7, quantization_method="affine_quant"):
def forward(ctx, input, training, quant_scheduler, start_quant_level, full_quant_iter, eval_interval, steps, bits=7, quantization_method="affine_quant"):
"""
Forward pass
:param ctx: Context object to store information for the backward pass (not used in this case)
Expand All @@ -147,14 +174,22 @@ def forward(ctx, input, bits=7, quantization_method="affine_quant"):
# Dequantize the quantized values using the dequantize function.
# Return the dequantized tensor, which approximates the input tensor but includes the quantization error.
zero_point, norm, quantized_weight = quantize_dictionary[quantization_method](input, bits)
return dequantize(zero_point, norm, quantized_weight)
# If scheduler is set, then we need to calculate the current quantization level
dequantized = dequantize(zero_point, norm, quantized_weight)
if quant_scheduler != None:
quant_level = calculate_quant_level(training, quant_scheduler, start_quant_level, full_quant_iter, steps)
if training and steps % eval_interval == 0:
print("quant level: ", quant_level)

return input + quant_level * (dequantized - input).detach()
return dequantized

@staticmethod
def backward(ctx, grad_output):
# Straight-Through Estimator (STE): passes grad_output through as the gradient with respect to the input
# gradient is approximated by simply passing the gradient from the output directly to the input,
# ignoring the quantization operation
return grad_output, None, None
return grad_output, None, None, None, None, None, None, None, None

quantize_dictionary = {
"ternary_quant": ternary_quantize,
Expand Down
Loading

0 comments on commit 034fa34

Please sign in to comment.