Skip to content

Commit

Permalink
Deprecate InjectMlirDebuginfoPass for odml_torch default migration
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 711867888
  • Loading branch information
chunnienc authored and copybara-github committed Jan 3, 2025
1 parent 7407150 commit 7e68dce
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 99 deletions.
40 changes: 22 additions & 18 deletions ai_edge_torch/_convert/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
# ==============================================================================

import logging
import os
from typing import Any, Literal, Optional, Union

import ai_edge_torch
from ai_edge_torch import fx_pass_base
from ai_edge_torch import lowertools
from ai_edge_torch import model
Expand All @@ -26,30 +26,34 @@
from ai_edge_torch.quantize import quant_config as qcfg
import torch

os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"


def _run_convert_passes(
exported_program: torch.export.ExportedProgram,
) -> torch.export.ExportedProgram:
exported_program = generative_fx_passes.run_generative_passes(
exported_program
)
exported_program = fx_pass_base.run_passes(
exported_program,
[
fx_passes.BuildInterpolateCompositePass(),
fx_passes.CanonicalizePass(),
fx_passes.OptimizeLayoutTransposesPass(),
fx_passes.CanonicalizePass(),
fx_passes.BuildAtenCompositePass(),
fx_passes.CanonicalizePass(),
fx_passes.RemoveNonUserOutputsPass(),
fx_passes.CanonicalizePass(),
fx_passes.InjectMlirDebuginfoPass(),
fx_passes.CanonicalizePass(),
],
)

passes = [
fx_passes.BuildInterpolateCompositePass(),
fx_passes.CanonicalizePass(),
fx_passes.OptimizeLayoutTransposesPass(),
fx_passes.CanonicalizePass(),
fx_passes.BuildAtenCompositePass(),
fx_passes.CanonicalizePass(),
fx_passes.RemoveNonUserOutputsPass(),
fx_passes.CanonicalizePass(),
]

# Debuginfo is not injected automatically by odml_torch. Only inject
# debuginfo via fx pass when using torch_xla.
if ai_edge_torch.config.use_torch_xla:
passes += [
fx_passes.InjectMlirDebuginfoPass(),
fx_passes.CanonicalizePass(),
]

exported_program = fx_pass_base.run_passes(exported_program, passes)
return exported_program


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def debuginfo_writer(*args, **kwargs):


class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
"""DEPRECATED: Debuginfo is injected automatically by odml_torch."""

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
Expand Down

This file was deleted.

3 changes: 3 additions & 0 deletions ai_edge_torch/lowertools/torch_xla_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
# https://github.com/google-ai-edge/ai-edge-torch/issues/326
os.environ["PJRT_DEVICE"] = "CPU"

os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"


from ai_edge_torch import model
from ai_edge_torch._convert import conversion_utils
from ai_edge_torch._convert import signature as signature_module
Expand Down

0 comments on commit 7e68dce

Please sign in to comment.