Skip to content

Commit

Permalink
EmbeddingShardingContext fields no default_factory for dynamo (#1712)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1712

dynamo does not support dataclass.field default_factory with Lists, avoiding them for now with specifying all arguments.

Reviewed By: Microve

Differential Revision: D53854370

fbshipit-source-id: b469f4a8acbcddbc2b9dca43765e11bd99429a3a
  • Loading branch information
Ivan Kobzarev authored and facebook-github-bot committed Feb 28, 2024
1 parent f1fb67a commit fdfc4e0
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 14 deletions.
33 changes: 27 additions & 6 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import abc
import copy
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union

import torch
Expand Down Expand Up @@ -614,12 +614,33 @@ def _wait_impl(self) -> Awaitable[ListOfKJTList]:
W = TypeVar("W")


@dataclass
class EmbeddingShardingContext(Multistreamable):
batch_size_per_rank: List[int] = field(default_factory=list)
batch_size_per_rank_per_feature: List[List[int]] = field(default_factory=list)
batch_size_per_feature_pre_a2a: List[int] = field(default_factory=list)
variable_batch_per_feature: bool = False
# Torch Dynamo does not support default_factory=list:
# https://github.com/pytorch/pytorch/issues/120108
# TODO(ivankobzarev) Make this a dataclass once supported

def __init__(
self,
batch_size_per_rank: Optional[List[int]] = None,
batch_size_per_rank_per_feature: Optional[List[List[int]]] = None,
batch_size_per_feature_pre_a2a: Optional[List[int]] = None,
variable_batch_per_feature: bool = False,
) -> None:
super().__init__()
self.batch_size_per_rank: List[int] = (
batch_size_per_rank if batch_size_per_rank is not None else []
)
self.batch_size_per_rank_per_feature: List[List[int]] = (
batch_size_per_rank_per_feature
if batch_size_per_rank_per_feature is not None
else []
)
self.batch_size_per_feature_pre_a2a: List[int] = (
batch_size_per_feature_pre_a2a
if batch_size_per_feature_pre_a2a is not None
else []
)
self.variable_batch_per_feature: bool = variable_batch_per_feature

def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
pass
Expand Down
46 changes: 38 additions & 8 deletions torchrec/distributed/sharding/sequence_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import List, Optional

import torch
Expand All @@ -15,7 +15,6 @@
from torchrec.streamable import Multistreamable


@dataclass
class SequenceShardingContext(EmbeddingShardingContext):
"""
Stores KJTAllToAll context and reuses it in SequenceEmbeddingsAllToAll.
Expand All @@ -32,12 +31,43 @@ class SequenceShardingContext(EmbeddingShardingContext):
input dist.
"""

features_before_input_dist: Optional[KeyedJaggedTensor] = None
input_splits: List[int] = field(default_factory=list)
output_splits: List[int] = field(default_factory=list)
sparse_features_recat: Optional[torch.Tensor] = None
unbucketize_permute_tensor: Optional[torch.Tensor] = None
lengths_after_input_dist: Optional[torch.Tensor] = None
# Torch Dynamo does not support default_factory=list:
# https://github.com/pytorch/pytorch/issues/120108
# TODO(ivankobzarev): Make this a dataclass once supported

def __init__(
self,
# Fields of EmbeddingShardingContext
batch_size_per_rank: Optional[List[int]] = None,
batch_size_per_rank_per_feature: Optional[List[List[int]]] = None,
batch_size_per_feature_pre_a2a: Optional[List[int]] = None,
variable_batch_per_feature: bool = False,
# Fields of SequenceShardingContext
features_before_input_dist: Optional[KeyedJaggedTensor] = None,
input_splits: Optional[List[int]] = None,
output_splits: Optional[List[int]] = None,
sparse_features_recat: Optional[torch.Tensor] = None,
unbucketize_permute_tensor: Optional[torch.Tensor] = None,
lengths_after_input_dist: Optional[torch.Tensor] = None,
) -> None:
super().__init__(
batch_size_per_rank,
batch_size_per_rank_per_feature,
batch_size_per_feature_pre_a2a,
variable_batch_per_feature,
)
self.features_before_input_dist: Optional[
KeyedJaggedTensor
] = features_before_input_dist
self.input_splits: List[int] = input_splits if input_splits is not None else []
self.output_splits: List[int] = (
output_splits if output_splits is not None else []
)
self.sparse_features_recat: Optional[torch.Tensor] = sparse_features_recat
self.unbucketize_permute_tensor: Optional[
torch.Tensor
] = unbucketize_permute_tensor
self.lengths_after_input_dist: Optional[torch.Tensor] = lengths_after_input_dist

def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
if self.features_before_input_dist is not None:
Expand Down

0 comments on commit fdfc4e0

Please sign in to comment.