Skip to content

build: manually update PyTorch version #4102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
@@ -520,6 +520,7 @@
"ReflectionPad3dModuleBack_basic",
# RuntimeError: Unknown function SliceOutOfLowerBoundEndIndexModule
"SliceOutOfLowerBoundEndIndexModule_basic",
"NativeGroupNormModule_basic",
}

FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | {
@@ -954,6 +955,7 @@
"AtenSymConstrainRange_basic",
"AtenSymConstrainRangeForSize_basic",
"Aten_AssertScalar_basic",
"NativeGroupNormModule_basic",
}

FX_IMPORTER_STABLEHLO_CRASHING_SET = {
Original file line number Diff line number Diff line change
@@ -134,7 +134,9 @@ def invoke_func(*torch_inputs):
def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace:
result: Trace = []
for item in trace:
prog: ExportedProgram = torch.export.export(artifact, tuple(item.inputs))
prog: ExportedProgram = torch.export.export(
artifact, tuple(item.inputs), strict=True
)
module = fx.export_and_import(
prog,
output_type=self._output_type,
5 changes: 4 additions & 1 deletion python/torch_mlir/fx.py
Original file line number Diff line number Diff line change
@@ -73,6 +73,7 @@ def export_and_import(
output_type: Union[str, OutputType] = OutputType.RAW,
fx_importer: Optional[FxImporter] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
strict: bool = False,
experimental_support_mutation: bool = False,
import_symbolic_shape_expressions: bool = False,
hooks: Optional[FxImporterHooks] = None,
@@ -94,7 +95,9 @@ def export_and_import(
else:
# pytorch 2.1 or lower doesn't have `dyanmic_shapes` keyword argument in torch.export
if version.Version(torch.__version__) >= version.Version("2.2.0"):
prog = torch.export.export(f, args, kwargs, dynamic_shapes=dynamic_shapes)
prog = torch.export.export(
f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=strict
)
else:
prog = torch.export.export(f, args, kwargs)
if decomposition_table is None:
2 changes: 1 addition & 1 deletion pytorch-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
cdb42bd8cc05bef0ec9b682b274c2acb273f2d62
3794824ceb12a9d4396eaa17795bf2147fd9e1c3
2 changes: 1 addition & 1 deletion pytorch-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torch/
--pre
torch==2.7.0.dev20250310
torch==2.8.0.dev20250325
9 changes: 2 additions & 7 deletions test/python/fx_importer/basic_test.py
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@ def run(f):
@run
# CHECK-LABEL: test_import_frozen_exported_program
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_1_4_torch.float32> : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32>
# CHECK-DAG: %[[a:.+]] = torch.aten.randn
# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_3_1_torch.float32> : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32>
# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]]
@@ -42,7 +42,6 @@ def run(f):
#
# Validate dialect resources exist.
# CHECK: dialect_resources:
# CHECK-DAG: torch_tensor_1_4_torch.float32
# CHECK-DAG: torch_tensor_3_1_torch.float32
def test_import_frozen_exported_program():
# Tests the basic structural premises of import_frozen_exported_program,
@@ -210,11 +209,7 @@ def forward(self):
@run
# CHECK-LABEL: test_stack_trace
# CHECK: #loc[[LOC1:.+]] = loc(
# CHECK: #loc[[LOC2:.+]] = loc(
# CHECK: #loc[[LOC3:.+]] = loc(
# CHECK: #loc[[LOC4:.+]] = loc(callsite(#loc[[LOC2]] at #loc[[LOC3]]))
# CHECK: #loc[[LOC5:.+]] = loc(callsite(#loc[[LOC1]] at #loc[[LOC4]]))
# CHECK: %{{.+}} = torch.aten.add.Tensor {{.+}} loc(#loc[[LOC4]])
# CHECK: %{{.+}} = torch.aten.add.Tensor {{.+}} loc(#loc[[LOC1]])
def test_stack_trace():
class Basic(nn.Module):
def __init__(self):
1 change: 1 addition & 0 deletions test/python/fx_importer/symbolic_shape_expr_test.py
Original file line number Diff line number Diff line change
@@ -222,6 +222,7 @@ def forward(self, x):
SliceTensorDynamicOutput(),
x,
dynamic_shapes=dynamic_shapes,
strict=True,
import_symbolic_shape_expressions=True,
)
print(m)
2 changes: 1 addition & 1 deletion test/python/fx_importer/v2.3/mutation_import.py
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@ def run(f):
# This doesn't do mutation but ensures that the basics remain functional.
# CHECK-LABEL: test_import_frozen_exported_program
# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_1_4_torch.float32> : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32>
# CHECK-DAG: %[[a:.+]] = torch.aten.randn
# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource<torch_tensor_3_1_torch.float32> : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32>
# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32>
# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]]
2 changes: 1 addition & 1 deletion torchvision-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/nightly/cpu/torchvision/
--pre
torchvision==0.22.0.dev20250310
torchvision==0.22.0.dev20250325