Skip to content

Commit 28cfabb

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Fix use_sdpa_with_kv_cache option (#4456)
Summary: Pull Request resolved: #4456 As titled. In `export_llava.py` `export_text_model()` needs to respect `use_sdpa_with_kv_cache_op` option. Reviewed By: cccclai Differential Revision: D60431561 fbshipit-source-id: 63d49f39339435fb16f0c1c62288fd31c86b3be8
1 parent 3c25aec commit 28cfabb

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

examples/models/llava/export_llava.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,14 @@ def forward(self, input_pos, embeddings):
8383
)
8484
quant_transform = get_quant_weight_transform(args, dtype_override, False)
8585
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
86-
86+
source_transforms = []
87+
if llava.use_sdpa_with_kv_cache_op:
88+
source_transforms.append(replace_sdpa_with_custom_op)
89+
source_transforms.append(quant_transform)
8790
manager = (
8891
text_model_em.set_output_dir("./")
8992
.to_dtype(dtype_override)
90-
.source_transform([replace_sdpa_with_custom_op, quant_transform])
93+
.source_transform(source_transforms)
9194
.capture_pre_autograd_graph()
9295
.pt2e_quantize(quantizers)
9396
)

0 commit comments

Comments
 (0)