Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Export llama uses to_edge_lower_and_transform #7524

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,11 +659,12 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
if args.export_only:
exit()

builder_exported_to_edge = builder_exported.pt2e_quantize(
quantizers
).export_to_edge()
# builder_exported_to_edge = builder_exported.pt2e_quantize(
# quantizers
# ).export_to_edge()

modelname = builder_exported_to_edge.modelname
# modelname = builder_exported_to_edge.modelname
modelname = builder_exported.modelname

# to_backend
partitioners = []
Expand Down Expand Up @@ -768,6 +769,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")

breakpoint()
if args.generate_etrecord:
if not builder_exported_to_edge.edge_manager:
raise ValueError("Unable to generate etrecord due to missing edge manager.")
Expand All @@ -793,14 +795,19 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
)
logging.info("Generated etrecord.bin")
else:
builder = builder_exported_to_edge.to_backend(partitioners)
builder_lowered = builder_exported.pt2e_quantize(
quantizers
).to_edge_transform_and_lower(
partitioners
)
# builder = builder_exported_to_edge.to_backend(partitioners)
if args.num_sharding > 0 and args.qnn:
from executorch.backends.qualcomm.utils.utils import canonicalize_program

# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch()
builder = builder_lowered.to_executorch()

if args.profile_memory:
generate_memory_trace(builder.export_program, "memory_profile.json")
Expand Down
95 changes: 54 additions & 41 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
DuplicateDynamicQuantChainPass,
)
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
from executorch.exir import EdgeProgramManager
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower
from executorch.exir.backend.partitioner import Partitioner

from executorch.exir.backend.utils import format_delegated_graph
Expand Down Expand Up @@ -216,6 +216,7 @@ def export(self) -> "LLMEdgeManager":
)
# pyre-fixme[8]: Attribute has type `Optional[GraphModule]`; used as
# `Module`.
self.pre_autograd_exported_program = exported_module
self.pre_autograd_graph_module = exported_module.module()
if hasattr(self.args, "export_only") and self.args.export_only:
torch.export.save(exported_module, self.args.output_name)
Expand Down Expand Up @@ -305,51 +306,51 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
), "export_to_edge is already called, please call pt2e_quantize before export_to_edge"
logging.info(f"Using pt2e {quantizers} to quantizing the model...")

if not quantizers:
logging.info("No quantizer provided, passing...")
return self

# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
if quantizers:
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
if self.verbose:
logging.info(f"Applied quantizers: {quantizers}")
composed_quantizer = ComposableQuantizer(quantizers)
assert (
self.pre_autograd_graph_module is not None
), "Please run export() first"
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
if self.verbose:
logging.info(f"Applied quantizers: {quantizers}")
composed_quantizer = ComposableQuantizer(quantizers)
assert (
self.pre_autograd_graph_module is not None
), "Please run export() first"
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
logging.info(
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
)
# Calibrate
if (
self.calibration_tasks is not None
and self.calibration_limit is not None
and self.calibration_seq_length is not None
and self.calibration_data is not None
and self.tokenizer_path is not None
):
logging.info(
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
)
# Calibrate
if (
self.calibration_tasks is not None
and self.calibration_limit is not None
and self.calibration_seq_length is not None
and self.calibration_data is not None
and self.tokenizer_path is not None
):
logging.info(
f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, calibration_data: {self.calibration_data}, tokenizer_path: {self.tokenizer_path}, seq_length: {self.calibration_seq_length}"
)
self.pt2e_calibrate(
prepared_module=m,
calibration_tasks=self.calibration_tasks,
calibration_limit=self.calibration_limit,
calibration_seq_length=self.calibration_seq_length,
calibration_data=self.calibration_data,
tokenizer_path=self.tokenizer_path,
)
else:
logging.info(
"No calibration provided, using dummy input to calibrate..."
)
m(*self.example_inputs)
m = convert_pt2e(m)
DuplicateDynamicQuantChainPass()(m)
self.pre_autograd_graph_module = m
return self
else:
logging.info("No quantizer provided, passing...")
return self
self.pt2e_calibrate(
prepared_module=m,
calibration_tasks=self.calibration_tasks,
calibration_limit=self.calibration_limit,
calibration_seq_length=self.calibration_seq_length,
calibration_data=self.calibration_data,
tokenizer_path=self.tokenizer_path,
)
else:
logging.info(
"No calibration provided, using dummy input to calibrate..."
)
m(*self.example_inputs, **self.example_kwarg_inputs)
m = convert_pt2e(m)
DuplicateDynamicQuantChainPass()(m)
self.pre_autograd_graph_module = m
return self

def export_to_edge(self) -> "LLMEdgeManager":
"""
Expand Down Expand Up @@ -415,6 +416,18 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag

return self

def to_edge_transform_and_lower(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManager":
if partitioners is None:
logging.info("No partitioner provided, skipping backend lowering...")
breakpoint()
edge_config = self._get_edge_config()
self.edge_manager = to_edge_transform_and_lower(
self.pre_autograd_exported_program,
partitioner=partitioners,
compile_config=edge_config,
)
return self

def to_executorch(self) -> "LLMEdgeManager":
"""
Lower the model to executorch and get an ExecutorchProgram.
Expand Down
Loading