Skip to content

Commit f72770a

Browse files
authored
[torch-mlir][sparse] replace ad-hoc mechanism with proper FX export (llvm#3648)
Now that the PyDev feature request pytorch/pytorch#117188 has been completed, we can remove all the ad-hoc code that propagates sparsity metadata and replace it with the built-int PyDev metadata for sparse tensors. This removes a lot of code and also ensures sparsity is consistent with the torch.sparse package for all cases.
1 parent 0a86deb commit f72770a

File tree

3 files changed

+58
-190
lines changed

3 files changed

+58
-190
lines changed

python/torch_mlir/extras/fx_importer.py

+40-56
Original file line numberDiff line numberDiff line change
@@ -369,63 +369,47 @@ def sympy_expr_to_semi_affine_expr(
369369
)
370370

371371

372-
@dataclass(frozen=True)
373-
class SparsityMeta:
374-
"""
375-
Class for keeping track of sparsity meta data.
376-
377-
NOTE: this will be fully replaced by
378-
torch.fx.passes.shape_prop.SparseTensorMetadata
379-
"""
380-
381-
layout: torch.layout
382-
batch_dim: int
383-
sparse_dim: int
384-
dense_dim: int
385-
blocksize: Optional[Tuple[int, int]]
386-
pos_dtype: torch.dtype
387-
crd_dtype: torch.dtype
388-
389-
390-
def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
391-
"""Returns sparse tensor encoding for the given sparse layout as string."""
392-
assert sparsity is not None
372+
def sparsity_encoding(t: torch.Tensor) -> str:
373+
"""Returns sparse tensor encoding for the given tensor as string."""
393374

394375
# Sparse tensors have the form
395376
# [ <batch_dimensions> , <sparse_dimensions>, <dense_dimensions> ]
396377
# which map directly to MLIR types.
397-
batch_dim, sparse_dim, dense_dim = (
398-
sparsity.batch_dim,
399-
sparsity.sparse_dim,
400-
sparsity.dense_dim,
378+
dim, batch_dim, sparse_dim, dense_dim = (
379+
t.ndim,
380+
t.ndim - t.sparse_dim() - t.dense_dim(),
381+
t.sparse_dim(),
382+
t.dense_dim(),
401383
)
402-
dim = batch_dim + sparse_dim + dense_dim
403-
assert dim == len(shape)
404-
blocksize = sparsity.blocksize
405-
406384
dims = ",".join(f"d{d}" for d in range(dim))
407385

408-
if sparsity.layout is torch.sparse_coo:
409-
assert sparse_dim >= 2 and blocksize is None
386+
if t.layout is torch.sparse_coo:
387+
assert sparse_dim >= 2
410388
trail_dim = batch_dim + sparse_dim - 1
411389
coords = ",".join(
412390
f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim + 1, trail_dim)
413391
)
414392
sep = "," if sparse_dim > 2 else ""
415393
lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)"
416-
elif sparsity.layout is torch.sparse_csr:
417-
assert sparse_dim == 2 and blocksize is None
394+
idx_dtype = t._indices().dtype # supports uncoalesced COO tensors
395+
elif t.layout is torch.sparse_csr:
396+
assert sparse_dim == 2
418397
lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed"
419-
elif sparsity.layout is torch.sparse_csc:
420-
assert sparse_dim == 2 and blocksize is None
398+
idx_dtype = t.col_indices().dtype
399+
elif t.layout is torch.sparse_csc:
400+
assert sparse_dim == 2
421401
lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed"
402+
idx_dtype = t.row_indices().dtype
422403
else:
423-
assert sparse_dim == 2 and blocksize is not None
424-
if sparsity.layout is torch.sparse_bsr:
404+
assert sparse_dim == 2
405+
blocksize = t.values().shape[batch_dim + 1 : batch_dim + 3]
406+
if t.layout is torch.sparse_bsr:
425407
i, j = batch_dim, batch_dim + 1
408+
idx_dtype = t.col_indices().dtype
426409
else:
427-
assert sparsity.layout is torch.sparse_bsc
410+
assert t.layout is torch.sparse_bsc
428411
j, i = batch_dim, batch_dim + 1
412+
idx_dtype = t.row_indices().dtype
429413
m, n = blocksize
430414
lvls = (
431415
f"d{i} floordiv {m}:dense,d{j} floordiv {n}:compressed,"
@@ -440,8 +424,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
440424
dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim))
441425
lvls = f"{lvls},{dense}"
442426

443-
posw = torch.iinfo(sparsity.pos_dtype).bits
444-
crdw = torch.iinfo(sparsity.crd_dtype).bits
427+
posw = crdw = torch.iinfo(idx_dtype).bits
445428
return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>"
446429

447430

@@ -1043,20 +1026,27 @@ def get_vtensor_type(
10431026
shape: torch.Size,
10441027
dtype: torch.dtype,
10451028
*,
1046-
sparsity: Optional[SparsityMeta] = None,
1029+
val: Optional[torch.Tensor] = None,
10471030
mutable: bool = False,
10481031
):
10491032
"""Return IrType for !torch.vtensor with the given shape and dtype"""
10501033
stem = "torch.tensor" if mutable else "torch.vtensor"
10511034
shape_asm = self.format_asm_shape(shape)
10521035
mlir_dtype = str(self.dtype_to_type(dtype))
1053-
if sparsity is not None:
1054-
encoding = sparsity_encoding(shape, sparsity)
1055-
assert encoding is not None
1036+
if val is not None and val.layout in [
1037+
torch.sparse_coo,
1038+
torch.sparse_csr,
1039+
torch.sparse_csc,
1040+
torch.sparse_bsr,
1041+
torch.sparse_bsc,
1042+
]:
1043+
# This is a sparse tensor.
1044+
encoding = sparsity_encoding(val)
10561045
return IrType.parse(
10571046
f"!{stem}<[{shape_asm}],{str(mlir_dtype)},{encoding}>",
10581047
context=self._c,
10591048
)
1049+
# This is a dense tensor.
10601050
return IrType.parse(
10611051
f"!{stem}<[{shape_asm}],{str(mlir_dtype)}>", context=self._c
10621052
)
@@ -1065,21 +1055,17 @@ def node_val_to_type(self, node: torch_fx.Node, *, mutable: bool = False) -> IrT
10651055
try:
10661056
tensor_meta = node.meta.get("tensor_meta")
10671057
val = node.meta.get("val")
1068-
sparsity = node.meta.get("sparsity", None)
10691058
except KeyError as e:
10701059
raise RuntimeError(
10711060
f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})"
10721061
)
1073-
return self.value_info_to_type(
1074-
val, tensor_meta=tensor_meta, sparsity=sparsity, mutable=mutable
1075-
)
1062+
return self.value_info_to_type(val, tensor_meta=tensor_meta, mutable=mutable)
10761063

10771064
def value_info_to_type(
10781065
self,
10791066
val,
10801067
*,
10811068
tensor_meta: Optional[TensorMetadata] = None,
1082-
sparsity=None,
10831069
mutable: bool = False,
10841070
):
10851071
if tensor_meta is not None:
@@ -1097,14 +1083,14 @@ def value_info_to_type(
10971083
)
10981084
else:
10991085
return self.tensor_metadata_to_type(
1100-
tensor_meta, sparsity=sparsity, mutable=mutable
1086+
tensor_meta, val=val, mutable=mutable
11011087
)
11021088
elif val is not None:
11031089
# some nodes with symbolic inputs pass a 'val' attribute rather than
11041090
# tensor_meta
11051091
if isinstance(val, TorchFakeTensor):
11061092
return self.get_vtensor_type(
1107-
val.size(), val.dtype, sparsity=sparsity, mutable=mutable
1093+
val.size(), val.dtype, val=val, mutable=mutable
11081094
)
11091095
elif isinstance(val, list) and all(
11101096
isinstance(x, TorchFakeTensor) for x in val
@@ -1126,19 +1112,17 @@ def tensor_metadata_to_type(
11261112
self,
11271113
tm: TensorMetadata,
11281114
*,
1129-
sparsity: Optional[SparsityMeta] = None,
1115+
val: Optional[torch.Tensor] = None,
11301116
mutable: bool = False,
11311117
) -> IrType:
11321118
tm_shape = tuple(
11331119
item.node if is_symbolic(item) else item for item in list(tm.shape)
11341120
)
11351121

1136-
key = (tm_shape, tm.dtype, sparsity, mutable)
1122+
key = (tm_shape, tm.dtype, val, mutable)
11371123
t = self._tensor_metadata_cache.get(key)
11381124
if t is None:
1139-
t = self.get_vtensor_type(
1140-
tm.shape, tm.dtype, sparsity=sparsity, mutable=mutable
1141-
)
1125+
t = self.get_vtensor_type(tm.shape, tm.dtype, val=val, mutable=mutable)
11421126
self._tensor_metadata_cache[key] = t
11431127
return t
11441128

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
config.unsupported = True
2+
3+
try:
4+
import torch
5+
if "2.5.0" <= str(torch.__version__):
6+
print("Enabling sparsity propagation tests")
7+
config.unsupported = False
8+
9+
except ModuleNotFoundError:
10+
...

test/python/fx_importer/sparse_test.py test/python/fx_importer/sparsity/sparse_test.py

+8-134
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@
88
from typing import Any, Callable, Optional, Tuple, Dict
99

1010
import torch
11-
import torch.export
1211
import torch.nn as nn
1312
import numpy as np
1413

1514
from torch_mlir.extras.fx_decomp_util import get_decomposition_table
1615
from torch_mlir.extras.fx_importer import FxImporter
17-
from torch_mlir.extras.fx_importer import SparsityMeta
1816
from torch_mlir import ir
1917
from torch_mlir.dialects import torch as torch_d
2018
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
@@ -23,139 +21,15 @@
2321
)
2422

2523

26-
# All sparse layouts currently supported in torch.sparse.
27-
SPARSE_LAYOUTS = [
28-
torch.sparse_coo,
29-
torch.sparse_csr,
30-
torch.sparse_csc,
31-
torch.sparse_bsr,
32-
torch.sparse_bsc,
33-
]
34-
35-
36-
def sparse_metadata(a: torch.Tensor) -> SparsityMeta:
37-
"""
38-
Returns a meta data tuple for the given sparse tensor.
39-
40-
NOTE: this will be fully replaced by fx graph SparseTensorMetadata
41-
"""
42-
sparse_dim = a.sparse_dim()
43-
dense_dim = a.dense_dim()
44-
batch_dim = a.ndim - dense_dim - sparse_dim
45-
blocksize = None
46-
if a.layout is torch.sparse_coo:
47-
return SparsityMeta(
48-
a.layout,
49-
batch_dim,
50-
sparse_dim,
51-
dense_dim,
52-
blocksize,
53-
a._indices().dtype,
54-
a._indices().dtype,
55-
)
56-
elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr:
57-
if a.layout is torch.sparse_bsr:
58-
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
59-
return SparsityMeta(
60-
a.layout,
61-
batch_dim,
62-
sparse_dim,
63-
dense_dim,
64-
blocksize,
65-
a.crow_indices().dtype,
66-
a.col_indices().dtype,
67-
)
68-
elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc:
69-
if a.layout is torch.sparse_bsc:
70-
blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3]
71-
return SparsityMeta(
72-
a.layout,
73-
batch_dim,
74-
sparse_dim,
75-
dense_dim,
76-
blocksize,
77-
a.ccol_indices().dtype,
78-
a.row_indices().dtype,
79-
)
80-
else:
81-
raise RuntimeError(f"Unsupported sparse layout for {a}")
82-
83-
84-
def sparse_export(
85-
f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
86-
) -> torch.export.ExportedProgram:
87-
"""
88-
This is a ***temporary*** wrapper around `torch.export.export`
89-
that eventually should be removed and simply replaced by the
90-
standard API for exporting traced graphs.
91-
92-
But until issue
93-
94-
https://github.com/pytorch/pytorch/pull/117907
95-
96-
is addressed, this wrapper provides support for the sparse
97-
tensor types by first converting all operands to dense tensors,
98-
building the traced graph as for the dense case, then annotating
99-
sparse parameters with their actual sparse layout attributes,
100-
followed by some simple propagation rules. This temporary solution
101-
accelerates testing torch-mlir with PyTorch sparse tensors until
102-
the issue is resolved upstream.
103-
"""
104-
# Convert all arguments to dense.
105-
dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args)
106-
mask = [a.layout in SPARSE_LAYOUTS for a in args]
107-
# Build the regular FX traced graph with only dense arguments
108-
# (the current version would crash otherwise, see issue above).
109-
prog = torch.export.export(f, dargs, kwargs)
110-
decomposition_table = get_decomposition_table()
111-
if decomposition_table:
112-
prog = prog.run_decompositions(decomposition_table)
113-
# Annotate sparse arguments in the graph and apply some very
114-
# basic propagation rules for sparsity.
115-
specs = prog.graph_signature.input_specs
116-
alen = len(specs)
117-
k = 0
118-
for i, node in enumerate(prog.graph.nodes):
119-
if node.op == "placeholder":
120-
# Argument.
121-
spec = specs[i]
122-
if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT:
123-
if mask[k]:
124-
node.meta["sparsity"] = sparse_metadata(args[k])
125-
k = k + 1
126-
elif node.op == "call_function":
127-
opname = node.target._schema.name.split("::")[1]
128-
# Zero preserving elt-wise unary op.
129-
if opname in {"abs", "neg", "relu", "sin"}:
130-
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
131-
elif opname == "_to_sparse" or opname == "to_sparse":
132-
dim = len(node.meta.get("val").shape)
133-
node.meta["sparsity"] = SparsityMeta(
134-
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
135-
)
136-
# TODO: Uncomment this to hack sparsity into the network.
137-
# elif opname == "_to_dense" or opname == "to_dense":
138-
# # hack (assumes we never really want the to_dense for now)
139-
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
140-
elif opname == "select" and node.args[0].meta.get("sparsity", None):
141-
dim = len(node.meta.get("val").shape)
142-
node.meta["sparsity"] = SparsityMeta(
143-
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
144-
)
145-
elif opname == "stack" and node.args[0][0].meta.get("sparsity", None):
146-
dim = len(node.meta.get("val").shape)
147-
node.meta["sparsity"] = SparsityMeta(
148-
torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64
149-
)
150-
return prog
151-
152-
15324
def export_and_import(f, *args, **kwargs):
154-
"""This method implements Stella's importer, stripped down to essentials."""
25+
"""A FX graph importer, stripped down to essentials."""
15526
context = ir.Context()
15627
torch_d.register_dialect(context)
15728
fx_importer = FxImporter(context=context)
158-
prog = sparse_export(f, args, kwargs)
29+
prog = torch.export.export(f, args, kwargs)
30+
decomposition_table = get_decomposition_table()
31+
if decomposition_table:
32+
prog = prog.run_decompositions(decomposition_table)
15933
fx_importer.import_frozen_program(prog)
16034
return fx_importer.module
16135

@@ -175,8 +49,7 @@ def sparse_jit(f, *args, **kwargs):
17549
enable_ir_printing=False,
17650
)
17751
# Compile with reference Linalg backend.
178-
# TODO: runtime verification currently fails with 'rank mismatch' on
179-
# memref.cast. Need to fix the IR first.
52+
# TODO: runtime verification ails with 'rank mismatch' on memref.cast
18053
backend = RefBackendLinalgOnTensorsBackend(generate_runtime_verification=False)
18154
compiled = backend.compile(module)
18255
invoker = backend.load(compiled)
@@ -218,7 +91,8 @@ def sparse_jit(f, *args, **kwargs):
21891

21992

22093
def run(f):
221-
print(f"{f.__name__}")
94+
# Prompt test name and torch version (for debugging).
95+
print(f"{f.__name__} ({torch.__version__})")
22296
print("-" * len(f.__name__))
22397
f()
22498
print()

0 commit comments

Comments
 (0)