Skip to content

Commit ac9f0a6

Browse files
Wakeupbuddyfacebook-github-bot
authored andcommitted
refactor preproc, support dense in TumHistory layer
Summary: Pull Request resolved: pytorch#11131 Reviewed By: xianjiec Differential Revision: D9358415 fbshipit-source-id: 38bf0e597e22d540d9e985ac8da730f80971d745
1 parent 3e85685 commit ac9f0a6

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

caffe2/python/layers/concat.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,27 @@
1111
)
1212
from future.utils import viewitems
1313
import numpy as np
14+
from collections import defaultdict
1415

1516
import logging
1617
logger = logging.getLogger(__name__)
1718

19+
20+
def get_concatenated_feature_to_index(blobs_to_concat):
21+
concat_feature_to_index = defaultdict(list)
22+
start_pos = 0
23+
for scalar in blobs_to_concat:
24+
num_dims = scalar.dtype.shape[0]
25+
if hasattr(scalar, 'metadata') \
26+
and hasattr(scalar.metadata, 'feature_specs') \
27+
and hasattr(scalar.metadata.feature_specs, 'feature_to_index') \
28+
and isinstance(scalar.metadata.feature_specs.feature_to_index, dict): # noqa B950
29+
for k, v in scalar.metadata.feature_specs.feature_to_index.items():
30+
concat_feature_to_index[k].extend([start_pos + vi for vi in v])
31+
start_pos += num_dims
32+
return dict(concat_feature_to_index) if concat_feature_to_index.keys() else None
33+
34+
1835
class Concat(ModelLayer):
1936
"""
2037
Construct Concat layer
@@ -95,6 +112,19 @@ def __init__(self, model, input_record, axis=1, add_axis=0,
95112
(np.float32, output_dims),
96113
self.get_next_blob_reference('output'))
97114

115+
record_to_concat = input_record.fields.values()
116+
concated_feature_to_index = get_concatenated_feature_to_index(
117+
record_to_concat
118+
)
119+
if concated_feature_to_index:
120+
metadata = schema.Metadata(
121+
feature_specs=schema.FeatureSpec(
122+
feature_to_index=concated_feature_to_index
123+
)
124+
)
125+
self.output_schema.set_metadata(metadata)
126+
127+
98128
def add_ops(self, net):
99129
net.Concat(
100130
self.input_record.field_blobs(),

0 commit comments

Comments
 (0)