diff --git a/torchrec/ir/serializer.py b/torchrec/ir/serializer.py index 3ca5c274b..1c1fb79d2 100644 --- a/torchrec/ir/serializer.py +++ b/torchrec/ir/serializer.py @@ -89,17 +89,17 @@ def ebc_meta_forward( features: KeyedJaggedTensor, ) -> KeyedTensor: batch_size = features.stride() - dim = sum(ebc._lengths_per_embedding) + dims = ebc._lengths_per_embedding arg_list = [ features.values(), features.weights_or_none(), features.lengths_or_none(), features.offsets_or_none(), ] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]` - output = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dim) + outputs = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dims) return KeyedTensor( keys=ebc._embedding_names, - values=output, + values=torch.cat(outputs, dim=1), length_per_key=ebc._lengths_per_embedding, ) @@ -110,17 +110,17 @@ def fpebc_meta_forward( ) -> KeyedTensor: batch_size = features.stride() ebc = fpebc._embedding_bag_collection - dim = sum(ebc._lengths_per_embedding) + dims = ebc._lengths_per_embedding arg_list = [ features.values(), features.weights_or_none(), features.lengths_or_none(), features.offsets_or_none(), ] # if want to include the weights: `+ [bag.weight for bag in self.embedding_bags.values()]` - output = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dim) + outputs = torch.ops.torchrec.ir_custom_op(arg_list, batch_size, dims) return KeyedTensor( keys=ebc._embedding_names, - values=output, + values=torch.cat(outputs, dim=1), length_per_key=ebc._lengths_per_embedding, ) diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index c7e295d6b..35ed5c8ee 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -28,30 +28,32 @@ logger: logging.Logger = logging.getLogger(__name__) -@torch.library.custom_op("torchrec::ir_custom_op", mutates_args={}) -def ir_custom_op_impl( - tensors: List[Optional[torch.Tensor]], batch_size: int, dim: int -) -> torch.Tensor: - device = None +def get_device(tensors: List[Optional[torch.Tensor]]) -> Optional[torch.device]: + """ + Returns the device of the first non-None tensor in the list. + """ for t in tensors: if t is not None: - device = t.device - break - logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dim}) {device}") - return torch.empty(batch_size, dim, device=device) + return t.device + return None + + +@torch.library.custom_op("torchrec::ir_custom_op", mutates_args={}) +def ir_custom_op_impl( + tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int] +) -> List[torch.Tensor]: + device = get_device(tensors) + logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dims}) {device}") + return [torch.empty(batch_size, dim, device=device) for dim in dims] @torch.library.register_fake("torchrec::ir_custom_op") def ir_custom_op_fake( - tensors: List[Optional[torch.Tensor]], batch_size: int, dim: int -) -> torch.Tensor: - device = None - for t in tensors: - if t is not None: - device = t.device - break - logger.info(f"ir_custom_op_fake -> ({batch_size}, {dim}) {device}") - return torch.empty(batch_size, dim, device=device) + tensors: List[Optional[torch.Tensor]], batch_size: int, dims: List[int] +) -> List[torch.Tensor]: + device = get_device(tensors) + logger.info(f"ir_custom_op_fake -> ({batch_size}, {dims}) {device}") + return [torch.empty(batch_size, dim, device=device) for dim in dims] def encapsulate_ir_modules(