From 44d3cabba1fb64f35c566b57e21701be5d3fc93f Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Sat, 2 Dec 2023 17:30:06 -0800 Subject: [PATCH 1/4] better repalce_node fn --- hls4ml/model/graph.py | 25 +++++++---- test/pytest/test_repack_precision.py | 27 ------------ test/pytest/test_repack_stream.py | 62 ++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 34 deletions(-) delete mode 100644 test/pytest/test_repack_precision.py create mode 100644 test/pytest/test_repack_stream.py diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index c44fd8f02e..8d55568f75 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -573,13 +573,24 @@ def replace_node(self, old_node, new_node): new_node (Layer): The new node """ - prev_node = self.graph.get(old_node.inputs[0]) - next_node = next((x for x in self.graph.values() if x.inputs[0] == old_node.outputs[0]), None) - if next_node is not None: - next_node.inputs[0] = new_node.outputs[0] - if prev_node is not None: - if new_node.inputs is None or len(new_node.inputs) == 0: # Check if already rewired - new_node.inputs = [prev_node.outputs[0]] + + assert len(new_node.inputs) == len( + old_node.inputs + ), f'{new_node.name} and {old_node.name} have different number of inputs' + assert len(new_node.outputs) == len( + old_node.outputs + ), f'{new_node.name} and {old_node.name} have different number of outputs' + + repl = {old_name: new_name for old_name, new_name in zip(old_node.outputs, new_node.outputs)} + repl.update({old_name: new_name for old_name, new_name in zip(old_node.inputs, new_node.inputs)}) + + for node in self.graph.values(): + for i, n in enumerate(node.inputs): + if n in repl: + node.inputs[i] = repl[n] + for i, n in enumerate(node.outputs): + if n in repl: + node.outputs[i] = repl[n] self.graph = OrderedDict((new_node.name, new_node) if k == old_node.name else (k, v) for k, v in self.graph.items()) self._update_model_outputs() diff --git a/test/pytest/test_repack_precision.py b/test/pytest/test_repack_precision.py deleted file mode 100644 index 9ac2fd97f9..0000000000 --- a/test/pytest/test_repack_precision.py +++ /dev/null @@ -1,27 +0,0 @@ -from tensorflow import keras - -from hls4ml.converters import convert_from_keras_model - - -def test_repack_precision(): - inp = keras.Input(shape=(3, 3), name='inp') - out = keras.layers.Reshape((3, 3), name='reshape')(inp) - out = keras.layers.Conv1D(2, 2, name='conv')(out) - model = keras.Model(inp, out) - - layer_conf = { - 'inp': {'Precision': 'fixed<20,10>'}, - 'reshape': {'Precision': 'fixed<20,10>'}, - 'conv': {'Precision': 'fixed<20,10>'}, - } - - hls_config = {'Model': {'Precision': 'fixed<2,1>', 'ReuseFactor': 1}, 'LayerName': layer_conf} - - # Repack only happens in io_stream - model_hls = convert_from_keras_model(model, hls_config=hls_config, io_type='io_stream') - assert 'repack_reshape' in model_hls.graph, 'repack_reshape not found in graph' - repack_precision = model_hls.graph['repack_reshape'].attributes['result_t'].precision - assert repack_precision.integer == 10, 'Precision mismatch' - assert repack_precision.fractional == 10, 'Precision mismatch' - assert repack_precision.width == 20, 'Precision mismatch' - assert repack_precision.signed is True, 'Precision mismatch' diff --git a/test/pytest/test_repack_stream.py b/test/pytest/test_repack_stream.py new file mode 100644 index 0000000000..b239d266c7 --- /dev/null +++ b/test/pytest/test_repack_stream.py @@ -0,0 +1,62 @@ +from pathlib import Path + +import numpy as np +import pytest +from tensorflow import keras + +from hls4ml.converters import convert_from_keras_model + +# test_root_path = Path(__file__).parent +test_root_path = Path('/tmp') + + +def test_repack_precision(): + inp = keras.Input(shape=(3, 3), name='inp') + out = keras.layers.Reshape((3, 3), name='reshape')(inp) + out = keras.layers.Conv1D(2, 2, name='conv')(out) + model = keras.Model(inp, out) + + layer_conf = { + 'inp': {'Precision': 'fixed<20,10>'}, + 'reshape': {'Precision': 'fixed<20,10>'}, + 'conv': {'Precision': 'fixed<20,10>'}, + } + + hls_config = {'Model': {'Precision': 'fixed<2,1>', 'ReuseFactor': 1}, 'LayerName': layer_conf} + + # Repack only happens in io_stream + model_hls = convert_from_keras_model(model, hls_config=hls_config, io_type='io_stream') + assert 'repack_reshape' in model_hls.graph, 'repack_reshape not found in graph' + repack_precision = model_hls.graph['repack_reshape'].attributes['result_t'].precision + assert repack_precision.integer == 10, 'Precision mismatch' + assert repack_precision.fractional == 10, 'Precision mismatch' + assert repack_precision.width == 20, 'Precision mismatch' + assert repack_precision.signed is True, 'Precision mismatch' + + +@pytest.mark.parametrize('backend', ['vivado', 'vitis', 'quartus']) +@pytest.mark.parametrize('strategy', ['Latency', 'Resource']) +def test_repack(backend: str, strategy: str): + inp1 = keras.Input(shape=(4,), name='inp1') + inp2 = keras.Input(shape=(4,), name='inp2') + r1 = keras.layers.Reshape((2, 2), name='reshape1')(inp1) + r2 = keras.layers.Reshape((2, 2), name='reshape2')(inp2) + out = keras.layers.Concatenate(name='concat')([r1, r2]) + model = keras.Model([inp1, inp2], out) + + hls_config = {'Model': {'Precision': 'ap_ufixed<8,8>', 'ReuseFactor': 1}, 'Strategy': strategy} + model_hls = convert_from_keras_model( + model, + io_type='io_stream', + backend='Vivado', + hls_config=hls_config, + output_dir=str(test_root_path / f'{backend}_{strategy}'), + ) + model_hls.compile() + inp_data = [ + np.random.randint(0, 2**8, (100, 4)).astype(np.float32), + np.random.randint(0, 2**8, (100, 4)).astype(np.float32), + ] + out_target = np.concatenate([inp_data[0].reshape(100, 2, 2), inp_data[1].reshape(100, 2, 2)], axis=-1) + out_data: np.ndarray = model_hls.predict(inp_data) # type: ignore + assert np.all(out_data.reshape(out_target.shape) == out_target), 'Concatenate failed: mismatching output' From b821fa7ad34e1d42eac50aee5dc3a957fb4e6c6e Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Sat, 2 Dec 2023 17:46:00 -0800 Subject: [PATCH 2/4] undo some autoformatting --- hls4ml/model/graph.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 8d55568f75..f4614a3582 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -574,12 +574,12 @@ def replace_node(self, old_node, new_node): """ - assert len(new_node.inputs) == len( - old_node.inputs - ), f'{new_node.name} and {old_node.name} have different number of inputs' - assert len(new_node.outputs) == len( - old_node.outputs - ), f'{new_node.name} and {old_node.name} have different number of outputs' + # fmt: off + assert len(new_node.inputs) == len(old_node.inputs), \ + f'{new_node.name} and {old_node.name} have different number of inputs' + assert len(new_node.outputs) == len(old_node.outputs), \ + f'{new_node.name} and {old_node.name} have different number of outputs' + # fmt: on repl = {old_name: new_name for old_name, new_name in zip(old_node.outputs, new_node.outputs)} repl.update({old_name: new_name for old_name, new_name in zip(old_node.inputs, new_node.inputs)}) @@ -655,7 +655,9 @@ def compile(self): Users should call this function if they want to use `predict` functionality for simulation. """ self.write() + self._compile() + def _compile(self): lib_name = self.config.backend.compile(self) if self._top_function_lib is not None: if platform.system() == "Linux": From 4f53d06a0320b42997eb88a7c89cd6348cbb6525 Mon Sep 17 00:00:00 2001 From: Chang Sun Date: Fri, 15 Dec 2023 01:18:17 +0100 Subject: [PATCH 3/4] fix test path --- test/pytest/test_repack_stream.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/pytest/test_repack_stream.py b/test/pytest/test_repack_stream.py index b239d266c7..2a7efaf189 100644 --- a/test/pytest/test_repack_stream.py +++ b/test/pytest/test_repack_stream.py @@ -6,8 +6,7 @@ from hls4ml.converters import convert_from_keras_model -# test_root_path = Path(__file__).parent -test_root_path = Path('/tmp') +test_root_path = Path(__file__).parent def test_repack_precision(): From 40d54617123c1791cb4ab8fe96069c8f6edd201b Mon Sep 17 00:00:00 2001 From: Vladimir Loncar Date: Tue, 19 Dec 2023 17:34:15 +0100 Subject: [PATCH 4/4] Fix output dirs in repack tests --- test/pytest/test_repack_stream.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/test/pytest/test_repack_stream.py b/test/pytest/test_repack_stream.py index 2a7efaf189..12d44a66b7 100644 --- a/test/pytest/test_repack_stream.py +++ b/test/pytest/test_repack_stream.py @@ -9,7 +9,8 @@ test_root_path = Path(__file__).parent -def test_repack_precision(): +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +def test_repack_precision(backend: str): inp = keras.Input(shape=(3, 3), name='inp') out = keras.layers.Reshape((3, 3), name='reshape')(inp) out = keras.layers.Conv1D(2, 2, name='conv')(out) @@ -24,7 +25,14 @@ def test_repack_precision(): hls_config = {'Model': {'Precision': 'fixed<2,1>', 'ReuseFactor': 1}, 'LayerName': layer_conf} # Repack only happens in io_stream - model_hls = convert_from_keras_model(model, hls_config=hls_config, io_type='io_stream') + model_hls = convert_from_keras_model( + model, + backend=backend, + output_dir=str(test_root_path / f'hls4mlprj_repack_precision_{backend}'), + hls_config=hls_config, + io_type='io_stream', + ) + model_hls.write() # Not needed for this test, but useful for debugging assert 'repack_reshape' in model_hls.graph, 'repack_reshape not found in graph' repack_precision = model_hls.graph['repack_reshape'].attributes['result_t'].precision assert repack_precision.integer == 10, 'Precision mismatch' @@ -33,7 +41,7 @@ def test_repack_precision(): assert repack_precision.signed is True, 'Precision mismatch' -@pytest.mark.parametrize('backend', ['vivado', 'vitis', 'quartus']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) @pytest.mark.parametrize('strategy', ['Latency', 'Resource']) def test_repack(backend: str, strategy: str): inp1 = keras.Input(shape=(4,), name='inp1') @@ -47,9 +55,9 @@ def test_repack(backend: str, strategy: str): model_hls = convert_from_keras_model( model, io_type='io_stream', - backend='Vivado', + backend=backend, hls_config=hls_config, - output_dir=str(test_root_path / f'{backend}_{strategy}'), + output_dir=str(test_root_path / f'hls4mlprj_repack_{backend}_{strategy}'), ) model_hls.compile() inp_data = [