Skip to content

Remove reference cycle - with exceptions #228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion torchdynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def transform(instructions, code_options):
)
tracer.run()
output = tracer.output
output.cleanup()
assert output.output_instructions
instructions[:] = output.output_instructions
code_options.update(output.code_options)
Expand Down
8 changes: 8 additions & 0 deletions torchdynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def restore_graphstate(self, state):
# FX deepcopy doesn't work for a partially created graph, so just remove new nodes
for node in reversed(list(self.graph.nodes)):
if node not in graph_nodes:
# Erasing node alone does not remove the meta information
# So, remove the help tensor explicitly
if "example_value" in node.meta:
del node.meta["example_value"]
self.graph.erase_node(node)

def count_calls(self):
Expand Down Expand Up @@ -374,3 +378,7 @@ def cleanup(self):
# Cleanup graphargs
for graph_arg in self.graphargs:
graph_arg.erase()

for node in self.graph.nodes:
if "example_value" in node.meta:
del node.meta["example_value"]
8 changes: 8 additions & 0 deletions torchdynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,14 @@ def run(self):
f"{self.lineno} {typestr(e)}\n"
)
raise
finally:
# Cleanup the outputGraph to delete the held tensors. We perform the
# cleanup only for InstructionTranslator and not
# InliningInstructionTranslator. The InliningInstructionTranslator
# mutates the output object and is restored to original state if
# there was an exception.
if isinstance(self, InstructionTranslator):
self.output.cleanup()

def push(self, val):
assert val is None or isinstance(
Expand Down