Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 30, 2025
1 parent 5efe950 commit 90e8f8b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ def _new_unsafe(
nested: bool = True,
**kwargs: dict[str, Any] | None,
) -> TensorDict:
if is_compiling():
if is_compiling() and cls is TensorDict:
# If the cls is not TensorDict, we must escape this to keep the same class.
# That's unfortunate because as of now it graph breaks but that's the best we can do.
return TensorDict(
source,
batch_size=batch_size,
Expand Down
23 changes: 23 additions & 0 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,29 @@ def locked_op(tc):
tc_op_c = locked_op_c(data)
assert (tc_op == tc_op_c).all()

def test_td_new_unsafe(self, mode):

class MyTd(TensorDict):
pass

def func_td():
return TensorDict._new_unsafe(a=torch.randn(3), batch_size=torch.Size(()))

@torch.compile(fullgraph=True, mode=mode)
def func_c_td():
return TensorDict._new_unsafe(a=torch.randn(3), batch_size=torch.Size(()))

def func_mytd():
return MyTd._new_unsafe(a=torch.randn(3), batch_size=torch.Size(()))

# This will graph break
@torch.compile(mode=mode)
def func_c_mytd():
return MyTd._new_unsafe(a=torch.randn(3), batch_size=torch.Size(()))

assert type(func_td()) is type(func_c_td())
assert type(func_mytd()) is type(func_c_mytd())


@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
Expand Down

0 comments on commit 90e8f8b

Please sign in to comment.