Skip to content

Commit

Permalink
[inductor] sympy.Integer([01]) -> sympy.S.(Zero|One) (pytorch#139523)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#139523
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#139364, pytorch#139365, pytorch#139370, pytorch#139452
  • Loading branch information
jansel authored and pytorchmergebot committed Nov 4, 2024
1 parent b6fb135 commit ed30fa7
Show file tree
Hide file tree
Showing 19 changed files with 77 additions and 85 deletions.
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class TensorArg:
name: str
buffer: str
dtype: torch.dtype
offset: sympy.Expr = sympy.Integer(0) # c++ only
offset: sympy.Expr = sympy.S.Zero # c++ only
alias_of: Optional[str] = None # halide only


Expand Down
6 changes: 3 additions & 3 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def stride_at(index: sympy.Expr, var: sympy.Symbol):
# see test_torchinductor_dynamic_shapes.py::test_full_boolean_dynamic_shapes_cpu
# which has tmp0 = ops.index_expr(s0 >= 1024, torch.bool) and fails below calculation.
# in this case, there is no dependencies between index and var.
return sympy.Integer(0)
return sympy.S.Zero
replacement = {var: var + 1}
new_index = sympy_subs(index, replacement) # type: ignore[arg-type]
return sympy.simplify(new_index - index)
Expand Down Expand Up @@ -4711,8 +4711,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
class LoopLevel:
var: Optional[sympy.Expr] = None
size: Optional[sympy.Expr] = None
offset: sympy.Expr = sympy.Integer(0)
steps: sympy.Expr = sympy.Integer(1)
offset: sympy.Expr = sympy.S.Zero
steps: sympy.Expr = sympy.S.One
parallel: int = 0
simd_omp: bool = False
simd_vec: bool = False
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/cpp_template_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def store_pointwise_nodes(
for i, sz in enumerate(var_sizes[0])
}
if not offsets:
offsets = [sympy.Integer(0)] * len(var_sizes[0])
offsets = [sympy.S.Zero] * len(var_sizes[0])
if not reindexers:
reindexers = [None] * len(nodes)
assert len(offsets) == len(var_sizes[0])
Expand Down
18 changes: 9 additions & 9 deletions torch/_inductor/codegen/halide.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def visit_floor_div(base, divisor):
if not nodes:
nodes.append(tree.lookup(1, tree.numel))
handled_count = 0
divisor = sympy.Integer(1)
divisor = sympy.S.One
added_sym_size = []
# decide on a minimal set of symbols and put them in self.halide_vars
while handled_count < len(nodes) and not eq(tree.numel, divisor):
Expand Down Expand Up @@ -846,7 +846,7 @@ def visit_floor_div(base, divisor):
idx += 1
divisor *= size
length = 1
expr = sympy.Integer(0)
expr = sympy.S.Zero
while not eq(node.length, length):
sym, size = added_sym_size[idx]
idx += 1
Expand All @@ -855,8 +855,8 @@ def visit_floor_div(base, divisor):
self.index_replacements[node.symbol()] = expr
except IndexError:
assert had_fallback
full_index = sympy.Integer(0)
stride = sympy.Integer(1)
full_index = sympy.S.Zero
stride = sympy.S.One
for sym, size in added_sym_size:
full_index += stride * sym
stride *= size
Expand Down Expand Up @@ -937,8 +937,8 @@ def indexing_to_dimensions(self, var: str, index: sympy.Expr, is_store: bool):
), sym

# group the expression by variables used
offset = sympy.Integer(0)
split_expr = {s: sympy.Integer(0) for s in symbols}
offset = sympy.S.Zero
split_expr = {s: sympy.S.Zero for s in symbols}
split_failed: List[Tuple[List[sympy.Symbol], sympy.Expr]] = []
index = sympy.expand(self.rename_indexing(index))
for part in index.args if isinstance(index, sympy.Add) else [index]:
Expand Down Expand Up @@ -972,7 +972,7 @@ def expr_to_dimension(expr, syms):
length = sympy.simplify(
sympy_subs(expr, {sym: self.sym_size(sym) - 1 for sym in syms}) + 1
)
stride = sympy.Integer(1)
stride = sympy.S.One
if isinstance(expr, sympy.Mul):
for term in expr.args:
if isinstance(term, sympy.Integer):
Expand All @@ -994,11 +994,11 @@ def expr_to_dimension(expr, syms):
if not dims: # scalar load/store
if self.has_indirect_indexing:
# workaround https://github.com/halide/Halide/issues/8338
dims.append(DimensionInfo(sympy.Integer(0), 1, 1))
dims.append(DimensionInfo(sympy.S.Zero, 1, 1))
elif not V.graph.sizevars.statically_known_equals(dims[0].stride, 1):
# Halide assumes dimension 0 is stride == 1, so add a dummy dimension
dims.insert(
0, DimensionInfo(sympy.Integer(0), 1 if is_store else dims[0].stride, 1)
0, DimensionInfo(sympy.S.Zero, 1 if is_store else dims[0].stride, 1)
)

if dims and not is_store:
Expand Down
16 changes: 8 additions & 8 deletions torch/_inductor/codegen/simd.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def __init__(
prefix: str,
*,
kernel: SIMDKernel,
divisor=sympy.Integer(1),
length=sympy.Integer(1),
divisor=sympy.S.One,
length=sympy.S.One,
root: IterationRangesRoot,
) -> None:
super().__init__()
Expand Down Expand Up @@ -205,7 +205,7 @@ def lookup(self, divisor, length):
return self.nodes[expr]

def construct_entries(self, lengths: List[sympy.Expr]):
divisor = sympy.Integer(1)
divisor = sympy.S.One
itervars = []
for length in reversed(lengths):
itervars.append(self.lookup(divisor, length))
Expand All @@ -224,7 +224,7 @@ def vars_and_sizes(self, index: sympy.Expr):
x.divisor, fallback=config.unbacked_symint_fallback
)
)
divisor = sympy.Integer(1)
divisor = sympy.S.One
index_vars = []
sizes = []

Expand Down Expand Up @@ -481,7 +481,7 @@ def combine_modular_indexing_pairs(self, index):
new_index,
{
tree_node.root.index_sym(): tree_node.root.lookup(
sympy.Integer(1), tree_node.root.numel
sympy.S.One, tree_node.root.numel
).symbol()
},
)
Expand Down Expand Up @@ -572,7 +572,7 @@ def getter(flat_vars):
return_getters = []
for size in length_group:
if sv.statically_known_equals(size, 1): # type: ignore[arg-type]
return_getters.append(lambda _: sympy.Integer(0))
return_getters.append(lambda _: sympy.S.Zero)
continue

while current_group < len(remaining) and sv.statically_known_equals(
Expand Down Expand Up @@ -635,7 +635,7 @@ def split_and_set_ranges(self, lengths: List[List[sympy.Expr]]):
"""
groups = [rt.numel for rt in self.range_trees]
if not self.inside_reduction:
groups[-1] = sympy.Integer(1)
groups[-1] = sympy.S.One

if len(lengths) == len(self.range_trees) and all(
V.graph.sizevars.simplify(sympy_product(x) - g) == 0
Expand Down Expand Up @@ -1564,7 +1564,7 @@ def candidate_tilings(node):
return tilings

@classmethod
def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.Integer(1)):
def select_tiling(cls, node_schedule, numel, reduction_numel=sympy.S.One):
"""
Heuristics to decide how to tile kernels.
Currently, we tile based on stride-1 dimensions.
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/simd_kernel_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
self,
node_schedule: List[NodeScheduleEntry],
numel: sympy.Expr,
reduction_numel: sympy.Expr = sympy.Integer(1),
reduction_numel: sympy.Expr = sympy.S.One,
):
self.node_schedule = node_schedule
self.numel = V.graph.sizevars.simplify(numel) # numel excludes reduction_numel
Expand Down
22 changes: 9 additions & 13 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def codegen_broadcast_and_reshape(

# Reshape to add singletons.
pre_broadcast_shape = [
sympy.Integer(1) if is_broadcasting else dim
sympy.S.One if is_broadcasting else dim
for dim, is_broadcasting in zip(
self.broadcast_shape, self.broadcasting_dims
)
Expand Down Expand Up @@ -342,7 +342,7 @@ def remove_dims(it):
and V.kernel.numels[-1] != 1
):
# Need to expand rank by 1 to match rank when self.inside_reduction=True
final_shape.append(sympy.Integer(1))
final_shape.append(sympy.S.One)

return BlockPtrOptions(
params=params,
Expand Down Expand Up @@ -375,9 +375,7 @@ def format(self, name: str, roffset=True) -> str:
f = V.kernel.index_to_str
offsets = [*self.offsets]
if not roffset:
offsets = [
self.replace_roffset(offset, sympy.Integer(0)) for offset in offsets
]
offsets = [self.replace_roffset(offset, sympy.S.Zero) for offset in offsets]
args = [
(
f"{name} + ({f(self.constant_offset)})"
Expand Down Expand Up @@ -408,9 +406,7 @@ def boundary_check(self) -> List[int]:
idx
for idx in range(len(self.shape))
if (
not sizevars.statically_known_equals(
self.strides[idx], sympy.Integer(0)
)
not sizevars.statically_known_equals(self.strides[idx], sympy.S.Zero)
and not sizevars.statically_known_multiple_of(
self.shape[idx], self.block_shape[idx]
)
Expand All @@ -437,7 +433,7 @@ def advance_roffset(self):
advance = [
(
self.replace_roffset(offset, rblock)
- self.replace_roffset(offset, sympy.Integer(0))
- self.replace_roffset(offset, sympy.S.Zero)
)
for offset in self.offsets
]
Expand Down Expand Up @@ -1655,7 +1651,7 @@ def get_slice_numels(dims: List[Any]) -> List[Any]:
Compute the cumulative size of each dimension's slice.
This proceeds from the last dim up to the second.
"""
numels = [sympy.Integer(1)]
numels = [sympy.S.One]
for dim in dims[:0:-1]:
numel = dim * numels[0]
numels.insert(0, numel)
Expand All @@ -1680,10 +1676,10 @@ def get_slice_numels(dims: List[Any]) -> List[Any]:
# Provide default values for unmatched dims and strides.
for dim in dims[1:]:
if dim not in match:
match[dim] = sympy.Integer(1)
match[dim] = sympy.S.One
for stride in strides[1:]:
if stride not in match:
match[stride] = sympy.Integer(0)
match[stride] = sympy.S.Zero

sizevars = V.graph.sizevars

Expand Down Expand Up @@ -1786,7 +1782,7 @@ def match_block_pointer() -> Optional[BlockPtrOptions]:
# For example xindex * 5 + rindex * 3 is partitioned to
# (xindex * 5, rindex * 3).
symbol = tree.symbol()
subexpr = sympy.Integer(0) + sum(
subexpr = sympy.S.Zero + sum(
expr for expr in index_terms if symbol in expr.free_symbols
)

Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def get_numel(self) -> sympy.Expr:
numel = V.graph.get_numel(self.name)
else:
vars: OrderedSet[sympy.Basic] = OrderedSet(self.index.free_symbols)
numel = sympy.Integer(1)
numel = sympy.S.One
for var, size in zip(self.var_names, self.size):
if var in vars:
numel = numel * size
Expand Down Expand Up @@ -328,7 +328,7 @@ def index(self):
raise NotImplementedError("WeakDep does not have an index")

def get_numel(self) -> sympy.Expr:
return sympy.Integer(1)
return sympy.S.One

def rename(self, renames: Dict[str, str]) -> "WeakDep":
if self.name in renames:
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/index_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def mod(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
if not is_integer_dtype(result_type):
return NotImplemented

result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
result_expr = ModularIndexing(x.expr, sympy.S.One, y.expr)
return TypedExpr(result_expr, result_type)

@staticmethod
Expand All @@ -152,7 +152,7 @@ def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
x_expr.is_nonnegative is not None
and x_expr.is_nonnegative == y_expr.is_positive
):
result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
result_expr = ModularIndexing(x.expr, sympy.S.One, y.expr)
return TypedExpr(result_expr, result_type)
return NotImplemented

Expand Down
Loading

0 comments on commit ed30fa7

Please sign in to comment.