Skip to content

Commit

Permalink
Register serialized name to KJT/JT (#1713)
Browse files Browse the repository at this point in the history
Summary:

Trying to reland D51312977. Hopefully this wont break any torch
packages as D53139358 recently did a similar change where called
`register_pytree_node` with an extra new argument, `flatten_with_keys_fn`,
which was added recently in D52547850.

Differential Revision: D53857843
  • Loading branch information
angelayi authored and facebook-github-bot committed Feb 16, 2024
1 parent bedeba5 commit a695a2f
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import abc
import operator
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -635,7 +636,11 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten


register_pytree_node(
JaggedTensor, _jt_flatten, _jt_unflatten, flatten_with_keys_fn=_jt_flatten_with_keys
JaggedTensor,
_jt_flatten,
_jt_unflatten,
flatten_with_keys_fn=_jt_flatten_with_keys,
serialized_type_name="torchrec.JaggedTensor",
)
register_pytree_flatten_spec(JaggedTensor, _jt_flatten_spec)

Expand Down Expand Up @@ -2094,6 +2099,7 @@ def _kjt_flatten_spec(
_kjt_flatten,
_kjt_unflatten,
flatten_with_keys_fn=_kjt_flatten_with_keys,
serialized_type_name="torchrec.KeyedJaggedTensor",
)
register_pytree_flatten_spec(KeyedJaggedTensor, _kjt_flatten_spec)

Expand Down

0 comments on commit a695a2f

Please sign in to comment.