@@ -369,63 +369,47 @@ def sympy_expr_to_semi_affine_expr(
369
369
)
370
370
371
371
372
- @dataclass (frozen = True )
373
- class SparsityMeta :
374
- """
375
- Class for keeping track of sparsity meta data.
376
-
377
- NOTE: this will be fully replaced by
378
- torch.fx.passes.shape_prop.SparseTensorMetadata
379
- """
380
-
381
- layout : torch .layout
382
- batch_dim : int
383
- sparse_dim : int
384
- dense_dim : int
385
- blocksize : Optional [Tuple [int , int ]]
386
- pos_dtype : torch .dtype
387
- crd_dtype : torch .dtype
388
-
389
-
390
- def sparsity_encoding (shape : torch .Size , sparsity : SparsityMeta ) -> str :
391
- """Returns sparse tensor encoding for the given sparse layout as string."""
392
- assert sparsity is not None
372
+ def sparsity_encoding (t : torch .Tensor ) -> str :
373
+ """Returns sparse tensor encoding for the given tensor as string."""
393
374
394
375
# Sparse tensors have the form
395
376
# [ <batch_dimensions> , <sparse_dimensions>, <dense_dimensions> ]
396
377
# which map directly to MLIR types.
397
- batch_dim , sparse_dim , dense_dim = (
398
- sparsity .batch_dim ,
399
- sparsity .sparse_dim ,
400
- sparsity .dense_dim ,
378
+ dim , batch_dim , sparse_dim , dense_dim = (
379
+ t .ndim ,
380
+ t .ndim - t .sparse_dim () - t .dense_dim (),
381
+ t .sparse_dim (),
382
+ t .dense_dim (),
401
383
)
402
- dim = batch_dim + sparse_dim + dense_dim
403
- assert dim == len (shape )
404
- blocksize = sparsity .blocksize
405
-
406
384
dims = "," .join (f"d{ d } " for d in range (dim ))
407
385
408
- if sparsity .layout is torch .sparse_coo :
409
- assert sparse_dim >= 2 and blocksize is None
386
+ if t .layout is torch .sparse_coo :
387
+ assert sparse_dim >= 2
410
388
trail_dim = batch_dim + sparse_dim - 1
411
389
coords = "," .join (
412
390
f"d{ d } :singleton(nonunique,soa)" for d in range (batch_dim + 1 , trail_dim )
413
391
)
414
392
sep = "," if sparse_dim > 2 else ""
415
393
lvls = f"d{ batch_dim } :compressed(nonunique),{ coords } { sep } d{ trail_dim } :singleton(soa)"
416
- elif sparsity .layout is torch .sparse_csr :
417
- assert sparse_dim == 2 and blocksize is None
394
+ idx_dtype = t ._indices ().dtype # supports uncoalesced COO tensors
395
+ elif t .layout is torch .sparse_csr :
396
+ assert sparse_dim == 2
418
397
lvls = f"d{ batch_dim } :dense,d{ batch_dim + 1 } :compressed"
419
- elif sparsity .layout is torch .sparse_csc :
420
- assert sparse_dim == 2 and blocksize is None
398
+ idx_dtype = t .col_indices ().dtype
399
+ elif t .layout is torch .sparse_csc :
400
+ assert sparse_dim == 2
421
401
lvls = f"d{ batch_dim + 1 } :dense,d{ batch_dim } :compressed"
402
+ idx_dtype = t .row_indices ().dtype
422
403
else :
423
- assert sparse_dim == 2 and blocksize is not None
424
- if sparsity .layout is torch .sparse_bsr :
404
+ assert sparse_dim == 2
405
+ blocksize = t .values ().shape [batch_dim + 1 : batch_dim + 3 ]
406
+ if t .layout is torch .sparse_bsr :
425
407
i , j = batch_dim , batch_dim + 1
408
+ idx_dtype = t .col_indices ().dtype
426
409
else :
427
- assert sparsity .layout is torch .sparse_bsc
410
+ assert t .layout is torch .sparse_bsc
428
411
j , i = batch_dim , batch_dim + 1
412
+ idx_dtype = t .row_indices ().dtype
429
413
m , n = blocksize
430
414
lvls = (
431
415
f"d{ i } floordiv { m } :dense,d{ j } floordiv { n } :compressed,"
@@ -440,8 +424,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
440
424
dense = "," .join (f"d{ d } :dense" for d in range (batch_dim + sparse_dim , dim ))
441
425
lvls = f"{ lvls } ,{ dense } "
442
426
443
- posw = torch .iinfo (sparsity .pos_dtype ).bits
444
- crdw = torch .iinfo (sparsity .crd_dtype ).bits
427
+ posw = crdw = torch .iinfo (idx_dtype ).bits
445
428
return f"#sparse_tensor.encoding<{{map=({ dims } )->({ lvls } ),posWidth={ posw } ,crdWidth={ crdw } }}>"
446
429
447
430
@@ -1043,20 +1026,27 @@ def get_vtensor_type(
1043
1026
shape : torch .Size ,
1044
1027
dtype : torch .dtype ,
1045
1028
* ,
1046
- sparsity : Optional [SparsityMeta ] = None ,
1029
+ val : Optional [torch . Tensor ] = None ,
1047
1030
mutable : bool = False ,
1048
1031
):
1049
1032
"""Return IrType for !torch.vtensor with the given shape and dtype"""
1050
1033
stem = "torch.tensor" if mutable else "torch.vtensor"
1051
1034
shape_asm = self .format_asm_shape (shape )
1052
1035
mlir_dtype = str (self .dtype_to_type (dtype ))
1053
- if sparsity is not None :
1054
- encoding = sparsity_encoding (shape , sparsity )
1055
- assert encoding is not None
1036
+ if val is not None and val .layout in [
1037
+ torch .sparse_coo ,
1038
+ torch .sparse_csr ,
1039
+ torch .sparse_csc ,
1040
+ torch .sparse_bsr ,
1041
+ torch .sparse_bsc ,
1042
+ ]:
1043
+ # This is a sparse tensor.
1044
+ encoding = sparsity_encoding (val )
1056
1045
return IrType .parse (
1057
1046
f"!{ stem } <[{ shape_asm } ],{ str (mlir_dtype )} ,{ encoding } >" ,
1058
1047
context = self ._c ,
1059
1048
)
1049
+ # This is a dense tensor.
1060
1050
return IrType .parse (
1061
1051
f"!{ stem } <[{ shape_asm } ],{ str (mlir_dtype )} >" , context = self ._c
1062
1052
)
@@ -1065,21 +1055,17 @@ def node_val_to_type(self, node: torch_fx.Node, *, mutable: bool = False) -> IrT
1065
1055
try :
1066
1056
tensor_meta = node .meta .get ("tensor_meta" )
1067
1057
val = node .meta .get ("val" )
1068
- sparsity = node .meta .get ("sparsity" , None )
1069
1058
except KeyError as e :
1070
1059
raise RuntimeError (
1071
1060
f"FIXME: Illegal access to torch.fx.Node.meta: { e } ({ node .meta .keys ()} : { node .meta } )"
1072
1061
)
1073
- return self .value_info_to_type (
1074
- val , tensor_meta = tensor_meta , sparsity = sparsity , mutable = mutable
1075
- )
1062
+ return self .value_info_to_type (val , tensor_meta = tensor_meta , mutable = mutable )
1076
1063
1077
1064
def value_info_to_type (
1078
1065
self ,
1079
1066
val ,
1080
1067
* ,
1081
1068
tensor_meta : Optional [TensorMetadata ] = None ,
1082
- sparsity = None ,
1083
1069
mutable : bool = False ,
1084
1070
):
1085
1071
if tensor_meta is not None :
@@ -1097,14 +1083,14 @@ def value_info_to_type(
1097
1083
)
1098
1084
else :
1099
1085
return self .tensor_metadata_to_type (
1100
- tensor_meta , sparsity = sparsity , mutable = mutable
1086
+ tensor_meta , val = val , mutable = mutable
1101
1087
)
1102
1088
elif val is not None :
1103
1089
# some nodes with symbolic inputs pass a 'val' attribute rather than
1104
1090
# tensor_meta
1105
1091
if isinstance (val , TorchFakeTensor ):
1106
1092
return self .get_vtensor_type (
1107
- val .size (), val .dtype , sparsity = sparsity , mutable = mutable
1093
+ val .size (), val .dtype , val = val , mutable = mutable
1108
1094
)
1109
1095
elif isinstance (val , list ) and all (
1110
1096
isinstance (x , TorchFakeTensor ) for x in val
@@ -1126,19 +1112,17 @@ def tensor_metadata_to_type(
1126
1112
self ,
1127
1113
tm : TensorMetadata ,
1128
1114
* ,
1129
- sparsity : Optional [SparsityMeta ] = None ,
1115
+ val : Optional [torch . Tensor ] = None ,
1130
1116
mutable : bool = False ,
1131
1117
) -> IrType :
1132
1118
tm_shape = tuple (
1133
1119
item .node if is_symbolic (item ) else item for item in list (tm .shape )
1134
1120
)
1135
1121
1136
- key = (tm_shape , tm .dtype , sparsity , mutable )
1122
+ key = (tm_shape , tm .dtype , val , mutable )
1137
1123
t = self ._tensor_metadata_cache .get (key )
1138
1124
if t is None :
1139
- t = self .get_vtensor_type (
1140
- tm .shape , tm .dtype , sparsity = sparsity , mutable = mutable
1141
- )
1125
+ t = self .get_vtensor_type (tm .shape , tm .dtype , val = val , mutable = mutable )
1142
1126
self ._tensor_metadata_cache [key ] = t
1143
1127
return t
1144
1128
0 commit comments