Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-land D66465376 #2637

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

TroyGarden
Copy link
Contributor

Summary:
Re-land diff D66465376

NOTE: use jit.ignore on the forward function to get rid of jit script error with TensorDict

def test_td_scripting(self) -> None:
    class TestModule(torch.nn.Module):
        torch.jit.ignore # <----- test fails without this ignore
        def forward(self, x: Union[TensorDict, KeyedJaggedTensor]) -> torch.Tensor:
            if isinstance(x, TensorDict):
                keys = list(x.keys())
                return torch.cat([x[key]._values for key in keys], dim=0)
            else:
                return x._values

    m = TestModule()
    gm = torch.fx.symbolic_trace(m)
    jm = torch.jit.script(gm)
    values = torch.tensor([0, 1, 2, 3, 2, 3, 4])
    kjt = KeyedJaggedTensor.from_offsets_sync(
        keys=["f1", "f2", "f3"],
        values=values,
        offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]),
    )
    torch.testing.assert_allclose(jm(kjt), values)

Differential Revision: D66460392

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 14, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66460392

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Dec 14, 2024
Summary:

Re-land diff D66465376

NOTE: use jit.ignore on the forward function to get rid of jit script error with `TensorDict`
```
def test_td_scripting(self) -> None:
    class TestModule(torch.nn.Module):
        torch.jit.ignore # <----- test fails without this ignore
        def forward(self, x: Union[TensorDict, KeyedJaggedTensor]) -> torch.Tensor:
            if isinstance(x, TensorDict):
                keys = list(x.keys())
                return torch.cat([x[key]._values for key in keys], dim=0)
            else:
                return x._values

    m = TestModule()
    gm = torch.fx.symbolic_trace(m)
    jm = torch.jit.script(gm)
    values = torch.tensor([0, 1, 2, 3, 2, 3, 4])
    kjt = KeyedJaggedTensor.from_offsets_sync(
        keys=["f1", "f2", "f3"],
        values=values,
        offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]),
    )
    torch.testing.assert_allclose(jm(kjt), values)
```

Differential Revision: D66460392
TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Dec 15, 2024
Summary:

Re-land diff D66465376

NOTE: use jit.ignore on the forward function to get rid of jit script error with `TensorDict`
```
def test_td_scripting(self) -> None:
    class TestModule(torch.nn.Module):
        torch.jit.ignore # <----- test fails without this ignore
        def forward(self, x: Union[TensorDict, KeyedJaggedTensor]) -> torch.Tensor:
            if isinstance(x, TensorDict):
                keys = list(x.keys())
                return torch.cat([x[key]._values for key in keys], dim=0)
            else:
                return x._values

    m = TestModule()
    gm = torch.fx.symbolic_trace(m)
    jm = torch.jit.script(gm)
    values = torch.tensor([0, 1, 2, 3, 2, 3, 4])
    kjt = KeyedJaggedTensor.from_offsets_sync(
        keys=["f1", "f2", "f3"],
        values=values,
        offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]),
    )
    torch.testing.assert_allclose(jm(kjt), values)
```

Differential Revision: D66460392
Summary:

Re-land diff D66465376

NOTE: use jit.ignore on the forward function to get rid of jit script error with `TensorDict`
```
def test_td_scripting(self) -> None:
    class TestModule(torch.nn.Module):
        torch.jit.ignore # <----- test fails without this ignore
        def forward(self, x: Union[TensorDict, KeyedJaggedTensor]) -> torch.Tensor:
            if isinstance(x, TensorDict):
                keys = list(x.keys())
                return torch.cat([x[key]._values for key in keys], dim=0)
            else:
                return x._values

    m = TestModule()
    gm = torch.fx.symbolic_trace(m)
    jm = torch.jit.script(gm)
    values = torch.tensor([0, 1, 2, 3, 2, 3, 4])
    kjt = KeyedJaggedTensor.from_offsets_sync(
        keys=["f1", "f2", "f3"],
        values=values,
        offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]),
    )
    torch.testing.assert_allclose(jm(kjt), values)
```

Differential Revision: D66460392
TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Dec 30, 2024
Summary:

Re-land diff D66465376

NOTE: use jit.ignore on the forward function to get rid of jit script error with `TensorDict`
```
def test_td_scripting(self) -> None:
    class TestModule(torch.nn.Module):
        torch.jit.ignore # <----- test fails without this ignore
        def forward(self, x: Union[TensorDict, KeyedJaggedTensor]) -> torch.Tensor:
            if isinstance(x, TensorDict):
                keys = list(x.keys())
                return torch.cat([x[key]._values for key in keys], dim=0)
            else:
                return x._values

    m = TestModule()
    gm = torch.fx.symbolic_trace(m)
    jm = torch.jit.script(gm)
    values = torch.tensor([0, 1, 2, 3, 2, 3, 4])
    kjt = KeyedJaggedTensor.from_offsets_sync(
        keys=["f1", "f2", "f3"],
        values=values,
        offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]),
    )
    torch.testing.assert_allclose(jm(kjt), values)
```

Differential Revision: D66460392
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D66460392

TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Dec 30, 2024
Summary:

Re-land diff D66465376

NOTE: use jit.ignore on the forward function to get rid of jit script error with `TensorDict`
```
def test_td_scripting(self) -> None:
    class TestModule(torch.nn.Module):
        torch.jit.ignore # <----- test fails without this ignore
        def forward(self, x: Union[TensorDict, KeyedJaggedTensor]) -> torch.Tensor:
            if isinstance(x, TensorDict):
                keys = list(x.keys())
                return torch.cat([x[key]._values for key in keys], dim=0)
            else:
                return x._values

    m = TestModule()
    gm = torch.fx.symbolic_trace(m)
    jm = torch.jit.script(gm)
    values = torch.tensor([0, 1, 2, 3, 2, 3, 4])
    kjt = KeyedJaggedTensor.from_offsets_sync(
        keys=["f1", "f2", "f3"],
        values=values,
        offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]),
    )
    torch.testing.assert_allclose(jm(kjt), values)
```

Differential Revision: D66460392
TroyGarden added a commit to TroyGarden/torchrec that referenced this pull request Dec 31, 2024
Summary:

Re-land diff D66465376

NOTE: use jit.ignore on the forward function to get rid of jit script error with `TensorDict`
```
def test_td_scripting(self) -> None:
    class TestModule(torch.nn.Module):
        torch.jit.ignore # <----- test fails without this ignore
        def forward(self, x: Union[TensorDict, KeyedJaggedTensor]) -> torch.Tensor:
            if isinstance(x, TensorDict):
                keys = list(x.keys())
                return torch.cat([x[key]._values for key in keys], dim=0)
            else:
                return x._values

    m = TestModule()
    gm = torch.fx.symbolic_trace(m)
    jm = torch.jit.script(gm)
    values = torch.tensor([0, 1, 2, 3, 2, 3, 4])
    kjt = KeyedJaggedTensor.from_offsets_sync(
        keys=["f1", "f2", "f3"],
        values=values,
        offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]),
    )
    torch.testing.assert_allclose(jm(kjt), values)
```

Differential Revision: D66460392
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants