Skip to content

Commit

Permalink
Fix for #905 (#906)
Browse files Browse the repository at this point in the history
* fix multi clones w/ diff outs in stream io

* fix test

---------

Co-authored-by: Javier Duarte <[email protected]>
  • Loading branch information
calad0i and jmduarte authored Nov 16, 2023
1 parent a9fc0fc commit 67c39b3
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 8 deletions.
14 changes: 6 additions & 8 deletions hls4ml/backends/fpga/passes/clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,19 @@ def initialize(self):
class CloneFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(Clone, include_header=clone_include_list)
self.template = None # to be filled once number of clones known

def format(self, node):
params = self._default_function_params(node)
for i, _output in enumerate(node.outputs):
params['output' + str(i + 1)] = node.variables[node.outputs[i]].name

if self.template is None:
self.template = (
'nnet::clone_stream<{input_t}, {output_t}, {size}>({input}, '
+ ', '.join(['{output' + str(i + 1) + '}' for i in range(len(node.outputs))])
+ ');'
)
template = (
'nnet::clone_stream<{input_t}, {output_t}, {size}>({input}, '
+ ', '.join(['{output' + str(i + 1) + '}' for i in range(len(node.outputs))])
+ ');'
)

return self.template.format(**params)
return template.format(**params)


def register_clone(backend):
Expand Down
59 changes: 59 additions & 0 deletions test/pytest/test_stream_multi_clone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import random
from pathlib import Path

import numpy as np
import pytest
import tensorflow as tf
from keras.layers import Add, Dense
from tensorflow import keras

from hls4ml.converters import convert_from_keras_model

test_root_path = Path(__file__).parent


@pytest.fixture(scope='module')
def model():
seed = 42
os.environ['RANDOM_SEED'] = f'{seed}'
np.random.seed(seed)
tf.random.set_seed(seed)
tf.get_logger().setLevel('ERROR')
random.seed(seed)

inp = keras.Input(shape=(10,))
x = Dense(10)(inp)
y = Dense(10)(inp)
z = Dense(10)(inp)
xy = Add()([x, y]) # 5
xy = Add()([xy, y]) # 5
out = Add()([xy, z]) # 5
model = keras.Model(inp, out)
return model


@pytest.fixture(scope='module')
def data():
rng = np.random.RandomState(42)
X = rng.normal(0, 1, (1000, 10))
X = np.clip(X, -16, 15)
return X


@pytest.mark.parametrize('backend', ['Vivado', 'Quartus', 'Vitis'])
def test_multi_clone(model, data, backend: str):
output_dir = str(test_root_path / f'hls4mlprj_stream_multi_clone_{backend}')
hls_config = {'Model': {'Precision': 'fixed<32,5>', 'ReuseFactor': 1}}
model_hls = convert_from_keras_model(
model,
backend=backend,
output_dir=output_dir,
hls_config=hls_config,
io_type='io_stream', # clone only happens with stream io.
)
model_hls.compile()
r_hls = model_hls.predict(data)
r_keras = model(data).numpy()

assert np.allclose(r_hls, r_keras, atol=1e-5, rtol=0)

0 comments on commit 67c39b3

Please sign in to comment.