Skip to content

build: manually update PyTorch version #4102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@
"ReflectionPad3dModuleBack_basic",
# RuntimeError: Unknown function SliceOutOfLowerBoundEndIndexModule
"SliceOutOfLowerBoundEndIndexModule_basic",
"NativeGroupNormModule_basic",
}

FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
Expand Down Expand Up @@ -954,6 +955,7 @@
"AtenSymConstrainRange_basic",
"AtenSymConstrainRangeForSize_basic",
"Aten_AssertScalar_basic",
"NativeGroupNormModule_basic",
}

FX_IMPORTER_STABLEHLO_CRASHING_SET = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ def invoke_func(*torch_inputs):
def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
result: Trace = []
for item in trace:
prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs))
prog: ExportedProgram = torch.export.export(
artifact, tuple(item.inputs), strict=True
)
module = fx.export_and_import(
prog,
output_type=self._output_type,
Expand Down
5 changes: 4 additions & 1 deletion python/torch_mlir/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def export_and_import(
output_type: Union[str, OutputType] = OutputType.RAW,
fx_importer: Optional[FxImporter] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
strict: bool = False,
experimental_support_mutation: bool = False,
import_symbolic_shape_expressions: bool = False,
hooks: Optional[FxImporterHooks] = None,
Expand All @@ -94,7 +95,9 @@ def export_and_import(
else:
# pytorch 2.1 or lower doesn't have `dyanmic_shapes` keyword argument in torch.export
if version.Version(torch.__version__) >= version.Version("2.2.0"):
prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes)
prog = torch.export.export(
f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=strict
)
else:
prog = torch.export.export(f, args, kwargs)
if decomposition_table is None:
Expand Down
2 changes: 1 addition & 1 deletion pytorch-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
cdb42bd8cc05bef0ec9b682b274c2acb273f2d62
3794824ceb12a9d4396eaa17795bf2147fd9e1c3
2 changes: 1 addition & 1 deletion pytorch-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch/
--pre
torch==2.7.0.dev20250310
torch==2.8.0.dev20250325
9 changes: 2 additions & 7 deletions test/python/fx_importer/basic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def run(f):
@run
# CHECK-LABEL: test_import_frozen_exported_program
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_1_4_torch.float32> : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32>
# CHECK-DAG: %[[a:.+]] = torch.aten.randn
# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_3_1_torch.float32> : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32>
# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]]
Expand All @@ -42,7 +42,6 @@ def run(f):
#
# Validate dialect resources exist.
# CHECK: dialect_resources:
# CHECK-DAG: torch_tensor_1_4_torch.float32
# CHECK-DAG: torch_tensor_3_1_torch.float32
def test_import_frozen_exported_program():
# Tests the basic structural premises of import_frozen_exported_program,
Expand Down Expand Up @@ -210,11 +209,7 @@ def forward(self):
@run
# CHECK-LABEL: test_stack_trace
# CHECK: #loc[[LOC1:.+]] = loc(
# CHECK: #loc[[LOC2:.+]] = loc(
# CHECK: #loc[[LOC3:.+]] = loc(
# CHECK: #loc[[LOC4:.+]] = loc(callsite(#loc[[LOC2]] at #loc[[LOC3]]))
# CHECK: #loc[[LOC5:.+]] = loc(callsite(#loc[[LOC1]] at #loc[[LOC4]]))
# CHECK: %{{.+}} = torch.aten.add.Tensor {{.+}} loc(#loc[[LOC4]])
# CHECK: %{{.+}} = torch.aten.add.Tensor {{.+}} loc(#loc[[LOC1]])
def test_stack_trace():
class Basic(nn.Module):
def __init__(self):
Expand Down
1 change: 1 addition & 0 deletions test/python/fx_importer/symbolic_shape_expr_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def forward(self, x):
SliceTensorDynamicOutput(),
x,
dynamic_shapes=dynamic_shapes,
strict=True,
import_symbolic_shape_expressions=True,
)
print(m)
Expand Down
2 changes: 1 addition & 1 deletion test/python/fx_importer/v2.3/mutation_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def run(f):
# This doesn't do mutation but ensures that the basics remain functional.
# CHECK-LABEL: test_import_frozen_exported_program
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_1_4_torch.float32> : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32>
# CHECK-DAG: %[[a:.+]] = torch.aten.randn
# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_3_1_torch.float32> : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32>
# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]]
Expand Down
2 changes: 1 addition & 1 deletion torchvision-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torchvision/
--pre
torchvision==0.22.0.dev20250310
torchvision==0.22.0.dev20250325
Loading