Skip to content

Commit

Permalink
Merge pull request #934 from calad0i/replace_node_improvment
Browse files Browse the repository at this point in the history
better repalce_node fn
  • Loading branch information
jmitrevs authored Dec 19, 2023
2 parents 7916ff5 + 40d5461 commit 93e759c
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 34 deletions.
27 changes: 20 additions & 7 deletions hls4ml/model/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,13 +577,24 @@ def replace_node(self, old_node, new_node):
new_node (Layer): The new node
"""
prev_node = self.graph.get(old_node.inputs[0])
next_node = next((x for x in self.graph.values() if x.inputs[0] == old_node.outputs[0]), None)
if next_node is not None:
next_node.inputs[0] = new_node.outputs[0]
if prev_node is not None:
if new_node.inputs is None or len(new_node.inputs) == 0: # Check if already rewired
new_node.inputs = [prev_node.outputs[0]]

# fmt: off
assert len(new_node.inputs) == len(old_node.inputs), \
f'{new_node.name} and {old_node.name} have different number of inputs'
assert len(new_node.outputs) == len(old_node.outputs), \
f'{new_node.name} and {old_node.name} have different number of outputs'
# fmt: on

repl = {old_name: new_name for old_name, new_name in zip(old_node.outputs, new_node.outputs)}
repl.update({old_name: new_name for old_name, new_name in zip(old_node.inputs, new_node.inputs)})

for node in self.graph.values():
for i, n in enumerate(node.inputs):
if n in repl:
node.inputs[i] = repl[n]
for i, n in enumerate(node.outputs):
if n in repl:
node.outputs[i] = repl[n]

self.graph = OrderedDict((new_node.name, new_node) if k == old_node.name else (k, v) for k, v in self.graph.items())
self._update_model_outputs()
Expand Down Expand Up @@ -648,7 +659,9 @@ def compile(self):
Users should call this function if they want to use `predict` functionality for simulation.
"""
self.write()
self._compile()

def _compile(self):
lib_name = self.config.backend.compile(self)
if self._top_function_lib is not None:
if platform.system() == "Linux":
Expand Down
27 changes: 0 additions & 27 deletions test/pytest/test_repack_precision.py

This file was deleted.

69 changes: 69 additions & 0 deletions test/pytest/test_repack_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from pathlib import Path

import numpy as np
import pytest
from tensorflow import keras

from hls4ml.converters import convert_from_keras_model

test_root_path = Path(__file__).parent


@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
def test_repack_precision(backend: str):
inp = keras.Input(shape=(3, 3), name='inp')
out = keras.layers.Reshape((3, 3), name='reshape')(inp)
out = keras.layers.Conv1D(2, 2, name='conv')(out)
model = keras.Model(inp, out)

layer_conf = {
'inp': {'Precision': 'fixed<20,10>'},
'reshape': {'Precision': 'fixed<20,10>'},
'conv': {'Precision': 'fixed<20,10>'},
}

hls_config = {'Model': {'Precision': 'fixed<2,1>', 'ReuseFactor': 1}, 'LayerName': layer_conf}

# Repack only happens in io_stream
model_hls = convert_from_keras_model(
model,
backend=backend,
output_dir=str(test_root_path / f'hls4mlprj_repack_precision_{backend}'),
hls_config=hls_config,
io_type='io_stream',
)
model_hls.write() # Not needed for this test, but useful for debugging
assert 'repack_reshape' in model_hls.graph, 'repack_reshape not found in graph'
repack_precision = model_hls.graph['repack_reshape'].attributes['result_t'].precision
assert repack_precision.integer == 10, 'Precision mismatch'
assert repack_precision.fractional == 10, 'Precision mismatch'
assert repack_precision.width == 20, 'Precision mismatch'
assert repack_precision.signed is True, 'Precision mismatch'


@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
@pytest.mark.parametrize('strategy', ['Latency', 'Resource'])
def test_repack(backend: str, strategy: str):
inp1 = keras.Input(shape=(4,), name='inp1')
inp2 = keras.Input(shape=(4,), name='inp2')
r1 = keras.layers.Reshape((2, 2), name='reshape1')(inp1)
r2 = keras.layers.Reshape((2, 2), name='reshape2')(inp2)
out = keras.layers.Concatenate(name='concat')([r1, r2])
model = keras.Model([inp1, inp2], out)

hls_config = {'Model': {'Precision': 'ap_ufixed<8,8>', 'ReuseFactor': 1}, 'Strategy': strategy}
model_hls = convert_from_keras_model(
model,
io_type='io_stream',
backend=backend,
hls_config=hls_config,
output_dir=str(test_root_path / f'hls4mlprj_repack_{backend}_{strategy}'),
)
model_hls.compile()
inp_data = [
np.random.randint(0, 2**8, (100, 4)).astype(np.float32),
np.random.randint(0, 2**8, (100, 4)).astype(np.float32),
]
out_target = np.concatenate([inp_data[0].reshape(100, 2, 2), inp_data[1].reshape(100, 2, 2)], axis=-1)
out_data: np.ndarray = model_hls.predict(inp_data) # type: ignore
assert np.all(out_data.reshape(out_target.shape) == out_target), 'Concatenate failed: mismatching output'

0 comments on commit 93e759c

Please sign in to comment.