Skip to content

Commit

Permalink
Warning for to_dict when torch.exporting (#2247)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2247

to_dict for non-strict torch.export is extremely slow and can almost always be avoided. This diff adds a warning to notify the user of this case.

Reviewed By: TroyGarden

Differential Revision: D60293485

fbshipit-source-id: 980f9911de4509d6259865dbc85965e2c66589c6
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Jul 29, 2024
1 parent ddcfd64 commit 981a37b
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# pyre-strict

import abc
import logging

import operator

Expand Down Expand Up @@ -51,6 +52,8 @@
except ImportError:
pass

logger: logging.Logger = logging.getLogger()


def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
if is_torchdynamo_compiling():
Expand Down Expand Up @@ -2226,6 +2229,10 @@ def __getitem__(self, key: str) -> JaggedTensor:
)

def to_dict(self) -> Dict[str, JaggedTensor]:
if not torch.jit.is_scripting() and is_non_strict_exporting():
logger.warn(
"Trying to non-strict torch.export KJT to_dict, which is extremely slow and not recommended!"
)
_jt_dict = _maybe_compute_kjt_to_jt_dict(
stride=self.stride(),
stride_per_key=self.stride_per_key(),
Expand Down

0 comments on commit 981a37b

Please sign in to comment.