diff --git a/torchrec/distributed/quant_state.py b/torchrec/distributed/quant_state.py index efcabb9f1..277f5f4f3 100644 --- a/torchrec/distributed/quant_state.py +++ b/torchrec/distributed/quant_state.py @@ -426,6 +426,8 @@ def sharded_tbes_weights_spec( # "ebc.tbes.1.1.table_1.weight_qscale":WeightSpec("ebc.embedding_bags.table_1.weight_qscale", [500, 0], [500, 2]) # "ebc.tbes.1.1.table_1.weight_qbias":WeightSpec("ebc.embedding_bags.table_1.weight_qbias", [500, 0], [500, 2]) # } + # In the format of ebc.tbes.i.j.table_k.weight, where i is the index of the TBE, j is the index of the embedding bag within TBE i, k is the index of the original table set in the ebc embedding_configs + # e.g. ebc.tbes.1.1.table_1.weight, it represents second embedding bag within the second TBE. This part of weight is from a shard of table_1 ret: Dict[str, WeightSpec] = {} for module_fqn, module in sharded_model.named_modules():