Skip to content

Commit

Permalink
[inductor] Move remove_kernel_local_buffers to Kernel (pytorch#139370)
Browse files Browse the repository at this point in the history
This method mutates the kernel, so it fits better in that class.

Pull Request resolved: pytorch#139370
Approved by: https://github.com/shunting314
ghstack dependencies: pytorch#139364, pytorch#139365
  • Loading branch information
jansel authored and pytorchmergebot committed Nov 4, 2024
1 parent 66d5e24 commit 3d633f1
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 64 deletions.
66 changes: 63 additions & 3 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@


schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
log = logging.getLogger(__name__)


def data_type_logger(msg):
Expand Down Expand Up @@ -2508,13 +2509,72 @@ def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.remove_kernel_local_buffers()
super().__exit__(exc_type, exc_val, exc_tb)

def remove_kernel_local_buffers(self) -> None:
"""
Any buffers that are both created and have a last use in the
same kernel can be removed.
Note that V.graph.scheduler can be None when codegening triton template
kernels.
"""
if V.graph.scheduler:
V.graph.scheduler.remove_kernel_local_buffers()
super().__exit__(exc_type, exc_val, exc_tb)
scheduler = V.graph.scheduler
if not scheduler:
return
fused_node_names = OrderedSet(
scheduler.name_to_buf[buf].defining_op.get_name()
for buf in self.store_buffer_names
if buf in scheduler.name_to_buf
)
names_to_remove = []
for out_buf in self.store_buffer_names:
if out_buf not in scheduler.name_to_buf:
# Aux buffers created during kernel codegen
names_to_remove.append(out_buf)
continue
users = scheduler.name_to_buf[out_buf].users
assert users is not None
users = OrderedSet(user.get_name() for user in users if not user.is_weak)
if users.issubset(fused_node_names):
names_to_remove.append(out_buf)

def remove_filter(n: str) -> bool:
return (
n not in self.must_keep_buffers
and n not in self.args.input_buffers
and n not in scheduler.mutation_renames
and n not in scheduler.mutation_real_name
)

names_to_remove = [*filter(remove_filter, names_to_remove)]

for name in names_to_remove:
if name in self.args.inplace_buffers:
buf = self.args.inplace_buffers[name]
if isinstance(buf, str) and buf.startswith("REMOVED"):
continue
remove = all(n in names_to_remove for n in buf.other_names)
if remove:
self.remove_inplace_buffer(name)
self.inplaced_to_remove.add(name)
else:
self.remove_buffer(name)

def remove_buffer(self, name: str) -> None:
# Assign a special value instead of deleting the entry
# because we still rely on output_buffers's length to
# generate unique arg name.
log.debug("remove_buffer(%r)", name)
self.args.output_buffers[name] = "REMOVED"
self.removed_buffers.add(name)

def remove_inplace_buffer(self, name: str) -> None:
log.debug("removing_inplace_buffer(%r)", name)
inner_name = self.args.inplace_buffers[name].inner_name
self.args.inplace_buffers[name] = inner_name.replace("in_out_ptr", "REMOVED")
self.removed_buffers.add(name)

def rename_indexing(self, index) -> sympy.Expr:
# adds the necessary kernel args for index expressions
Expand Down
61 changes: 0 additions & 61 deletions torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3331,67 +3331,6 @@ def free_buffers(self) -> None:

self.buffer_names_to_free.clear()

def remove_kernel_local_buffers(self) -> None:
"""
Any buffers that are both created and have a last use in the
same kernel can be removed.
"""

fused_node_names = OrderedSet(
self.name_to_buf[buf].defining_op.get_name()
for buf in V.kernel.store_buffer_names
if buf in self.name_to_buf
)
names_to_remove = []
for out_buf in V.kernel.store_buffer_names:
if out_buf not in self.name_to_buf:
# Aux buffers created during kernel codegen
names_to_remove.append(out_buf)
continue
users = self.name_to_buf[out_buf].users
assert users is not None
users = OrderedSet(user.get_name() for user in users if not user.is_weak)
if users.issubset(fused_node_names):
names_to_remove.append(out_buf)

def remove_filter(n: str) -> bool:
return (
n not in V.kernel.must_keep_buffers
and n not in V.kernel.args.input_buffers
and n not in self.mutation_renames
and n not in self.mutation_real_name
)

names_to_remove = list(filter(remove_filter, names_to_remove))

for name in names_to_remove:
if name in V.kernel.args.inplace_buffers:
buf = V.kernel.args.inplace_buffers[name]
if isinstance(buf, str) and buf.startswith("REMOVED"):
continue
remove = all(n in names_to_remove for n in buf.other_names)
if remove:
self.remove_inplace_buffer(name)
V.kernel.inplaced_to_remove.add(name)
else:
self.remove_buffer(name)

def remove_buffer(self, name: str) -> None:
# Assign a special value instead of deleting the entry
# because we still rely on output_buffers's length to
# generate unique arg name.
log.debug("remove_buffer(%r)", name)
V.kernel.args.output_buffers[name] = "REMOVED"
V.kernel.removed_buffers.add(name)

def remove_inplace_buffer(self, name: str) -> None:
log.debug("removing_inplace_buffer(%r)", name)
inner_name = V.kernel.args.inplace_buffers[name].inner_name
V.kernel.args.inplace_buffers[name] = inner_name.replace(
"in_out_ptr", "REMOVED"
)
V.kernel.removed_buffers.add(name)

def flush(self) -> None:
for backend in self.backends.values():
backend.flush()
Expand Down

0 comments on commit 3d633f1

Please sign in to comment.