Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Same behaviour in compiled and non-compiled versions of _new_unsafe #1197

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ def _new_unsafe(
nested: bool = True,
**kwargs: dict[str, Any] | None,
) -> TensorDict:
if is_compiling():
if is_compiling() and cls is TensorDict:
# If the cls is not TensorDict, we must escape this to keep the same class.
# That's unfortunate because as of now it graph breaks but that's the best we can do.
return TensorDict(
source,
batch_size=batch_size,
Expand Down
23 changes: 23 additions & 0 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,29 @@ def locked_op(tc):
tc_op_c = locked_op_c(data)
assert (tc_op == tc_op_c).all()

def test_td_new_unsafe(self, mode):

class MyTd(TensorDict):
pass

def func_td():
return TensorDict._new_unsafe(a=torch.randn(3), batch_size=torch.Size(()))

@torch.compile(fullgraph=True, mode=mode)
def func_c_td():
return TensorDict._new_unsafe(a=torch.randn(3), batch_size=torch.Size(()))

def func_mytd():
return MyTd._new_unsafe(a=torch.randn(3), batch_size=torch.Size(()))

# This will graph break
@torch.compile(mode=mode)
def func_c_mytd():
return MyTd._new_unsafe(a=torch.randn(3), batch_size=torch.Size(()))

assert type(func_td()) is type(func_c_td())
assert type(func_mytd()) is type(func_c_mytd())


@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
Expand Down
Loading