16
16
import json
17
17
import os
18
18
import pathlib
19
+ from typing import Optional
19
20
20
21
import numpy as np
21
22
import tritonclient .grpc .model_config_pb2 as model_config
@@ -222,7 +223,9 @@ def export(self, path, input_schema, output_schema, node_id=None, version=1):
222
223
return config
223
224
224
225
225
- def _hugectr_config (name , hugectr_params , max_batch_size = None ):
226
+ def _hugectr_config (
227
+ name : str , parameters : dict , max_batch_size : Optional [int ] = None
228
+ ) -> model_config .ModelConfig :
226
229
"""Create a config for a HugeCTR model.
227
230
228
231
Parameters
@@ -239,63 +242,29 @@ def _hugectr_config(name, hugectr_params, max_batch_size=None):
239
242
config
240
243
Dictionary representation of hugectr config.
241
244
"""
242
- config = model_config .ModelConfig (name = name , backend = "hugectr" , max_batch_size = max_batch_size )
243
-
244
- config .input .append (
245
- model_config .ModelInput (name = "DES" , data_type = model_config .TYPE_FP32 , dims = [- 1 ])
246
- )
247
-
248
- config .input .append (
249
- model_config .ModelInput (name = "CATCOLUMN" , data_type = model_config .TYPE_INT64 , dims = [- 1 ])
250
- )
251
-
252
- config .input .append (
253
- model_config .ModelInput (name = "ROWINDEX" , data_type = model_config .TYPE_INT32 , dims = [- 1 ])
245
+ config = model_config .ModelConfig (
246
+ name = name ,
247
+ backend = "hugectr" ,
248
+ max_batch_size = max_batch_size ,
249
+ input = [
250
+ model_config .ModelInput (name = "DES" , data_type = model_config .TYPE_FP32 , dims = [- 1 ]),
251
+ model_config .ModelInput (name = "CATCOLUMN" , data_type = model_config .TYPE_INT64 , dims = [- 1 ]),
252
+ model_config .ModelInput (name = "ROWINDEX" , data_type = model_config .TYPE_INT32 , dims = [- 1 ]),
253
+ ],
254
+ output = [
255
+ model_config .ModelOutput (name = "OUTPUT0" , data_type = model_config .TYPE_FP32 , dims = [- 1 ])
256
+ ],
257
+ instance_group = [model_config .ModelInstanceGroup (gpus = [0 ], count = 1 , kind = 1 )],
254
258
)
255
259
256
- config .output .append (
257
- model_config .ModelOutput (name = "OUTPUT0" , data_type = model_config .TYPE_FP32 , dims = [- 1 ])
258
- )
259
-
260
- config .instance_group .append (model_config .ModelInstanceGroup (gpus = [0 ], count = 1 , kind = 1 ))
261
-
262
- config_hugectr = model_config .ModelParameter (string_value = hugectr_params ["config" ])
263
- config .parameters ["config" ].CopyFrom (config_hugectr )
264
-
265
- gpucache_val = hugectr_params ["gpucache" ]
266
- gpucache = model_config .ModelParameter (string_value = gpucache_val )
267
- config .parameters ["gpucache" ].CopyFrom (gpucache )
268
-
269
- gpucacheper_val = str (hugectr_params ["gpucacheper" ])
270
- gpucacheper = model_config .ModelParameter (string_value = gpucacheper_val )
271
- config .parameters ["gpucacheper" ].CopyFrom (gpucacheper )
272
-
273
- label_dim = model_config .ModelParameter (string_value = str (hugectr_params ["label_dim" ]))
274
- config .parameters ["label_dim" ].CopyFrom (label_dim )
275
-
276
- slots = model_config .ModelParameter (string_value = str (hugectr_params ["slots" ]))
277
- config .parameters ["slots" ].CopyFrom (slots )
278
-
279
- des_feature_num = model_config .ModelParameter (
280
- string_value = str (hugectr_params ["des_feature_num" ])
281
- )
282
- config .parameters ["des_feature_num" ].CopyFrom (des_feature_num )
283
-
284
- cat_feature_num = model_config .ModelParameter (
285
- string_value = str (hugectr_params ["cat_feature_num" ])
286
- )
287
- config .parameters ["cat_feature_num" ].CopyFrom (cat_feature_num )
288
-
289
- max_nnz = model_config .ModelParameter (string_value = str (hugectr_params ["max_nnz" ]))
290
- config .parameters ["max_nnz" ].CopyFrom (max_nnz )
291
-
292
- embedding_vector_size = model_config .ModelParameter (
293
- string_value = str (hugectr_params ["embedding_vector_size" ])
294
- )
295
- config .parameters ["embedding_vector_size" ].CopyFrom (embedding_vector_size )
260
+ for parameter_key , parameter_value in parameters .items ():
261
+ if parameter_value is None :
262
+ continue
296
263
297
- embeddingkey_long_type_val = hugectr_params ["embeddingkey_long_type" ]
298
- embeddingkey_long_type = model_config .ModelParameter (string_value = embeddingkey_long_type_val )
299
- config .parameters ["embeddingkey_long_type" ].CopyFrom (embeddingkey_long_type )
264
+ if isinstance (parameter_value , list ):
265
+ config .parameters [parameter_key ].string_value = json .dumps (parameter_value )
266
+ elif isinstance (parameter_value , bool ):
267
+ config .parameters [parameter_key ].string_value = str (parameter_value ).lower ()
268
+ config .parameters [parameter_key ].string_value = str (parameter_value )
300
269
301
270
return config
0 commit comments