Skip to content

Commit

Permalink
change ir_custom_op output to list of tensors (pytorch#2246)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2246

# context
* the original implementation of "ir_custom_op" strategy has logic flaw:
* input the sum of dim, and let the op return a contiguous tensor, then split it to multiple tensors
* from the dynamic shape (ds) prespective, there is a sum(ds_i) before the op, then another split to (ds_i). the range calculation for these ds are unnecessary and create a lot of complexities
* it's better to keep these ds transparent into and out from the op

Differential Revision: D53558783
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jul 27, 2024
1 parent c89e9df commit f7b5994
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 24 deletions.
12 changes: 6 additions & 6 deletions torchrec/ir/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,17 @@ def ebc_meta_forward(
features: KeyedJaggedTensor,
) -> KeyedTensor:
batch_size = features.stride()
dim = sum(ebc._lengths_per_embedding)
dims = ebc._lengths_per_embedding
arg_list = [
features.values(),
features.weights_or_none(),
features.lengths_or_none(),
features.offsets_or_none(),
] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]`
output = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dim)
outputs = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dims)
return KeyedTensor(
keys=ebc._embedding_names,
values=output,
values=torch.cat(outputs, dim=1),
length_per_key=ebc._lengths_per_embedding,
)

Expand All @@ -110,17 +110,17 @@ def fpebc_meta_forward(
) -> KeyedTensor:
batch_size = features.stride()
ebc = fpebc._embedding_bag_collection
dim = sum(ebc._lengths_per_embedding)
dims = ebc._lengths_per_embedding
arg_list = [
features.values(),
features.weights_or_none(),
features.lengths_or_none(),
features.offsets_or_none(),
] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]`
output = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dim)
outputs = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dims)
return KeyedTensor(
keys=ebc._embedding_names,
values=output,
values=torch.cat(outputs, dim=1),
length_per_key=ebc._lengths_per_embedding,
)

Expand Down
38 changes: 20 additions & 18 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,32 @@
logger: logging.Logger = logging.getLogger(__name__)


@torch.library.custom_op("torchrec::ir_custom_op", mutates_args={})
def ir_custom_op_impl(
tensors: List[Optional[torch.Tensor]], batch_size: int, dim: int
) -> torch.Tensor:
device = None
def get_device(tensors: List[Optional[torch.Tensor]]) -> Optional[torch.device]:
"""
Returns the device of the first non-None tensor in the list.
"""
for t in tensors:
if t is not None:
device = t.device
break
logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dim}) {device}")
return torch.empty(batch_size, dim, device=device)
return t.device
return None


@torch.library.custom_op("torchrec::ir_custom_op", mutates_args={})
def ir_custom_op_impl(
tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int]
) -> List[torch.Tensor]:
device = get_device(tensors)
logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dims}) {device}")
return [torch.empty(batch_size, dim, device=device) for dim in dims]


@torch.library.register_fake("torchrec::ir_custom_op")
def ir_custom_op_fake(
tensors: List[Optional[torch.Tensor]], batch_size: int, dim: int
) -> torch.Tensor:
device = None
for t in tensors:
if t is not None:
device = t.device
break
logger.info(f"ir_custom_op_fake -> ({batch_size}, {dim}) {device}")
return torch.empty(batch_size, dim, device=device)
tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int]
) -> List[torch.Tensor]:
device = get_device(tensors)
logger.info(f"ir_custom_op_fake -> ({batch_size}, {dims}) {device}")
return [torch.empty(batch_size, dim, device=device) for dim in dims]


def encapsulate_ir_modules(
Expand Down

0 comments on commit f7b5994

Please sign in to comment.