|
11 | 11 | )
|
12 | 12 | from future.utils import viewitems
|
13 | 13 | import numpy as np
|
| 14 | +from collections import defaultdict |
14 | 15 |
|
15 | 16 | import logging
|
16 | 17 | logger = logging.getLogger(__name__)
|
17 | 18 |
|
| 19 | + |
| 20 | +def get_concatenated_feature_to_index(blobs_to_concat): |
| 21 | + concat_feature_to_index = defaultdict(list) |
| 22 | + start_pos = 0 |
| 23 | + for scalar in blobs_to_concat: |
| 24 | + num_dims = scalar.dtype.shape[0] |
| 25 | + if hasattr(scalar, 'metadata') \ |
| 26 | + and hasattr(scalar.metadata, 'feature_specs') \ |
| 27 | + and hasattr(scalar.metadata.feature_specs, 'feature_to_index') \ |
| 28 | + and isinstance(scalar.metadata.feature_specs.feature_to_index, dict): # noqa B950 |
| 29 | + for k, v in scalar.metadata.feature_specs.feature_to_index.items(): |
| 30 | + concat_feature_to_index[k].extend([start_pos + vi for vi in v]) |
| 31 | + start_pos += num_dims |
| 32 | + return dict(concat_feature_to_index) if concat_feature_to_index.keys() else None |
| 33 | + |
| 34 | + |
18 | 35 | class Concat(ModelLayer):
|
19 | 36 | """
|
20 | 37 | Construct Concat layer
|
@@ -95,6 +112,19 @@ def __init__(self, model, input_record, axis=1, add_axis=0,
|
95 | 112 | (np.float32, output_dims),
|
96 | 113 | self.get_next_blob_reference('output'))
|
97 | 114 |
|
| 115 | + record_to_concat = input_record.fields.values() |
| 116 | + concated_feature_to_index = get_concatenated_feature_to_index( |
| 117 | + record_to_concat |
| 118 | + ) |
| 119 | + if concated_feature_to_index: |
| 120 | + metadata = schema.Metadata( |
| 121 | + feature_specs=schema.FeatureSpec( |
| 122 | + feature_to_index=concated_feature_to_index |
| 123 | + ) |
| 124 | + ) |
| 125 | + self.output_schema.set_metadata(metadata) |
| 126 | + |
| 127 | + |
98 | 128 | def add_ops(self, net):
|
99 | 129 | net.Concat(
|
100 | 130 | self.input_record.field_blobs(),
|
|
0 commit comments