Skip to content

Commit

Permalink
Support sending using lengths to TBE instead of just offsets (pytorch…
Browse files Browse the repository at this point in the history
…#2557)

Summary: Pull Request resolved: pytorch/torchrec#2557

Differential Revision: D64906767
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Nov 26, 2024
1 parent cffa05a commit baa759b
Showing 1 changed file with 11 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ def _update_tablewise_cache_miss(

self.table_wise_cache_miss[i] += miss_count

def forward(
def _forward_impl(
self,
indices: Tensor,
offsets: Tensor,
Expand Down Expand Up @@ -1016,6 +1016,16 @@ def forward(
fp8_exponent_bias=self.fp8_exponent_bias,
)

def forward(
self,
indices: Tensor,
offsets: Tensor,
per_sample_weights: Optional[Tensor] = None,
) -> Tensor:
return self._forward_impl(
indices=indices, offsets=offsets, per_sample_weights=per_sample_weights
)

def initialize_logical_weights_placements_and_offsets(
self,
) -> None:
Expand Down

0 comments on commit baa759b

Please sign in to comment.