From f7b59942199dd50090b1f6a64568f89b14b3c5bd Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Fri, 26 Jul 2024 18:35:10 -0700 Subject: [PATCH] change ir_custom_op output to list of tensors (#2246) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2246 # context * the original implementation of "ir_custom_op" strategy has logic flaw: * input the sum of dim, and let the op return a contiguous tensor, then split it to multiple tensors * from the dynamic shape (ds) prespective, there is a sum(ds_i) before the op, then another split to (ds_i). the range calculation for these ds are unnecessary and create a lot of complexities * it's better to keep these ds transparent into and out from the op Differential Revision: D53558783 --- torchrec/ir/serializer.py | 12 ++++++------ torchrec/ir/utils.py | 38 ++++++++++++++++++++------------------ 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/torchrec/ir/serializer.py b/torchrec/ir/serializer.py index 3ca5c274b..1c1fb79d2 100644 --- a/torchrec/ir/serializer.py +++ b/torchrec/ir/serializer.py @@ -89,17 +89,17 @@ def ebc_meta_forward( features: KeyedJaggedTensor, ) -> KeyedTensor: batch_size = features.stride() - dim = sum(ebc._lengths_per_embedding) + dims = ebc._lengths_per_embedding arg_list = [ features.values(), features.weights_or_none(), features.lengths_or_none(), features.offsets_or_none(), ] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]` - output = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dim) + outputs = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dims) return KeyedTensor( keys=ebc._embedding_names, - values=output, + values=torch.cat(outputs, dim=1), length_per_key=ebc._lengths_per_embedding, ) @@ -110,17 +110,17 @@ def fpebc_meta_forward( ) -> KeyedTensor: batch_size = features.stride() ebc = fpebc._embedding_bag_collection - dim = sum(ebc._lengths_per_embedding) + dims = ebc._lengths_per_embedding arg_list = [ features.values(), features.weights_or_none(), features.lengths_or_none(), features.offsets_or_none(), ] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]` - output = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dim) + outputs = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dims) return KeyedTensor( keys=ebc._embedding_names, - values=output, + values=torch.cat(outputs, dim=1), length_per_key=ebc._lengths_per_embedding, ) diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index c7e295d6b..35ed5c8ee 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -28,30 +28,32 @@ logger: logging.Logger = logging.getLogger(__name__) -@torch.library.custom_op("torchrec::ir_custom_op", mutates_args={}) -def ir_custom_op_impl( - tensors: List[Optional[torch.Tensor]], batch_size: int, dim: int -) -> torch.Tensor: - device = None +def get_device(tensors: List[Optional[torch.Tensor]]) -> Optional[torch.device]: + """ + Returns the device of the first non-None tensor in the list. + """ for t in tensors: if t is not None: - device = t.device - break - logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dim}) {device}") - return torch.empty(batch_size, dim, device=device) + return t.device + return None + + +@torch.library.custom_op("torchrec::ir_custom_op", mutates_args={}) +def ir_custom_op_impl( + tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int] +) -> List[torch.Tensor]: + device = get_device(tensors) + logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dims}) {device}") + return [torch.empty(batch_size, dim, device=device) for dim in dims] @torch.library.register_fake("torchrec::ir_custom_op") def ir_custom_op_fake( - tensors: List[Optional[torch.Tensor]], batch_size: int, dim: int -) -> torch.Tensor: - device = None - for t in tensors: - if t is not None: - device = t.device - break - logger.info(f"ir_custom_op_fake -> ({batch_size}, {dim}) {device}") - return torch.empty(batch_size, dim, device=device) + tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int] +) -> List[torch.Tensor]: + device = get_device(tensors) + logger.info(f"ir_custom_op_fake -> ({batch_size}, {dims}) {device}") + return [torch.empty(batch_size, dim, device=device) for dim in dims] def encapsulate_ir_modules(