Skip to content

Commit

Permalink
Fix the type mismatch issue during TGIF publish (#1796)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1796

Annotate _embedding_names_per_rank_per_sharding to make sure it's type of List[List[List[str]]]. The annotation has to be in a separate function with input parameter so that it won't be dropped during symbolic trace.

Reviewed By: s4ayub

Differential Revision: D54442150

fbshipit-source-id: 8940a31c405eb6a5cb947c057ca4e77a08b6f245
  • Loading branch information
Min Yu authored and facebook-github-bot committed Mar 18, 2024
1 parent da1c013 commit 1970969
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def _construct_jagged_tensors_tw(
for i in range(len(embeddings)):
embeddings_i: torch.Tensor = embeddings[i]
features_i: KeyedJaggedTensor = features[i]
if features_i.lengths().numel() == 0:
# No table on the rank, skip.
continue

lengths = features_i.lengths().view(-1, features_i.stride())
values = features_i.values()
Expand Down Expand Up @@ -287,6 +290,34 @@ def _construct_jagged_tensors(
return _construct_jagged_tensors_tw(embeddings, features, need_indices)


# Wrap the annotation in a separate function with input parameter so that it won't be dropped during symbolic trace.
# Please note the input parameter is necessary, though is not used, otherwise this function will be optimized.
@torch.fx.has_side_effect
@torch.fx.wrap
def annotate_embedding_names(
embedding_names: List[str],
dummy: List[List[torch.Tensor]],
) -> List[str]:
return torch.jit.annotate(List[str], embedding_names)


def format_embedding_names_per_rank_per_sharding(
embedding_names_per_rank_per_sharding: List[List[List[str]]],
dummy: List[List[torch.Tensor]],
) -> List[List[List[str]]]:
annotated_embedding_names_per_rank_per_sharding: List[List[List[str]]] = []
for embedding_names_per_rank in embedding_names_per_rank_per_sharding:
annotated_embedding_names_per_rank: List[List[str]] = []
for embedding_names in embedding_names_per_rank:
annotated_embedding_names_per_rank.append(
annotate_embedding_names(embedding_names, dummy)
)
annotated_embedding_names_per_rank_per_sharding.append(
annotated_embedding_names_per_rank
)
return annotated_embedding_names_per_rank_per_sharding


@torch.fx.wrap
def output_jt_dict(
sharding_types: List[str],
Expand Down Expand Up @@ -709,11 +740,14 @@ def output_dist(
# pyre-ignore
sharding_ctx.features_before_input_dist
)

return output_jt_dict(
sharding_types=list(self._sharding_type_to_sharding.keys()),
emb_per_sharding=emb_per_sharding,
features_per_sharding=features_per_sharding,
embedding_names_per_rank_per_sharding=self._embedding_names_per_rank_per_sharding,
embedding_names_per_rank_per_sharding=format_embedding_names_per_rank_per_sharding(
self._embedding_names_per_rank_per_sharding, output
),
need_indices=self._need_indices,
features_before_input_dist_per_sharding=features_before_input_dist_per_sharding,
unbucketize_tensor_idxs_per_sharding=unbucketize_tensor_idxs_per_sharding,
Expand Down

0 comments on commit 1970969

Please sign in to comment.