Skip to content

Commit 20c0285

Browse files
committed
[ExecuTorch][to_backend] Enable to_backend API to leverage preprocess_all
ghstack-source-id: f0f8445e51481241c8ffa9a59346a06d2e2b4a54 ghstack-comment-id: 2770905773 Pull Request resolved: #9824
1 parent d573c1c commit 20c0285

5 files changed

+1169
-16
lines changed

exir/backend/backend_api.py

+372-1
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
import copy
99
import logging
1010
from contextlib import contextmanager, nullcontext
11+
from dataclasses import dataclass
1112
from functools import singledispatch
12-
from typing import Generator, List
13+
from typing import Dict, Generator, List
1314

1415
import torch
1516

@@ -417,3 +418,373 @@ def to_backend(
417418
constants=tagged_exported_program.constants,
418419
verifiers=[tagged_exported_program.verifier],
419420
)
421+
422+
423+
def _create_partitions_in_graph_module(
424+
tagged_graph_module: torch.fx.GraphModule,
425+
partition_result: PartitionResult,
426+
owning_program: ExportedProgram,
427+
is_submodule: bool,
428+
) -> Dict[str, List[torch.fx.Node]]:
429+
backend_id_to_submodule_name = {}
430+
for tag, delegation_spec in partition_result.partition_tags.items():
431+
# Create partition with nodes containing this tag. There should only be
432+
# one contained submodule per tag
433+
node_list = _get_node_list_with_same_tag(
434+
tagged_graph_module, tag, owning_program
435+
)
436+
437+
if len(node_list) == 0:
438+
logging.debug(f"Did not find any nodes for tag {tag}")
439+
continue
440+
441+
logging.debug(f"For tag {tag}, found nodes {node_list}")
442+
# Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs)
443+
444+
replace_ctx = (
445+
tagged_graph_module._set_replace_hook(
446+
owning_program.graph_signature.get_replace_hook()
447+
)
448+
if not is_submodule
449+
else nullcontext()
450+
)
451+
with replace_ctx:
452+
submodule, call_module_node = create_submodule_from_nodes(
453+
tagged_graph_module, node_list, tag
454+
)
455+
456+
tagged_graph_module_output_node = [
457+
node for node in tagged_graph_module.graph.nodes if node.op == "output"
458+
][0]
459+
submodule_output_node = [
460+
node for node in submodule.graph.nodes if node.op == "output"
461+
][0]
462+
# Copy the output node meta from the original output node, because
463+
# create_submodule_from_nodes doesn't cover the meta field
464+
submodule_output_node.meta = tagged_graph_module_output_node.meta
465+
logging.debug(f"Partitioned graph module: {tagged_graph_module}")
466+
(
467+
submodule_program,
468+
toplevel_input_specs_to_delete,
469+
toplevel_output_specs_to_delete,
470+
) = create_exported_program_from_submodule(
471+
submodule,
472+
owning_program,
473+
tag,
474+
call_module_node,
475+
is_submodule,
476+
)
477+
call_module_node.meta["backend_id"] = delegation_spec.backend_id
478+
call_module_node.meta["compile_spec"] = delegation_spec.compile_specs
479+
call_module_node.meta["submodule_program"] = submodule_program
480+
call_module_node.meta["toplevel_input_specs_to_delete"] = (
481+
toplevel_input_specs_to_delete
482+
)
483+
call_module_node.meta["toplevel_output_specs_to_delete"] = (
484+
toplevel_output_specs_to_delete
485+
)
486+
call_module_node.meta["is_submodule"] = is_submodule
487+
488+
if delegation_spec.backend_id not in backend_id_to_submodule_name:
489+
backend_id_to_submodule_name[delegation_spec.backend_id] = []
490+
491+
# The call_module_node created here might not be the same node instance as
492+
# the one in the final graph module. This is because this node might be replaced
493+
# in future edits to the graph. As a result, we just keep track of the node's name
494+
# and at the end we search for this node in our final graph module
495+
backend_id_to_submodule_name[delegation_spec.backend_id].append(
496+
call_module_node.target
497+
)
498+
499+
created_submodule_nodes = dict(
500+
(key, []) for key in backend_id_to_submodule_name.keys()
501+
)
502+
for backend_id, submodule_name in backend_id_to_submodule_name.items():
503+
for node in tagged_graph_module.graph.nodes:
504+
if node.op == "call_module" and node.target in submodule_name:
505+
created_submodule_nodes[backend_id].append(node)
506+
507+
# check the number of submodule_names and submodule_nodes are equal
508+
for backend_id in created_submodule_nodes.keys():
509+
assert len(created_submodule_nodes[backend_id]) == len(
510+
backend_id_to_submodule_name[backend_id]
511+
)
512+
513+
return created_submodule_nodes
514+
515+
516+
def _create_partitions(
517+
tagged_graph_module: torch.fx.GraphModule,
518+
partition_result: PartitionResult,
519+
owning_program: ExportedProgram,
520+
is_submodule: bool = False,
521+
) -> Dict[str, List[torch.fx.Node]]:
522+
backend_id_to_call_submodules = _create_partitions_in_graph_module(
523+
tagged_graph_module, partition_result, owning_program, is_submodule
524+
)
525+
526+
# Recursively partition and lower for submodules
527+
for _, submod, _ in get_control_flow_submodules(tagged_graph_module):
528+
nested_backend_id_to_call_submodules = _create_partitions(
529+
submod, partition_result, owning_program, is_submodule=True
530+
)
531+
for (
532+
backend_id,
533+
nested_submodules,
534+
) in nested_backend_id_to_call_submodules.items():
535+
if backend_id not in backend_id_to_call_submodules:
536+
backend_id_to_call_submodules[backend_id] = nested_submodules
537+
else:
538+
backend_id_to_call_submodules[backend_id].extend(nested_submodules)
539+
540+
return backend_id_to_call_submodules
541+
542+
543+
def lower_all_submodules_to_backend(
544+
backend_id: str,
545+
method_to_submodules_nodes: Dict[str, List[torch.fx.Node]],
546+
method_to_tagged_edge_program: Dict[str, ExportedProgram],
547+
) -> None:
548+
"""
549+
Lower all submodules nodes given in the method_to_submodule_nodes map to backend_id.
550+
"""
551+
# The created exported program for the submodules are in the call_module node's meta data
552+
# We just map the method_to_submodule_nodes directly to the method_to_partitioned_exported_programs
553+
method_to_partitioned_program = {
554+
method_name: [node.meta["submodule_program"] for node in call_submodule_nodes]
555+
for method_name, call_submodule_nodes in method_to_submodules_nodes.items()
556+
}
557+
method_to_compile_specs = {
558+
method_name: [node.meta["compile_spec"] for node in call_submodule_nodes]
559+
for method_name, call_submodule_nodes in method_to_submodules_nodes.items()
560+
}
561+
backend_found = False
562+
for cls in BackendDetails.__subclasses__():
563+
if backend_id == cls.__name__:
564+
method_to_preprocess_result: dict[str, List[PreprocessResult]] = (
565+
cls.preprocess_multimethod(
566+
method_to_partitioned_program, method_to_compile_specs
567+
)
568+
)
569+
backend_found = True
570+
571+
if not backend_found:
572+
raise NotImplementedError(f"Backend {backend_id} was not found.")
573+
574+
for method_name in method_to_preprocess_result.keys():
575+
owning_program = method_to_tagged_edge_program[method_name]
576+
list_of_preprocess_results = method_to_preprocess_result[method_name]
577+
list_of_call_submodule_nodes = method_to_submodules_nodes[method_name]
578+
list_of_compile_specs = method_to_compile_specs[method_name]
579+
assert (
580+
len(list_of_preprocess_results) == len(list_of_call_submodule_nodes),
581+
f"Expected {len(list_of_call_submodule_nodes)} preprocessed results for method {method_name} but got {len(list_of_preprocess_results)}",
582+
)
583+
for preprocess_result, call_submodule_node, compile_spec in zip(
584+
list_of_preprocess_results,
585+
list_of_call_submodule_nodes,
586+
list_of_compile_specs,
587+
):
588+
submodule_program = call_submodule_node.meta["submodule_program"]
589+
lowered_module = LoweredBackendModule(
590+
edge_program=submodule_program,
591+
backend_id=backend_id,
592+
processed_bytes=preprocess_result.processed_bytes,
593+
compile_specs=compile_spec,
594+
)
595+
owning_graph_module = call_submodule_node.graph.owning_module
596+
is_submodule = call_submodule_node.meta["is_submodule"]
597+
toplevel_input_specs_to_delete = call_submodule_node.meta[
598+
"toplevel_input_specs_to_delete"
599+
]
600+
toplevel_output_specs_to_delete = call_submodule_node.meta[
601+
"toplevel_output_specs_to_delete"
602+
]
603+
# call delegate args should only use user_inputs
604+
call_delegate_args = []
605+
# Preserve input order as user_inputs
606+
for inp_name in submodule_program.graph_signature.user_inputs:
607+
for inp_node in call_submodule_node.all_input_nodes:
608+
if inp_node.name == inp_name:
609+
call_delegate_args.append(inp_node)
610+
break
611+
612+
def generate_debug_handle(ep: ExportedProgram) -> int:
613+
"""
614+
Generate a debug handle for the given ExportedProgram.
615+
"""
616+
debug_handle = 0
617+
for node in ep.graph_module.graph.nodes:
618+
debug_handle = max(debug_handle, node.meta.get("debug_handle", 0))
619+
return debug_handle + 1
620+
621+
# Replace the partitioned submodule with a lowered submodule
622+
# Add call_method node with function "forward"
623+
with owning_graph_module.graph.inserting_before(call_submodule_node):
624+
lowered_name = get_lowered_module_name(
625+
owning_graph_module, lowered_module
626+
)
627+
lowered_node = owning_graph_module.graph.get_attr(lowered_name)
628+
call_delegate_node = owning_graph_module.graph.call_function(
629+
executorch_call_delegate,
630+
(lowered_node,) + tuple(call_delegate_args),
631+
call_submodule_node.kwargs,
632+
)
633+
call_delegate_node.meta["debug_handle"] = generate_debug_handle(
634+
owning_program
635+
)
636+
call_delegate_node.meta["val"] = call_submodule_node.meta["val"]
637+
call_submodule_node.replace_all_uses_with(call_delegate_node)
638+
owning_graph_module.graph.erase_node(call_submodule_node)
639+
640+
if is_submodule:
641+
assert len(toplevel_input_specs_to_delete) == 0
642+
assert len(toplevel_output_specs_to_delete) == 0
643+
elif (
644+
len(toplevel_input_specs_to_delete) > 0
645+
or len(toplevel_output_specs_to_delete) > 0
646+
):
647+
_unsafe_adjust_original_program(
648+
owning_program,
649+
call_delegate_node,
650+
toplevel_input_specs_to_delete,
651+
toplevel_output_specs_to_delete,
652+
)
653+
654+
655+
@dataclass
656+
class MethodProgramsPartitionerSpec:
657+
"""
658+
Since single dispatch for to_backend requires the first argument to be a
659+
valid class, we create the following dataclass spec to hold the dictionaries
660+
mapping the method name to the corresponding program, partitioner
661+
"""
662+
663+
method_to_edge_program: Dict[str, ExportedProgram]
664+
method_to_partitioner: Dict[str, Partitioner]
665+
666+
667+
@to_backend.register
668+
def _(
669+
method_edge_program_partitioners: MethodProgramsPartitionerSpec,
670+
) -> Dict[str, ExportedProgram]:
671+
"""
672+
Add overloaded implementations for to_backend:
673+
674+
::
675+
676+
def to_backend(
677+
method_edge_program_partitioners: MethodProgramsPartitionerSpec
678+
) -> Dict[str, ExportedProgram]:
679+
680+
Returns a semantically-equivalent dictionary of programs to the programs given as input (represented
681+
as a graph module in Edge dialect), but with portions of the program targeted for
682+
delegation as determined by the partitioner.
683+
684+
Args:
685+
method_edge_program_partitioners: contains two mappings,
686+
- method_to_edge_program: mapping of method names to their respective programs in Edge dialect.
687+
- method_to_partitioner: mapping of method names to an instance of the partitioner, in charge with tagging
688+
portions of the specified program for delegation. A valid partitioner must return PartitionerResult
689+
including both tagged exported program and partitioner_tag: Dict[str, DelegationSpec], where each key is a tag name and
690+
the nodes with same tag will be fused a one subgraph and delegated to backend specififed in delegation spec.
691+
692+
693+
Returns:
694+
ExportedProgram: The input program, with some portions targeted for delegation.
695+
"""
696+
method_to_edge_program = method_edge_program_partitioners.method_to_edge_program
697+
method_to_partitioner = method_edge_program_partitioners.method_to_partitioner
698+
699+
partitioned_and_lowered_exported_programs = {}
700+
backend_id_to_method_submodules_map = {}
701+
method_to_tagged_exported_program = {}
702+
703+
for method_name, partitioner_instance in method_to_partitioner.items():
704+
assert (
705+
method_name in method_to_edge_program
706+
), f"Partitioner for method {method_name} is not provided"
707+
edge_program = method_to_edge_program[method_name]
708+
edge_program._validate()
709+
710+
# Use fake program, with FakeTensors in the state dict, to avoid copying large constant values.
711+
# Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback.
712+
try:
713+
fake_edge_program = get_fake_program(edge_program)
714+
except Exception as e:
715+
logging.warning(
716+
f"Error in get_fake_program for graph {edge_program.graph_module}, fallback to deepcopy: {e}"
717+
)
718+
fake_edge_program = copy.deepcopy(edge_program)
719+
partitioner_result = partitioner_instance(fake_edge_program)
720+
tagged_exported_program = partitioner_result.tagged_exported_program
721+
method_to_tagged_exported_program[method_name] = tagged_exported_program
722+
723+
# Check that the partitioner did not modify the original graph
724+
if _ENABLE_VALIDATION:
725+
assert is_identical_graph(
726+
tagged_exported_program.graph_module,
727+
edge_program.graph_module,
728+
), f"The partitioner {partitioner_instance} should not modify the graph module"
729+
else:
730+
logging.warning("Disabled validating the partitioner.")
731+
732+
assert (
733+
partitioner_result.partition_tags is not None
734+
), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec"
735+
736+
update_to_real_program(tagged_exported_program, edge_program)
737+
738+
for tag, _ in partitioner_result.partition_tags.items():
739+
_maybe_duplicate_constant_nodes(tagged_exported_program, tag)
740+
741+
backend_id_to_call_submodule_nodes = _create_partitions(
742+
tagged_exported_program.graph_module,
743+
partitioner_result,
744+
tagged_exported_program,
745+
)
746+
for (
747+
backend_id,
748+
call_submodule_nodes,
749+
) in backend_id_to_call_submodule_nodes.items():
750+
if backend_id not in backend_id_to_method_submodules_map:
751+
backend_id_to_method_submodules_map[backend_id] = {}
752+
backend_id_to_method_submodules_map[backend_id][
753+
method_name
754+
] = call_submodule_nodes
755+
756+
for (
757+
backend_id,
758+
method_to_submodule_nodes,
759+
) in backend_id_to_method_submodules_map.items():
760+
lower_all_submodules_to_backend(
761+
backend_id,
762+
method_to_submodule_nodes,
763+
method_to_tagged_exported_program,
764+
)
765+
766+
for method_name in method_to_edge_program.keys():
767+
if method_name in method_to_tagged_exported_program:
768+
tagged_exported_program = method_to_tagged_exported_program[method_name]
769+
partitioned_and_lowered_exported_programs[method_name] = ExportedProgram(
770+
root=tagged_exported_program.graph_module,
771+
graph=tagged_exported_program.graph_module.graph,
772+
graph_signature=tagged_exported_program.graph_signature,
773+
state_dict=tagged_exported_program.state_dict,
774+
range_constraints=copy.deepcopy(
775+
tagged_exported_program.range_constraints
776+
),
777+
module_call_graph=copy.deepcopy(
778+
tagged_exported_program.module_call_graph
779+
),
780+
example_inputs=None,
781+
constants=tagged_exported_program.constants,
782+
verifiers=[tagged_exported_program.verifier],
783+
)
784+
else:
785+
# this edge program wasn't partitioned, so we can just return it as is
786+
partitioned_and_lowered_exported_programs[method_name] = (
787+
method_to_edge_program[method_name]
788+
)
789+
790+
return partitioned_and_lowered_exported_programs

0 commit comments

Comments
 (0)