Skip to content

Commit

Permalink
Small syntactic changes for dynamo compatibility
Browse files Browse the repository at this point in the history
Summary:
Dynamo has some gaps in support of generators, list comprehension etc.
Avoiding them for now with syntactic changes

Previous diff was reverted because recat was created on the target device from the start.
Then with per-item manipulations it was writing directly to device (which broke freya training as it looks like freya does not support per-item changes).

In this diff recat is created on "cpu", the same as List[int] in original version.

Reviewed By: MatthewWEdwards

Differential Revision: D54192498

fbshipit-source-id: bdfffc5f207fa200c9b207969225c2bcbdf94e7a
  • Loading branch information
Ivan Kobzarev authored and facebook-github-bot committed Feb 28, 2024
1 parent fdfc4e0 commit f1c716a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 15 deletions.
17 changes: 10 additions & 7 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,20 @@ def _get_recat(
if local_split == 0:
return None

recat: List[int] = []
feature_order: List[int] = []
for x in range(num_splits // stagger):
for y in range(stagger):
feature_order.append(x + num_splits // stagger * y)

feature_order: List[int] = [
x + num_splits // stagger * y
for x in range(num_splits // stagger)
for y in range(stagger)
]
recat: torch.Tensor = torch.empty(
local_split * len(feature_order), dtype=torch.int32
)

_i = 0
for i in range(local_split):
for j in feature_order: # range(num_splits):
recat.append(i + j * local_split)
recat[_i] = i + j * local_split
_i += 1

# variable batch size
if batch_size_per_rank is not None and any(
Expand Down
18 changes: 16 additions & 2 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ def _fx_wrap_stride_per_key_per_rank(
)


@torch.fx.wrap
def _fx_wrap_gen_list_n_times(ls: List[str], n: int) -> List[str]:
# Syntax for dynamo (instead of generator kjt.keys() * num_buckets)
ret: List[str] = []
for _ in range(n):
ret.extend(ls)
return ret


def bucketize_kjt_before_all2all(
kjt: KeyedJaggedTensor,
num_buckets: int,
Expand Down Expand Up @@ -143,7 +152,7 @@ def bucketize_kjt_before_all2all(
return (
KeyedJaggedTensor(
# duplicate keys will be resolved by AllToAll
keys=kjt.keys() * num_buckets,
keys=_fx_wrap_gen_list_n_times(kjt.keys(), num_buckets),
values=bucketized_indices,
weights=pos if bucketize_pos else bucketized_weights,
lengths=bucketized_lengths.view(-1),
Expand Down Expand Up @@ -371,7 +380,12 @@ def _wait_impl(self) -> KJTList:
Returns:
KJTList: synced `KJTList`.
"""
kjts = [w.wait() for w in self.awaitables]

# Syntax: no list comprehension usage for dynamo
kjts = []
for w in self.awaitables:
kjts.append(w.wait())

_set_sharding_context_post_a2a(kjts, self.ctx)
return KJTList(kjts)

Expand Down
13 changes: 7 additions & 6 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,12 +856,13 @@ def compute_and_output_dist(
) -> LazyAwaitable[KeyedTensor]:
batch_size_per_feature_pre_a2a = []
awaitables = []
for lookup, dist, sharding_context, features in zip(
self._lookups,
self._output_dists,
ctx.sharding_contexts,
input,
):

# No usage of zip for dynamo
for i in range(len(self._lookups)):
lookup = self._lookups[i]
dist = self._output_dists[i]
sharding_context = ctx.sharding_contexts[i]
features = input[i]
awaitables.append(dist(lookup(features), sharding_context))
if sharding_context:
batch_size_per_feature_pre_a2a.extend(
Expand Down

0 comments on commit f1c716a

Please sign in to comment.