Skip to content

Commit

Permalink
Add fused_params to mch sharders (#1649)
Browse files Browse the repository at this point in the history
Summary:

Adding fused_params to zch. The main reason is so cache load factor passed through zch sharder.fused_params can be represented in planner stats.

Reviewed By: dstaay-fb

Differential Revision: D52921362
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Feb 5, 2024
1 parent 02771aa commit e57f1f0
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torchrec/distributed/mc_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Dict, Iterator, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, TypeVar, Union

import torch
from torch.autograd.profiler import record_function
Expand Down Expand Up @@ -276,3 +276,8 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
set(self._mc_sharder.sharding_types(compute_device_type)),
)
)

@property
def fused_params(self) -> Optional[Dict[str, Any]]:
# TODO: to be deprecate after planner get cache_load_factor from ParameterConstraints
return self._e_sharder.fused_params

0 comments on commit e57f1f0

Please sign in to comment.