Skip to content

Commit 4d99847

Browse files
Add slot_sizes parameter
1 parent c923a27 commit 4d99847

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

merlin/systems/dag/ops/hugectr.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def export(self, path, input_schema, output_schema, node_id=None, params=None, v
176176
for layer in model_json["layers"]
177177
if layer["type"] == "DistributedSlotSparseEmbeddingHash"
178178
]
179+
full_slots = [x["sparse_embedding_hparam"]["slot_size_array"] for x in sparse_layers]
179180
num_cat_columns = sum(x["slot_num"] for x in data_layer["sparse"])
180181
vec_size = [x["sparse_embedding_hparam"]["embedding_vec_size"] for x in sparse_layers]
181182

@@ -214,7 +215,7 @@ def export(self, path, input_schema, output_schema, node_id=None, params=None, v
214215
self.hugectr_params["embedding_vector_size"] = vec_size[0]
215216
self.hugectr_params["slots"] = num_cat_columns
216217
self.hugectr_params["label_dim"] = data_layer["label"]["label_dim"]
217-
218+
self.hugectr_params["slot_sizes"] = full_slots
218219
config = _hugectr_config(node_name, self.hugectr_params, max_batch_size=self.max_batch_size)
219220

220221
with open(os.path.join(node_export_path, "config.pbtxt"), "w", encoding="utf-8") as o:

0 commit comments

Comments
 (0)