Skip to content

Commit

Permalink
Added ability to set decoder outputs via config for opennmt models (o…
Browse files Browse the repository at this point in the history
  • Loading branch information
Anna Grebneva authored Feb 21, 2022
1 parent 0590913 commit 2483e0c
Showing 1 changed file with 15 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _process(self, output_callback, calculate_metrics, progress_reporter, metric
metrics_result = self._get_metrics_result(batch_input_ids, batch_annotation, batch_prediction,
calculate_metrics)
if output_callback:
output_callback(batch_raw_prediction[0], metrics_result=metrics_result,
output_callback(batch_raw_prediction, metrics_result=metrics_result,
element_identifiers=batch_identifiers, dataset_indices=batch_input_ids)
self._update_progress(progress_reporter, metric_config, batch_id, len(batch_prediction), csv_file)

Expand Down Expand Up @@ -427,6 +427,13 @@ class DecoderDLSDKModel(CommonOpenNMTDecoder, CommonDLSDKModel):
state_inputs = ['h_0', 'c_0', 'memory', 'mem_len', 'input_feed.1']
state_outputs = ['h_1', 'c_1', '', '', 'input_feed']

def __init__(self, network_info, launcher, suffix, delayed_model_loading=False):
if network_info.get('outputs'):
self.output_layers = network_info['outputs']
if network_info.get('return_outputs'):
self.return_layers = network_info['return_outputs']
super().__init__(network_info, launcher, suffix, delayed_model_loading)


class GeneratorDLSDKModel(CommonDLSDKModel):
default_model_suffix = 'generator'
Expand All @@ -450,6 +457,13 @@ class DecoderOVModel(CommonOpenNMTDecoder, CommonOVModel):
state_inputs = ['h_0', 'c_0', 'memory', 'mem_len', 'input_feed.1']
state_outputs = ['h_1/sink_port_0', 'c_1/sink_port_0', '', '', 'input_feed/sink_port_0']

def __init__(self, network_info, launcher, suffix, delayed_model_loading=False):
if network_info.get('outputs'):
self.output_layers = network_info['outputs']
if network_info.get('return_outputs'):
self.return_layers = network_info['return_outputs']
super().__init__(network_info, launcher, suffix, delayed_model_loading)


class GeneratorOVModel(CommonOVModel):
default_model_suffix = 'generator'
Expand Down

0 comments on commit 2483e0c

Please sign in to comment.