Skip to content

Commit

Permalink
Merge pull request #919 from vloncar/fetch_example
Browse files Browse the repository at this point in the history
Fix fetching models from example-models repo
  • Loading branch information
jmitrevs authored Nov 15, 2023
2 parents 2cd8333 + 23e73ef commit 6a92562
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 18 deletions.
35 changes: 17 additions & 18 deletions hls4ml/utils/example_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@

from .config import create_config

ORGANIZATION = 'fastmachinelearning'
BRANCH = 'master'


def _load_data_config_avai(model_name):
"""
Check data and configuration availability for each model from this file:
https://github.com/hls-fpga-machine-learning/example-models/blob/master/available_data_config.json
https://github.com/fastmachinelearning/example-models/blob/master/available_data_config.json
"""

link_to_list = (
'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/available_data_config.json'
)
link_to_list = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/available_data_config.json'

temp_file, _ = urlretrieve(link_to_list)

Expand Down Expand Up @@ -73,12 +74,8 @@ def _load_example_data(model_name):
input_file_name = filtered_name + "_input.dat"
output_file_name = filtered_name + "_output.dat"

link_to_input = (
'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/data/' + input_file_name
)
link_to_output = (
'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/data/' + output_file_name
)
link_to_input = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/data/' + input_file_name
link_to_output = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/data/' + output_file_name

urlretrieve(link_to_input, input_file_name)
urlretrieve(link_to_output, output_file_name)
Expand All @@ -91,9 +88,7 @@ def _load_example_config(model_name):

config_name = filtered_name + "_config.yml"

link_to_config = (
'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/config-files/' + config_name
)
link_to_config = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/config-files/' + config_name

# Load the configuration as dictionary from file
urlretrieve(link_to_config, config_name)
Expand All @@ -110,7 +105,7 @@ def fetch_example_model(model_name, backend='Vivado'):
Download an example model (and example data & configuration if available) from github repo to working directory,
and return the corresponding configuration:
https://github.com/hls-fpga-machine-learning/example-models
https://github.com/fastmachinelearning/example-models
Use fetch_example_list() to see all the available models.
Expand All @@ -122,15 +117,18 @@ def fetch_example_model(model_name, backend='Vivado'):
dict: Dictionary that stores the configuration to the model
"""

# Initilize the download link and model type
download_link = 'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/'
# Initialize the download link and model type
download_link = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/'
model_type = None
model_config = None

# Check for model's type to update link
if '.json' in model_name:
model_type = 'keras'
model_config = 'KerasJson'
elif '.h5' in model_name:
model_type = 'keras'
model_config = 'KerasH5'
elif '.pt' in model_name:
model_type = 'pytorch'
model_config = 'PytorchModel'
Expand Down Expand Up @@ -158,11 +156,12 @@ def fetch_example_model(model_name, backend='Vivado'):

if _config_is_available(model_name):
config = _load_example_config(model_name)
config[model_config] = model_name # Ensure that paths are correct
else:
config = _create_default_config(model_name, model_config, backend)

# If the model is a keras model then have to download its weight file as well
if model_type == 'keras':
if model_type == 'keras' and '.json' in model_name:
model_weight_name = model_name[:-5] + "_weights.h5"

download_link_weight = download_link + model_type + '/' + model_weight_name
Expand All @@ -174,7 +173,7 @@ def fetch_example_model(model_name, backend='Vivado'):


def fetch_example_list():
link_to_list = 'https://raw.githubusercontent.com/hls-fpga-machine-learning/example-models/master/available_models.json'
link_to_list = f'https://raw.githubusercontent.com/{ORGANIZATION}/example-models/{BRANCH}/available_models.json'

temp_file, _ = urlretrieve(link_to_list)

Expand Down
32 changes: 32 additions & 0 deletions test/pytest/test_fetch_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import ast
import io
from contextlib import redirect_stdout
from pathlib import Path

import pytest

import hls4ml

test_root_path = Path(__file__).parent


@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
def test_fetch_example_utils(backend):
f = io.StringIO()
with redirect_stdout(f):
hls4ml.utils.fetch_example_list()
out = f.getvalue()

model_list = ast.literal_eval(out) # Check if we indeed got a dictionary back

assert 'qkeras_mnist_cnn.json' in model_list['keras']

# This model has an example config that is also downloaded. Stored configurations don't set "Backend" value.
config = hls4ml.utils.fetch_example_model('qkeras_mnist_cnn.json', backend=backend)
config['KerasJson'] = 'qkeras_mnist_cnn.json'
config['KerasH5']
config['Backend'] = backend
config['OutputDir'] = str(test_root_path / f'hls4mlprj_fetch_example_{backend}')

hls_model = hls4ml.converters.keras_to_hls(config)
hls_model.compile() # For now, it is enough if it compiles, we're only testing downloading works as expected

0 comments on commit 6a92562

Please sign in to comment.