Skip to content

Commit a7e8a77

Browse files
committed
fixes
1 parent ec336a9 commit a7e8a77

10 files changed

+16
-24
lines changed

sharktank/sharktank/examples/export_paged_llm_v1.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -60,21 +60,19 @@ def main():
6060
dataset_type = cli.get_input_data_files(args)
6161
dataset_type = "irpa" if "irpa" in dataset_type else "gguf"
6262
dataset = cli.get_input_dataset(args)
63-
64-
kv_cache_dtype = getattr(torch, args.kv_cache_dtype)
6563
hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
6664
tensor_parallelism_size = (
6765
dataset.properties["tensor_parallelism_size"]
6866
if "tensor_parallelism_size" in dataset.properties
6967
else 1
7068
)
69+
7170
llama_config = LlamaModelConfig(
7271
hp,
7372
tensor_parallelism_size=tensor_parallelism_size,
7473
use_hf=False,
7574
static_tables=False, # Rely on the compiler for hoisting tables.
7675
kv_cache_type="direct" if args.bs == [1] else "paged",
77-
kv_cache_dtype=kv_cache_dtype,
7876
attention_kernel=args.attention_kernel,
7977
)
8078
llama_config.fake_quant = args.fake_quant

sharktank/sharktank/examples/paged_llm_v1.py

-4
Original file line numberDiff line numberDiff line change
@@ -253,15 +253,12 @@ def main():
253253
cli.add_quantization_options(parser)
254254
cli.add_model_options(parser)
255255
args = cli.parse(parser)
256-
257256
device = torch.device(args.device) if args.device else None
258257
activation_dtype = getattr(torch, args.activation_dtype)
259258
assert isinstance(activation_dtype, torch.dtype)
260-
kv_cache_dtype = getattr(torch, args.kv_cache_dtype)
261259
dataset = cli.get_input_dataset(args)
262260
tokenizer = cli.get_tokenizer(args)
263261
prompts = args.prompt
264-
265262
config = LlamaModelConfig(
266263
hp=configs.LlamaHParams.from_gguf_props(dataset.properties),
267264
block_seq_stride=16,
@@ -270,7 +267,6 @@ def main():
270267
activation_dtype=activation_dtype,
271268
attention_dtype=activation_dtype,
272269
attention_kernel=args.attention_kernel,
273-
kv_cache_dtype=kv_cache_dtype,
274270
use_hf=args.use_hf,
275271
tensor_parallelism_size=args.tensor_parallelism_size,
276272
fake_quant=args.fake_quant,

sharktank/sharktank/layers/configs/llm_configs.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,6 @@ class LlamaModelConfig:
153153
# Dtype to use for general FP activations not otherwise configured.
154154
activation_dtype: torch.dtype = torch.float16
155155

156-
kv_cache_dtype: torch.dtype = torch.float16
157-
158156
# Dtype to use for attention.
159157
attention_dtype: torch.dtype = torch.float16
160158

@@ -180,4 +178,4 @@ class LlamaModelConfig:
180178
# the compiler to transform it to an initialization time step. This can
181179
# be the difference of many gigabytes of static data being embedded in
182180
# the program and not.
183-
static_tables: bool = True
181+
static_tables: bool = True

sharktank/sharktank/layers/kv_cache.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ def write_timestep(
191191
update_count = len(cache_partitions)
192192

193193
for b in range(bs):
194-
row_index = torch.tensor(b, dtype=torch.int64)
195-
row_start_pos = seq_positions[row_index]
194+
row_index = torch.tensor([b], dtype=torch.int64)
195+
row_start_pos = seq_positions[row_index].unsqueeze(0)
196196

197197
for i, update in enumerate(cache_partitions):
198198
cache = state[transformer_block_index * update_count + i]
@@ -477,7 +477,6 @@ def write(
477477

478478
page_ids = page_ids.flatten(0, 1)
479479
part_block_view = part_block_view.flatten(0, 1)
480-
part_block_view = ops.to(part_block_view, page_table.dtype)
481480

482481
page_table.index_put_(
483482
(

sharktank/sharktank/layers/linear.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class LinearLayer(ThetaLayer):
3030
if premul_input is not None:
3131
x = x * premul_input
3232
matmul(x, weight.T) + bias
33+
34+
fake_quant exists to allow export without adding dequant ops.
35+
when fake_quant is True, the op will in quant dequant fashion.
36+
When false, it will keep quantized types.
3337
```
3438
"""
3539

@@ -70,21 +74,21 @@ def forward(self, x):
7074
x = q_input.quantize(x)
7175
if self.fake_quant:
7276
x = x.unpack().dequant()
77+
elif qdq_input is not None and self.fake_quant:
78+
x = qdq_input.quantize(x).unpack().dequant()
7379

7480
y = ops.linear(x, weight, bias)
7581

7682
# Unconditionally dequantize.
77-
# TODO: Support a q_output specifier that signals the layer to let
78-
# the QuantizedTensor escape.
7983
if isinstance(y, QuantizedTensor) and not self.fake_quant:
8084
y = y.unpack().dequant()
8185
# Note that f8_e4m3fnuz types on AMD GPUs accumulate to fp32.
8286
# We can truncate to fp16 in iree, so we do a cast here
83-
# to account for this in the IR.
87+
# to account for this in the IR. This is may not be the right
88+
# level to do this, but for now its here.
8489
if not self.fake_quant and y.dtype == torch.float8_e4m3fnuz:
8590
y = ops.to(y, torch.float16)
8691
return y
87-
if qdq_output is not None:
88-
# TODO: same as above.
92+
if qdq_output is not None and self.fake_quant:
8993
y = qdq_output.quantize(y).unpack().dequant()
9094
return y

sharktank/sharktank/layers/paged_llama_attention_block.py

-3
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,6 @@ def forward(
151151
xk = self.cache_quantizer.quantize(xk).unpack().qs
152152
xv = self.cache_quantizer.quantize(xv).unpack().qs
153153

154-
print(xk.dtype)
155-
print(xv.dtype)
156-
print(self.cache.dtype)
157154
xk, xv = self.transact_cache(
158155
xk_cache_update=xk,
159156
xv_cache_update=xv,

sharktank/sharktank/ops/qlinear_impls.py

+2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def qlinear_tensor_scaled(
5555
return matmul(x_layout.qs, weight_layout.qs, transpose_rhs=True).to(
5656
torch.float16
5757
)
58+
else:
59+
return NotImplemented
5860

5961
# Bias.
6062
quantized_bias_accum = False

sharktank/sharktank/types/quantizers.py

-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def _quantize_raw_tensor(self, t: torch.Tensor, *, name: str) -> QuantizedTensor
8383
...
8484

8585

86-
8786
@register_inference_tensor
8887
class StaticScaledQuantizer(QuantizerTensor):
8988
"""Quantizes to a `TensorScaledLayout` (per-tensor) or (TBD) for per-axis.

sharktank/tests/layers/linear_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def testNativeQuant_SymPerTensor_AsymPerAxis0_Dynamic(self):
8484
bias_quant,
8585
]
8686
)
87-
linear = LinearLayer(theta)
87+
linear = LinearLayer(theta, fake_quant=False)
8888

8989
output = linear(lhs)
9090
output_ref = torch.matmul(lhs, rhs.T) + bias

sharktank/tests/layers/paged_llama_attention_block_test.py

-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def forward(self, h, seq_block_ids, cache_state):
115115
output = aot.export(ep)
116116
output.verify()
117117
asm = str(output.mlir_module)
118-
output.save_mlir("temp.mlir")
119118
self.assertNotIn("scaled_dot_product_attention", asm)
120119

121120
def testExportNondecomposed(self):

0 commit comments

Comments
 (0)