diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index e8ace51040bf37..7da757ffd6729d 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -48,6 +48,7 @@ schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") +log = logging.getLogger(__name__) def data_type_logger(msg): @@ -2508,13 +2509,72 @@ def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]: return self def __exit__(self, exc_type, exc_val, exc_tb): + self.remove_kernel_local_buffers() + super().__exit__(exc_type, exc_val, exc_tb) + + def remove_kernel_local_buffers(self) -> None: """ + Any buffers that are both created and have a last use in the + same kernel can be removed. + Note that V.graph.scheduler can be None when codegening triton template kernels. """ - if V.graph.scheduler: - V.graph.scheduler.remove_kernel_local_buffers() - super().__exit__(exc_type, exc_val, exc_tb) + scheduler = V.graph.scheduler + if not scheduler: + return + fused_node_names = OrderedSet( + scheduler.name_to_buf[buf].defining_op.get_name() + for buf in self.store_buffer_names + if buf in scheduler.name_to_buf + ) + names_to_remove = [] + for out_buf in self.store_buffer_names: + if out_buf not in scheduler.name_to_buf: + # Aux buffers created during kernel codegen + names_to_remove.append(out_buf) + continue + users = scheduler.name_to_buf[out_buf].users + assert users is not None + users = OrderedSet(user.get_name() for user in users if not user.is_weak) + if users.issubset(fused_node_names): + names_to_remove.append(out_buf) + + def remove_filter(n: str) -> bool: + return ( + n not in self.must_keep_buffers + and n not in self.args.input_buffers + and n not in scheduler.mutation_renames + and n not in scheduler.mutation_real_name + ) + + names_to_remove = [*filter(remove_filter, names_to_remove)] + + for name in names_to_remove: + if name in self.args.inplace_buffers: + buf = self.args.inplace_buffers[name] + if isinstance(buf, str) and buf.startswith("REMOVED"): + continue + remove = all(n in names_to_remove for n in buf.other_names) + if remove: + self.remove_inplace_buffer(name) + self.inplaced_to_remove.add(name) + else: + self.remove_buffer(name) + + def remove_buffer(self, name: str) -> None: + # Assign a special value instead of deleting the entry + # because we still rely on output_buffers's length to + # generate unique arg name. + log.debug("remove_buffer(%r)", name) + self.args.output_buffers[name] = "REMOVED" + self.removed_buffers.add(name) + + def remove_inplace_buffer(self, name: str) -> None: + log.debug("removing_inplace_buffer(%r)", name) + inner_name = self.args.inplace_buffers[name].inner_name + self.args.inplace_buffers[name] = inner_name.replace("in_out_ptr", "REMOVED") + self.removed_buffers.add(name) def rename_indexing(self, index) -> sympy.Expr: # adds the necessary kernel args for index expressions diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 44cc970bdaefcd..8263cf5380ade3 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3331,67 +3331,6 @@ def free_buffers(self) -> None: self.buffer_names_to_free.clear() - def remove_kernel_local_buffers(self) -> None: - """ - Any buffers that are both created and have a last use in the - same kernel can be removed. - """ - - fused_node_names = OrderedSet( - self.name_to_buf[buf].defining_op.get_name() - for buf in V.kernel.store_buffer_names - if buf in self.name_to_buf - ) - names_to_remove = [] - for out_buf in V.kernel.store_buffer_names: - if out_buf not in self.name_to_buf: - # Aux buffers created during kernel codegen - names_to_remove.append(out_buf) - continue - users = self.name_to_buf[out_buf].users - assert users is not None - users = OrderedSet(user.get_name() for user in users if not user.is_weak) - if users.issubset(fused_node_names): - names_to_remove.append(out_buf) - - def remove_filter(n: str) -> bool: - return ( - n not in V.kernel.must_keep_buffers - and n not in V.kernel.args.input_buffers - and n not in self.mutation_renames - and n not in self.mutation_real_name - ) - - names_to_remove = list(filter(remove_filter, names_to_remove)) - - for name in names_to_remove: - if name in V.kernel.args.inplace_buffers: - buf = V.kernel.args.inplace_buffers[name] - if isinstance(buf, str) and buf.startswith("REMOVED"): - continue - remove = all(n in names_to_remove for n in buf.other_names) - if remove: - self.remove_inplace_buffer(name) - V.kernel.inplaced_to_remove.add(name) - else: - self.remove_buffer(name) - - def remove_buffer(self, name: str) -> None: - # Assign a special value instead of deleting the entry - # because we still rely on output_buffers's length to - # generate unique arg name. - log.debug("remove_buffer(%r)", name) - V.kernel.args.output_buffers[name] = "REMOVED" - V.kernel.removed_buffers.add(name) - - def remove_inplace_buffer(self, name: str) -> None: - log.debug("removing_inplace_buffer(%r)", name) - inner_name = V.kernel.args.inplace_buffers[name].inner_name - V.kernel.args.inplace_buffers[name] = inner_name.replace( - "in_out_ptr", "REMOVED" - ) - V.kernel.removed_buffers.add(name) - def flush(self) -> None: for backend in self.backends.values(): backend.flush()