Skip to content

Commit

Permalink
[inductor] Simplify remove_kernel_local_buffers (pytorch#139452)
Browse files Browse the repository at this point in the history
I plan to reuse `can_buffer_be_removed_through_fusion` in some heuristics.

Pull Request resolved: pytorch#139452
Approved by: https://github.com/shunting314
ghstack dependencies: pytorch#139364, pytorch#139365, pytorch#139370
  • Loading branch information
jansel authored and pytorchmergebot committed Nov 4, 2024
1 parent 3d633f1 commit b6fb135
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 21 deletions.
31 changes: 10 additions & 21 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2528,27 +2528,16 @@ def remove_kernel_local_buffers(self) -> None:
for buf in self.store_buffer_names
if buf in scheduler.name_to_buf
)
names_to_remove = []
for out_buf in self.store_buffer_names:
if out_buf not in scheduler.name_to_buf:
# Aux buffers created during kernel codegen
names_to_remove.append(out_buf)
continue
users = scheduler.name_to_buf[out_buf].users
assert users is not None
users = OrderedSet(user.get_name() for user in users if not user.is_weak)
if users.issubset(fused_node_names):
names_to_remove.append(out_buf)

def remove_filter(n: str) -> bool:
return (
n not in self.must_keep_buffers
and n not in self.args.input_buffers
and n not in scheduler.mutation_renames
and n not in scheduler.mutation_real_name
)

names_to_remove = [*filter(remove_filter, names_to_remove)]
names_to_remove: OrderedSet[str] = OrderedSet()
for name in self.store_buffer_names:
if (
name not in self.must_keep_buffers
and name not in self.args.input_buffers
and scheduler.can_buffer_be_removed_through_fusion(
name, fused_node_names
)
):
names_to_remove.add(name)

for name in names_to_remove:
if name in self.args.inplace_buffers:
Expand Down
13 changes: 13 additions & 0 deletions torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3399,6 +3399,19 @@ def get_order(n: torch.fx.Node) -> int:
_, last = max(origins, key=operator.itemgetter(0))
V.graph.wrapper_code.enter_context(last)

def can_buffer_be_removed_through_fusion(
self, name: str, fused_node_names: OrderedSet[str]
) -> bool:
try:
users = self.name_to_buf[name].users
except KeyError:
return False
return (
all(user.is_weak or user.get_name() in fused_node_names for user in users)
and name not in self.mutation_renames
and name not in self.mutation_real_name
)

def codegen(self) -> None:
with dynamo_timed("Scheduler.codegen"):
return self._codegen()
Expand Down

0 comments on commit b6fb135

Please sign in to comment.