|
8 | 8 | import copy
|
9 | 9 | import logging
|
10 | 10 | from contextlib import contextmanager, nullcontext
|
| 11 | +from dataclasses import dataclass |
11 | 12 | from functools import singledispatch
|
12 |
| -from typing import Generator, List |
| 13 | +from typing import Dict, Generator, List |
13 | 14 |
|
14 | 15 | import torch
|
15 | 16 |
|
@@ -417,3 +418,373 @@ def to_backend(
|
417 | 418 | constants=tagged_exported_program.constants,
|
418 | 419 | verifiers=[tagged_exported_program.verifier],
|
419 | 420 | )
|
| 421 | + |
| 422 | + |
| 423 | +def _create_partitions_in_graph_module( |
| 424 | + tagged_graph_module: torch.fx.GraphModule, |
| 425 | + partition_result: PartitionResult, |
| 426 | + owning_program: ExportedProgram, |
| 427 | + is_submodule: bool, |
| 428 | +) -> Dict[str, List[torch.fx.Node]]: |
| 429 | + backend_id_to_submodule_name = {} |
| 430 | + for tag, delegation_spec in partition_result.partition_tags.items(): |
| 431 | + # Create partition with nodes containing this tag. There should only be |
| 432 | + # one contained submodule per tag |
| 433 | + node_list = _get_node_list_with_same_tag( |
| 434 | + tagged_graph_module, tag, owning_program |
| 435 | + ) |
| 436 | + |
| 437 | + if len(node_list) == 0: |
| 438 | + logging.debug(f"Did not find any nodes for tag {tag}") |
| 439 | + continue |
| 440 | + |
| 441 | + logging.debug(f"For tag {tag}, found nodes {node_list}") |
| 442 | + # Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs) |
| 443 | + |
| 444 | + replace_ctx = ( |
| 445 | + tagged_graph_module._set_replace_hook( |
| 446 | + owning_program.graph_signature.get_replace_hook() |
| 447 | + ) |
| 448 | + if not is_submodule |
| 449 | + else nullcontext() |
| 450 | + ) |
| 451 | + with replace_ctx: |
| 452 | + submodule, call_module_node = create_submodule_from_nodes( |
| 453 | + tagged_graph_module, node_list, tag |
| 454 | + ) |
| 455 | + |
| 456 | + tagged_graph_module_output_node = [ |
| 457 | + node for node in tagged_graph_module.graph.nodes if node.op == "output" |
| 458 | + ][0] |
| 459 | + submodule_output_node = [ |
| 460 | + node for node in submodule.graph.nodes if node.op == "output" |
| 461 | + ][0] |
| 462 | + # Copy the output node meta from the original output node, because |
| 463 | + # create_submodule_from_nodes doesn't cover the meta field |
| 464 | + submodule_output_node.meta = tagged_graph_module_output_node.meta |
| 465 | + logging.debug(f"Partitioned graph module: {tagged_graph_module}") |
| 466 | + ( |
| 467 | + submodule_program, |
| 468 | + toplevel_input_specs_to_delete, |
| 469 | + toplevel_output_specs_to_delete, |
| 470 | + ) = create_exported_program_from_submodule( |
| 471 | + submodule, |
| 472 | + owning_program, |
| 473 | + tag, |
| 474 | + call_module_node, |
| 475 | + is_submodule, |
| 476 | + ) |
| 477 | + call_module_node.meta["backend_id"] = delegation_spec.backend_id |
| 478 | + call_module_node.meta["compile_spec"] = delegation_spec.compile_specs |
| 479 | + call_module_node.meta["submodule_program"] = submodule_program |
| 480 | + call_module_node.meta["toplevel_input_specs_to_delete"] = ( |
| 481 | + toplevel_input_specs_to_delete |
| 482 | + ) |
| 483 | + call_module_node.meta["toplevel_output_specs_to_delete"] = ( |
| 484 | + toplevel_output_specs_to_delete |
| 485 | + ) |
| 486 | + call_module_node.meta["is_submodule"] = is_submodule |
| 487 | + |
| 488 | + if delegation_spec.backend_id not in backend_id_to_submodule_name: |
| 489 | + backend_id_to_submodule_name[delegation_spec.backend_id] = [] |
| 490 | + |
| 491 | + # The call_module_node created here might not be the same node instance as |
| 492 | + # the one in the final graph module. This is because this node might be replaced |
| 493 | + # in future edits to the graph. As a result, we just keep track of the node's name |
| 494 | + # and at the end we search for this node in our final graph module |
| 495 | + backend_id_to_submodule_name[delegation_spec.backend_id].append( |
| 496 | + call_module_node.target |
| 497 | + ) |
| 498 | + |
| 499 | + created_submodule_nodes = dict( |
| 500 | + (key, []) for key in backend_id_to_submodule_name.keys() |
| 501 | + ) |
| 502 | + for backend_id, submodule_name in backend_id_to_submodule_name.items(): |
| 503 | + for node in tagged_graph_module.graph.nodes: |
| 504 | + if node.op == "call_module" and node.target in submodule_name: |
| 505 | + created_submodule_nodes[backend_id].append(node) |
| 506 | + |
| 507 | + # check the number of submodule_names and submodule_nodes are equal |
| 508 | + for backend_id in created_submodule_nodes.keys(): |
| 509 | + assert len(created_submodule_nodes[backend_id]) == len( |
| 510 | + backend_id_to_submodule_name[backend_id] |
| 511 | + ) |
| 512 | + |
| 513 | + return created_submodule_nodes |
| 514 | + |
| 515 | + |
| 516 | +def _create_partitions( |
| 517 | + tagged_graph_module: torch.fx.GraphModule, |
| 518 | + partition_result: PartitionResult, |
| 519 | + owning_program: ExportedProgram, |
| 520 | + is_submodule: bool = False, |
| 521 | +) -> Dict[str, List[torch.fx.Node]]: |
| 522 | + backend_id_to_call_submodules = _create_partitions_in_graph_module( |
| 523 | + tagged_graph_module, partition_result, owning_program, is_submodule |
| 524 | + ) |
| 525 | + |
| 526 | + # Recursively partition and lower for submodules |
| 527 | + for _, submod, _ in get_control_flow_submodules(tagged_graph_module): |
| 528 | + nested_backend_id_to_call_submodules = _create_partitions( |
| 529 | + submod, partition_result, owning_program, is_submodule=True |
| 530 | + ) |
| 531 | + for ( |
| 532 | + backend_id, |
| 533 | + nested_submodules, |
| 534 | + ) in nested_backend_id_to_call_submodules.items(): |
| 535 | + if backend_id not in backend_id_to_call_submodules: |
| 536 | + backend_id_to_call_submodules[backend_id] = nested_submodules |
| 537 | + else: |
| 538 | + backend_id_to_call_submodules[backend_id].extend(nested_submodules) |
| 539 | + |
| 540 | + return backend_id_to_call_submodules |
| 541 | + |
| 542 | + |
| 543 | +def lower_all_submodules_to_backend( |
| 544 | + backend_id: str, |
| 545 | + method_to_submodules_nodes: Dict[str, List[torch.fx.Node]], |
| 546 | + method_to_tagged_edge_program: Dict[str, ExportedProgram], |
| 547 | +) -> None: |
| 548 | + """ |
| 549 | + Lower all submodules nodes given in the method_to_submodule_nodes map to backend_id. |
| 550 | + """ |
| 551 | + # The created exported program for the submodules are in the call_module node's meta data |
| 552 | + # We just map the method_to_submodule_nodes directly to the method_to_partitioned_exported_programs |
| 553 | + method_to_partitioned_program = { |
| 554 | + method_name: [node.meta["submodule_program"] for node in call_submodule_nodes] |
| 555 | + for method_name, call_submodule_nodes in method_to_submodules_nodes.items() |
| 556 | + } |
| 557 | + method_to_compile_specs = { |
| 558 | + method_name: [node.meta["compile_spec"] for node in call_submodule_nodes] |
| 559 | + for method_name, call_submodule_nodes in method_to_submodules_nodes.items() |
| 560 | + } |
| 561 | + backend_found = False |
| 562 | + for cls in BackendDetails.__subclasses__(): |
| 563 | + if backend_id == cls.__name__: |
| 564 | + method_to_preprocess_result: dict[str, List[PreprocessResult]] = ( |
| 565 | + cls.preprocess_multimethod( |
| 566 | + method_to_partitioned_program, method_to_compile_specs |
| 567 | + ) |
| 568 | + ) |
| 569 | + backend_found = True |
| 570 | + |
| 571 | + if not backend_found: |
| 572 | + raise NotImplementedError(f"Backend {backend_id} was not found.") |
| 573 | + |
| 574 | + for method_name in method_to_preprocess_result.keys(): |
| 575 | + owning_program = method_to_tagged_edge_program[method_name] |
| 576 | + list_of_preprocess_results = method_to_preprocess_result[method_name] |
| 577 | + list_of_call_submodule_nodes = method_to_submodules_nodes[method_name] |
| 578 | + list_of_compile_specs = method_to_compile_specs[method_name] |
| 579 | + assert ( |
| 580 | + len(list_of_preprocess_results) == len(list_of_call_submodule_nodes), |
| 581 | + f"Expected {len(list_of_call_submodule_nodes)} preprocessed results for method {method_name} but got {len(list_of_preprocess_results)}", |
| 582 | + ) |
| 583 | + for preprocess_result, call_submodule_node, compile_spec in zip( |
| 584 | + list_of_preprocess_results, |
| 585 | + list_of_call_submodule_nodes, |
| 586 | + list_of_compile_specs, |
| 587 | + ): |
| 588 | + submodule_program = call_submodule_node.meta["submodule_program"] |
| 589 | + lowered_module = LoweredBackendModule( |
| 590 | + edge_program=submodule_program, |
| 591 | + backend_id=backend_id, |
| 592 | + processed_bytes=preprocess_result.processed_bytes, |
| 593 | + compile_specs=compile_spec, |
| 594 | + ) |
| 595 | + owning_graph_module = call_submodule_node.graph.owning_module |
| 596 | + is_submodule = call_submodule_node.meta["is_submodule"] |
| 597 | + toplevel_input_specs_to_delete = call_submodule_node.meta[ |
| 598 | + "toplevel_input_specs_to_delete" |
| 599 | + ] |
| 600 | + toplevel_output_specs_to_delete = call_submodule_node.meta[ |
| 601 | + "toplevel_output_specs_to_delete" |
| 602 | + ] |
| 603 | + # call delegate args should only use user_inputs |
| 604 | + call_delegate_args = [] |
| 605 | + # Preserve input order as user_inputs |
| 606 | + for inp_name in submodule_program.graph_signature.user_inputs: |
| 607 | + for inp_node in call_submodule_node.all_input_nodes: |
| 608 | + if inp_node.name == inp_name: |
| 609 | + call_delegate_args.append(inp_node) |
| 610 | + break |
| 611 | + |
| 612 | + def generate_debug_handle(ep: ExportedProgram) -> int: |
| 613 | + """ |
| 614 | + Generate a debug handle for the given ExportedProgram. |
| 615 | + """ |
| 616 | + debug_handle = 0 |
| 617 | + for node in ep.graph_module.graph.nodes: |
| 618 | + debug_handle = max(debug_handle, node.meta.get("debug_handle", 0)) |
| 619 | + return debug_handle + 1 |
| 620 | + |
| 621 | + # Replace the partitioned submodule with a lowered submodule |
| 622 | + # Add call_method node with function "forward" |
| 623 | + with owning_graph_module.graph.inserting_before(call_submodule_node): |
| 624 | + lowered_name = get_lowered_module_name( |
| 625 | + owning_graph_module, lowered_module |
| 626 | + ) |
| 627 | + lowered_node = owning_graph_module.graph.get_attr(lowered_name) |
| 628 | + call_delegate_node = owning_graph_module.graph.call_function( |
| 629 | + executorch_call_delegate, |
| 630 | + (lowered_node,) + tuple(call_delegate_args), |
| 631 | + call_submodule_node.kwargs, |
| 632 | + ) |
| 633 | + call_delegate_node.meta["debug_handle"] = generate_debug_handle( |
| 634 | + owning_program |
| 635 | + ) |
| 636 | + call_delegate_node.meta["val"] = call_submodule_node.meta["val"] |
| 637 | + call_submodule_node.replace_all_uses_with(call_delegate_node) |
| 638 | + owning_graph_module.graph.erase_node(call_submodule_node) |
| 639 | + |
| 640 | + if is_submodule: |
| 641 | + assert len(toplevel_input_specs_to_delete) == 0 |
| 642 | + assert len(toplevel_output_specs_to_delete) == 0 |
| 643 | + elif ( |
| 644 | + len(toplevel_input_specs_to_delete) > 0 |
| 645 | + or len(toplevel_output_specs_to_delete) > 0 |
| 646 | + ): |
| 647 | + _unsafe_adjust_original_program( |
| 648 | + owning_program, |
| 649 | + call_delegate_node, |
| 650 | + toplevel_input_specs_to_delete, |
| 651 | + toplevel_output_specs_to_delete, |
| 652 | + ) |
| 653 | + |
| 654 | + |
| 655 | +@dataclass |
| 656 | +class MethodProgramsPartitionerSpec: |
| 657 | + """ |
| 658 | + Since single dispatch for to_backend requires the first argument to be a |
| 659 | + valid class, we create the following dataclass spec to hold the dictionaries |
| 660 | + mapping the method name to the corresponding program, partitioner |
| 661 | + """ |
| 662 | + |
| 663 | + method_to_edge_program: Dict[str, ExportedProgram] |
| 664 | + method_to_partitioner: Dict[str, Partitioner] |
| 665 | + |
| 666 | + |
| 667 | +@to_backend.register |
| 668 | +def _( |
| 669 | + method_edge_program_partitioners: MethodProgramsPartitionerSpec, |
| 670 | +) -> Dict[str, ExportedProgram]: |
| 671 | + """ |
| 672 | + Add overloaded implementations for to_backend: |
| 673 | +
|
| 674 | + :: |
| 675 | +
|
| 676 | + def to_backend( |
| 677 | + method_edge_program_partitioners: MethodProgramsPartitionerSpec |
| 678 | + ) -> Dict[str, ExportedProgram]: |
| 679 | +
|
| 680 | + Returns a semantically-equivalent dictionary of programs to the programs given as input (represented |
| 681 | + as a graph module in Edge dialect), but with portions of the program targeted for |
| 682 | + delegation as determined by the partitioner. |
| 683 | +
|
| 684 | + Args: |
| 685 | + method_edge_program_partitioners: contains two mappings, |
| 686 | + - method_to_edge_program: mapping of method names to their respective programs in Edge dialect. |
| 687 | + - method_to_partitioner: mapping of method names to an instance of the partitioner, in charge with tagging |
| 688 | + portions of the specified program for delegation. A valid partitioner must return PartitionerResult |
| 689 | + including both tagged exported program and partitioner_tag: Dict[str, DelegationSpec], where each key is a tag name and |
| 690 | + the nodes with same tag will be fused a one subgraph and delegated to backend specififed in delegation spec. |
| 691 | +
|
| 692 | +
|
| 693 | + Returns: |
| 694 | + ExportedProgram: The input program, with some portions targeted for delegation. |
| 695 | + """ |
| 696 | + method_to_edge_program = method_edge_program_partitioners.method_to_edge_program |
| 697 | + method_to_partitioner = method_edge_program_partitioners.method_to_partitioner |
| 698 | + |
| 699 | + partitioned_and_lowered_exported_programs = {} |
| 700 | + backend_id_to_method_submodules_map = {} |
| 701 | + method_to_tagged_exported_program = {} |
| 702 | + |
| 703 | + for method_name, partitioner_instance in method_to_partitioner.items(): |
| 704 | + assert ( |
| 705 | + method_name in method_to_edge_program |
| 706 | + ), f"Partitioner for method {method_name} is not provided" |
| 707 | + edge_program = method_to_edge_program[method_name] |
| 708 | + edge_program._validate() |
| 709 | + |
| 710 | + # Use fake program, with FakeTensors in the state dict, to avoid copying large constant values. |
| 711 | + # Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback. |
| 712 | + try: |
| 713 | + fake_edge_program = get_fake_program(edge_program) |
| 714 | + except Exception as e: |
| 715 | + logging.warning( |
| 716 | + f"Error in get_fake_program for graph {edge_program.graph_module}, fallback to deepcopy: {e}" |
| 717 | + ) |
| 718 | + fake_edge_program = copy.deepcopy(edge_program) |
| 719 | + partitioner_result = partitioner_instance(fake_edge_program) |
| 720 | + tagged_exported_program = partitioner_result.tagged_exported_program |
| 721 | + method_to_tagged_exported_program[method_name] = tagged_exported_program |
| 722 | + |
| 723 | + # Check that the partitioner did not modify the original graph |
| 724 | + if _ENABLE_VALIDATION: |
| 725 | + assert is_identical_graph( |
| 726 | + tagged_exported_program.graph_module, |
| 727 | + edge_program.graph_module, |
| 728 | + ), f"The partitioner {partitioner_instance} should not modify the graph module" |
| 729 | + else: |
| 730 | + logging.warning("Disabled validating the partitioner.") |
| 731 | + |
| 732 | + assert ( |
| 733 | + partitioner_result.partition_tags is not None |
| 734 | + ), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec" |
| 735 | + |
| 736 | + update_to_real_program(tagged_exported_program, edge_program) |
| 737 | + |
| 738 | + for tag, _ in partitioner_result.partition_tags.items(): |
| 739 | + _maybe_duplicate_constant_nodes(tagged_exported_program, tag) |
| 740 | + |
| 741 | + backend_id_to_call_submodule_nodes = _create_partitions( |
| 742 | + tagged_exported_program.graph_module, |
| 743 | + partitioner_result, |
| 744 | + tagged_exported_program, |
| 745 | + ) |
| 746 | + for ( |
| 747 | + backend_id, |
| 748 | + call_submodule_nodes, |
| 749 | + ) in backend_id_to_call_submodule_nodes.items(): |
| 750 | + if backend_id not in backend_id_to_method_submodules_map: |
| 751 | + backend_id_to_method_submodules_map[backend_id] = {} |
| 752 | + backend_id_to_method_submodules_map[backend_id][ |
| 753 | + method_name |
| 754 | + ] = call_submodule_nodes |
| 755 | + |
| 756 | + for ( |
| 757 | + backend_id, |
| 758 | + method_to_submodule_nodes, |
| 759 | + ) in backend_id_to_method_submodules_map.items(): |
| 760 | + lower_all_submodules_to_backend( |
| 761 | + backend_id, |
| 762 | + method_to_submodule_nodes, |
| 763 | + method_to_tagged_exported_program, |
| 764 | + ) |
| 765 | + |
| 766 | + for method_name in method_to_edge_program.keys(): |
| 767 | + if method_name in method_to_tagged_exported_program: |
| 768 | + tagged_exported_program = method_to_tagged_exported_program[method_name] |
| 769 | + partitioned_and_lowered_exported_programs[method_name] = ExportedProgram( |
| 770 | + root=tagged_exported_program.graph_module, |
| 771 | + graph=tagged_exported_program.graph_module.graph, |
| 772 | + graph_signature=tagged_exported_program.graph_signature, |
| 773 | + state_dict=tagged_exported_program.state_dict, |
| 774 | + range_constraints=copy.deepcopy( |
| 775 | + tagged_exported_program.range_constraints |
| 776 | + ), |
| 777 | + module_call_graph=copy.deepcopy( |
| 778 | + tagged_exported_program.module_call_graph |
| 779 | + ), |
| 780 | + example_inputs=None, |
| 781 | + constants=tagged_exported_program.constants, |
| 782 | + verifiers=[tagged_exported_program.verifier], |
| 783 | + ) |
| 784 | + else: |
| 785 | + # this edge program wasn't partitioned, so we can just return it as is |
| 786 | + partitioned_and_lowered_exported_programs[method_name] = ( |
| 787 | + method_to_edge_program[method_name] |
| 788 | + ) |
| 789 | + |
| 790 | + return partitioned_and_lowered_exported_programs |
0 commit comments