diff --git a/hls4ml/utils/example_models.py b/hls4ml/utils/example_models.py index 5fefbd662b..657f14325b 100644 --- a/hls4ml/utils/example_models.py +++ b/hls4ml/utils/example_models.py @@ -6,17 +6,18 @@ from .config import create_config +ORGANIZATION = 'fastmachinelearning' +BRANCH = 'master' + def _load_data_config_avai(model_name): """ Check data and configuration availability for each model from this file: - https://github.com/hls-fpga-machine-learning/example-models/blob/master/available_data_config.json + https://github.com/fastmachinelearning/example-models/blob/master/available_data_config.json """ - link_to_list = ( - 'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/available_data_config.json' - ) + link_to_list = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/available_data_config.json' temp_file, _ = urlretrieve(link_to_list) @@ -73,12 +74,8 @@ def _load_example_data(model_name): input_file_name = filtered_name + "_input.dat" output_file_name = filtered_name + "_output.dat" - link_to_input = ( - 'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/data/' + input_file_name - ) - link_to_output = ( - 'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/data/' + output_file_name - ) + link_to_input = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/data/' + input_file_name + link_to_output = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/data/' + output_file_name urlretrieve(link_to_input, input_file_name) urlretrieve(link_to_output, output_file_name) @@ -91,9 +88,7 @@ def _load_example_config(model_name): config_name = filtered_name + "_config.yml" - link_to_config = ( - 'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/config-files/' + config_name - ) + link_to_config = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/config-files/' + config_name # Load the configuration as dictionary from file urlretrieve(link_to_config, config_name) @@ -110,7 +105,7 @@ def fetch_example_model(model_name, backend='Vivado'): Download an example model (and example data & configuration if available) from github repo to working directory, and return the corresponding configuration: - https://github.com/hls-fpga-machine-learning/example-models + https://github.com/fastmachinelearning/example-models Use fetch_example_list() to see all the available models. @@ -122,8 +117,8 @@ def fetch_example_model(model_name, backend='Vivado'): dict: Dictionary that stores the configuration to the model """ - # Initilize the download link and model type - download_link = 'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/' + # Initialize the download link and model type + download_link = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/' model_type = None model_config = None @@ -131,6 +126,9 @@ def fetch_example_model(model_name, backend='Vivado'): if '.json' in model_name: model_type = 'keras' model_config = 'KerasJson' + elif '.h5' in model_name: + model_type = 'keras' + model_config = 'KerasH5' elif '.pt' in model_name: model_type = 'pytorch' model_config = 'PytorchModel' @@ -158,11 +156,12 @@ def fetch_example_model(model_name, backend='Vivado'): if _config_is_available(model_name): config = _load_example_config(model_name) + config[model_config] = model_name # Ensure that paths are correct else: config = _create_default_config(model_name, model_config, backend) # If the model is a keras model then have to download its weight file as well - if model_type == 'keras': + if model_type == 'keras' and '.json' in model_name: model_weight_name = model_name[:-5] + "_weights.h5" download_link_weight = download_link + model_type + '/' + model_weight_name @@ -174,7 +173,7 @@ def fetch_example_model(model_name, backend='Vivado'): def fetch_example_list(): - link_to_list = 'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/available_models.json' + link_to_list = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/available_models.json' temp_file, _ = urlretrieve(link_to_list) diff --git a/test/pytest/test_fetch_example.py b/test/pytest/test_fetch_example.py new file mode 100644 index 0000000000..6e640a94a0 --- /dev/null +++ b/test/pytest/test_fetch_example.py @@ -0,0 +1,32 @@ +import ast +import io +from contextlib import redirect_stdout +from pathlib import Path + +import pytest + +import hls4ml + +test_root_path = Path(__file__).parent + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +def test_fetch_example_utils(backend): + f = io.StringIO() + with redirect_stdout(f): + hls4ml.utils.fetch_example_list() + out = f.getvalue() + + model_list = ast.literal_eval(out) # Check if we indeed got a dictionary back + + assert 'qkeras_mnist_cnn.json' in model_list['keras'] + + # This model has an example config that is also downloaded. Stored configurations don't set "Backend" value. + config = hls4ml.utils.fetch_example_model('qkeras_mnist_cnn.json', backend=backend) + config['KerasJson'] = 'qkeras_mnist_cnn.json' + config['KerasH5'] + config['Backend'] = backend + config['OutputDir'] = str(test_root_path / f'hls4mlprj_fetch_example_{backend}') + + hls_model = hls4ml.converters.keras_to_hls(config) + hls_model.compile() # For now, it is enough if it compiles, we're only testing downloading works as expected