Skip to content

Commit 3b3f1d7

Browse files
Update _hugectr_config to be more concise
1 parent 8cbaf90 commit 3b3f1d7

File tree

1 file changed

+25
-56
lines changed

1 file changed

+25
-56
lines changed

merlin/systems/dag/ops/hugectr.py

Lines changed: 25 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import os
1818
import pathlib
19+
from typing import Optional
1920

2021
import numpy as np
2122
import tritonclient.grpc.model_config_pb2 as model_config
@@ -222,7 +223,9 @@ def export(self, path, input_schema, output_schema, node_id=None, version=1):
222223
return config
223224

224225

225-
def _hugectr_config(name, hugectr_params, max_batch_size=None):
226+
def _hugectr_config(
227+
name: str, parameters: dict, max_batch_size: Optional[int] = None
228+
) -> model_config.ModelConfig:
226229
"""Create a config for a HugeCTR model.
227230
228231
Parameters
@@ -239,63 +242,29 @@ def _hugectr_config(name, hugectr_params, max_batch_size=None):
239242
config
240243
Dictionary representation of hugectr config.
241244
"""
242-
config = model_config.ModelConfig(name=name, backend="hugectr", max_batch_size=max_batch_size)
243-
244-
config.input.append(
245-
model_config.ModelInput(name="DES", data_type=model_config.TYPE_FP32, dims=[-1])
246-
)
247-
248-
config.input.append(
249-
model_config.ModelInput(name="CATCOLUMN", data_type=model_config.TYPE_INT64, dims=[-1])
250-
)
251-
252-
config.input.append(
253-
model_config.ModelInput(name="ROWINDEX", data_type=model_config.TYPE_INT32, dims=[-1])
245+
config = model_config.ModelConfig(
246+
name=name,
247+
backend="hugectr",
248+
max_batch_size=max_batch_size,
249+
input=[
250+
model_config.ModelInput(name="DES", data_type=model_config.TYPE_FP32, dims=[-1]),
251+
model_config.ModelInput(name="CATCOLUMN", data_type=model_config.TYPE_INT64, dims=[-1]),
252+
model_config.ModelInput(name="ROWINDEX", data_type=model_config.TYPE_INT32, dims=[-1]),
253+
],
254+
output=[
255+
model_config.ModelOutput(name="OUTPUT0", data_type=model_config.TYPE_FP32, dims=[-1])
256+
],
257+
instance_group=[model_config.ModelInstanceGroup(gpus=[0], count=1, kind=1)],
254258
)
255259

256-
config.output.append(
257-
model_config.ModelOutput(name="OUTPUT0", data_type=model_config.TYPE_FP32, dims=[-1])
258-
)
259-
260-
config.instance_group.append(model_config.ModelInstanceGroup(gpus=[0], count=1, kind=1))
261-
262-
config_hugectr = model_config.ModelParameter(string_value=hugectr_params["config"])
263-
config.parameters["config"].CopyFrom(config_hugectr)
264-
265-
gpucache_val = hugectr_params["gpucache"]
266-
gpucache = model_config.ModelParameter(string_value=gpucache_val)
267-
config.parameters["gpucache"].CopyFrom(gpucache)
268-
269-
gpucacheper_val = str(hugectr_params["gpucacheper"])
270-
gpucacheper = model_config.ModelParameter(string_value=gpucacheper_val)
271-
config.parameters["gpucacheper"].CopyFrom(gpucacheper)
272-
273-
label_dim = model_config.ModelParameter(string_value=str(hugectr_params["label_dim"]))
274-
config.parameters["label_dim"].CopyFrom(label_dim)
275-
276-
slots = model_config.ModelParameter(string_value=str(hugectr_params["slots"]))
277-
config.parameters["slots"].CopyFrom(slots)
278-
279-
des_feature_num = model_config.ModelParameter(
280-
string_value=str(hugectr_params["des_feature_num"])
281-
)
282-
config.parameters["des_feature_num"].CopyFrom(des_feature_num)
283-
284-
cat_feature_num = model_config.ModelParameter(
285-
string_value=str(hugectr_params["cat_feature_num"])
286-
)
287-
config.parameters["cat_feature_num"].CopyFrom(cat_feature_num)
288-
289-
max_nnz = model_config.ModelParameter(string_value=str(hugectr_params["max_nnz"]))
290-
config.parameters["max_nnz"].CopyFrom(max_nnz)
291-
292-
embedding_vector_size = model_config.ModelParameter(
293-
string_value=str(hugectr_params["embedding_vector_size"])
294-
)
295-
config.parameters["embedding_vector_size"].CopyFrom(embedding_vector_size)
260+
for parameter_key, parameter_value in parameters.items():
261+
if parameter_value is None:
262+
continue
296263

297-
embeddingkey_long_type_val = hugectr_params["embeddingkey_long_type"]
298-
embeddingkey_long_type = model_config.ModelParameter(string_value=embeddingkey_long_type_val)
299-
config.parameters["embeddingkey_long_type"].CopyFrom(embeddingkey_long_type)
264+
if isinstance(parameter_value, list):
265+
config.parameters[parameter_key].string_value = json.dumps(parameter_value)
266+
elif isinstance(parameter_value, bool):
267+
config.parameters[parameter_key].string_value = str(parameter_value).lower()
268+
config.parameters[parameter_key].string_value = str(parameter_value)
300269

301270
return config

0 commit comments

Comments
 (0)