Skip to content

Commit

Permalink
Add VBE support for PositionWeightedModuleCollection (pytorch#2647)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2647

As titled, we have seen wins from position encoding in modeling and would like to leverage PositionWeightedModuleCollection to reduce the cost https://fb.workplace.com/groups/204375858345877/permalink/884618276988295/

I have  a stack locally that show NE equivalence between PositionWeightedModuleCollection and position encoding in modeling
{F1974047979}

Given IG has adopted VBE, I am adding necessary plumbing for VBE in PositionWeightedModuleCollection

**Diffs will land after code freeze but publish first to get the review underway**

Reviewed By: TroyGarden

Differential Revision: D67526005

fbshipit-source-id: bf245d87f4e91998bcd31e2c79f120ec22736ab4
  • Loading branch information
AlbertDachiChen authored and facebook-github-bot committed Jan 6, 2025
1 parent 5255984 commit 504642a
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion torchrec/modules/feature_processor_.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#!/usr/bin/env python3

import abc
from typing import Dict, Optional
from typing import Dict, List, Optional

import torch

Expand Down Expand Up @@ -150,6 +150,13 @@ def get_weights_list(
return torch.cat(weights_list) if weights_list else features.weights_or_none()


@torch.fx.wrap
def get_stride_per_key_per_rank(kjt: KeyedJaggedTensor) -> Optional[List[List[int]]]:
if not kjt.variable_stride_per_key():
return None
return kjt.stride_per_key_per_rank()


class PositionWeightedModuleCollection(FeatureProcessorsCollection, CopyMixIn):
def __init__(
self, max_feature_lengths: Dict[str, int], device: Optional[torch.device] = None
Expand Down Expand Up @@ -193,6 +200,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor:
offsets=features.offsets(),
stride=features.stride(),
length_per_key=features.length_per_key(),
stride_per_key_per_rank=get_stride_per_key_per_rank(features),
)

def copy(self, device: torch.device) -> nn.Module:
Expand Down

0 comments on commit 504642a

Please sign in to comment.