diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 6904b4acb3c6..6a86636bd454 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -520,6 +520,7 @@ "ReflectionPad3dModuleBack_basic", # RuntimeError: Unknown function SliceOutOfLowerBoundEndIndexModule "SliceOutOfLowerBoundEndIndexModule_basic", + "NativeGroupNormModule_basic", } FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { @@ -954,6 +955,7 @@ "AtenSymConstrainRange_basic", "AtenSymConstrainRangeForSize_basic", "Aten_AssertScalar_basic", + "NativeGroupNormModule_basic", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py index 10e06411db31..396d43638a42 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py @@ -134,7 +134,9 @@ def invoke_func(*torch_inputs): def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: result: Trace = [] for item in trace: - prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs)) + prog: ExportedProgram = torch.export.export( + artifact, tuple(item.inputs), strict=True + ) module = fx.export_and_import( prog, output_type=self._output_type, diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index f9aca6bec6a4..f68e230990f3 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -73,6 +73,7 @@ def export_and_import( output_type: Union[str, OutputType] = OutputType.RAW, fx_importer: Optional[FxImporter] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, + strict: bool = False, experimental_support_mutation: bool = False, import_symbolic_shape_expressions: bool = False, hooks: Optional[FxImporterHooks] = None, @@ -94,7 +95,9 @@ def export_and_import( else: # pytorch 2.1 or lower doesn't have `dyanmic_shapes` keyword argument in torch.export if version.Version(torch.__version__) >= version.Version("2.2.0"): - prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes) + prog = torch.export.export( + f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=strict + ) else: prog = torch.export.export(f, args, kwargs) if decomposition_table is None: diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 15cdd0e379c1..5c00b88fd067 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -cdb42bd8cc05bef0ec9b682b274c2acb273f2d62 +3794824ceb12a9d4396eaa17795bf2147fd9e1c3 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index bef7533d621a..2307f6b8f3ff 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch/ --pre -torch==2.7.0.dev20250310 +torch==2.8.0.dev20250325 diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 9b1fe163d625..0377d57dc777 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -31,7 +31,7 @@ def run(f): @run # CHECK-LABEL: test_import_frozen_exported_program # CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> -# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32> +# CHECK-DAG: %[[a:.+]] = torch.aten.randn # CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32> # CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> # CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]] @@ -42,7 +42,6 @@ def run(f): # # Validate dialect resources exist. # CHECK: dialect_resources: -# CHECK-DAG: torch_tensor_1_4_torch.float32 # CHECK-DAG: torch_tensor_3_1_torch.float32 def test_import_frozen_exported_program(): # Tests the basic structural premises of import_frozen_exported_program, @@ -210,11 +209,7 @@ def forward(self): @run # CHECK-LABEL: test_stack_trace # CHECK: #loc[[LOC1:.+]] = loc( -# CHECK: #loc[[LOC2:.+]] = loc( -# CHECK: #loc[[LOC3:.+]] = loc( -# CHECK: #loc[[LOC4:.+]] = loc(callsite(#loc[[LOC2]] at #loc[[LOC3]])) -# CHECK: #loc[[LOC5:.+]] = loc(callsite(#loc[[LOC1]] at #loc[[LOC4]])) -# CHECK: %{{.+}} = torch.aten.add.Tensor {{.+}} loc(#loc[[LOC4]]) +# CHECK: %{{.+}} = torch.aten.add.Tensor {{.+}} loc(#loc[[LOC1]]) def test_stack_trace(): class Basic(nn.Module): def __init__(self): diff --git a/test/python/fx_importer/symbolic_shape_expr_test.py b/test/python/fx_importer/symbolic_shape_expr_test.py index 3b8274ccae46..6d1f6dd45150 100644 --- a/test/python/fx_importer/symbolic_shape_expr_test.py +++ b/test/python/fx_importer/symbolic_shape_expr_test.py @@ -222,6 +222,7 @@ def forward(self, x): SliceTensorDynamicOutput(), x, dynamic_shapes=dynamic_shapes, + strict=True, import_symbolic_shape_expressions=True, ) print(m) diff --git a/test/python/fx_importer/v2.3/mutation_import.py b/test/python/fx_importer/v2.3/mutation_import.py index c0214b761467..9530e9524126 100644 --- a/test/python/fx_importer/v2.3/mutation_import.py +++ b/test/python/fx_importer/v2.3/mutation_import.py @@ -31,7 +31,7 @@ def run(f): # This doesn't do mutation but ensures that the basics remain functional. # CHECK-LABEL: test_import_frozen_exported_program # CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> -# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32> +# CHECK-DAG: %[[a:.+]] = torch.aten.randn # CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32> # CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> # CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]] diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 913fe77a8d0e..1be85215b92c 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torchvision/ --pre -torchvision==0.22.0.dev20250310 +torchvision==0.22.0.dev20250325