-
Notifications
You must be signed in to change notification settings - Fork 466
FIFO depth optimizer for Vitis backend #1037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
vloncar
merged 54 commits into
fastmachinelearning:main
from
steltze:fifo_depth_opt_vitis
Mar 3, 2025
Merged
Changes from all commits
Commits
Show all changes
54 commits
Select commit
Hold shift + click to select a range
5ed7dd2
Init depthwise resource implementation for streaming interface
441c1b8
Init fifo optimization file for vitis backend
8ba0211
Register fifo opt flow in vitis backend
bd59846
Init changes in build_prj.tcl and modification files in vitis writer
ab4c232
Fix vitis writer by adding project.tcl modifer
1e583c5
Fix build_prj.tcl to synthesize with the large FIFOs
127da7c
Fix if statement in cosim tcl script
a88fb0d
Clean the optimizer file
47281c4
Implement the optmized depths parsing
57e8ffe
Implement setter for new depths
26f54c8
Fix csv file name parsing
78a8933
Fix name parsing, deeply hardcoded for now
0b7d4b3
Clean documentation and files
6bbd2a2
Remove unused function
67a00bf
Add documentation and runtime checks
b492308
Add documentation
d0918b5
Include extracting optimized depths
baa81f5
Fix documentation
62a933f
Add function to override Vivado test bench
eb85e41
Fix hls4ml docs
ba47496
Undo changes in sepconv stream
steltze 3ab3b61
Format code
steltze 7ca3438
Run pre-commit
steltze d2a5aa6
Remove unused imports
steltze e76fcde
Run pre-commit
steltze 3328cf9
Remove comment
steltze 01d5851
Fix typo and documentation
steltze a133606
Remove commented out code
steltze a077c33
Init unit test
steltze 8daca5a
Use proper model for unit test to profile fifos
steltze 7c7e4d3
Fix json generator to include before and after depths
steltze 5482221
Set up full test
steltze 3335ffa
Set up exception tests
steltze ede391e
Clean test
steltze 25ca08a
Fix full test
steltze 34949bb
Clean test
steltze 1a1f347
Run precommit
steltze 18f9385
Force the cosimulation to execute twice
steltze e1d80a5
Skip tests
steltze 0c4f958
Update documentation
steltze 92dc849
Fix conflict, use built-in os function
steltze e7b4caa
Setup onnx pytest
steltze a2557fd
Rebase and fix optimizer after main branch changes
steltze 496a7e3
Update documentation
steltze 81b3acd
Run precommit
steltze e86e11d
Fix qonnx test by optimizing away the input quantization
steltze 63aa3df
Run precommit
steltze 76cd4ee
Address review comments
steltze aee5921
Fix c-test for loop
steltze d8b363a
Correct comment
steltze 613099e
Merge branch 'main' into fifo_depth_opt_vitis
JanFSchulte 128f297
Merge remote-tracking branch 'upstream/main' into stelios_vitis_fifo_opt
vloncar 56b7808
Streamlining some changes to better fit the codebase (but mostly cosm…
vloncar 55a72b6
Merge branch 'main' into fifo_depth_opt_vitis
vloncar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
195 changes: 195 additions & 0 deletions
195
hls4ml/backends/vitis/passes/fifo_depth_optimization.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
import json | ||
import zipfile | ||
|
||
from hls4ml.model.optimizer.optimizer import ConfigurableOptimizerPass, ModelOptimizerPass | ||
|
||
|
||
def initialize_large_fifos(model, profiling_fifo_depth): | ||
"""Set all FIFO depths equal to a large value so that they can be profiled. | ||
|
||
Args: | ||
model (ModelGraph): The model to which FIFO depth optimization is applied. | ||
profiling_fifo_depth (int): A large non-negative integer, must be larger than the max expected depth of the FIFOs. | ||
|
||
Returns: | ||
Dict[str, int]: A dictionary containing FIFO names as keys and their initial depths as values is returned for | ||
comparison with the optimized depths. | ||
""" | ||
|
||
# filter all the output variables and keep only the internal FIFOs, excluding output objects that are not FIFOs and the | ||
# input and output FIFOs as they can't be profiled and are implementation dependant i.e AXI Stream, AXI Master or | ||
# connected to another IP | ||
vars_to_profile = { | ||
output_variable_name: output_variable | ||
for output_variable_name, output_variable in model.output_vars.items() | ||
if ('StreamVariable' in str(type(output_variable))) | ||
and output_variable != model.get_output_variables()[0] | ||
and output_variable != model.get_input_variables()[0] | ||
} | ||
|
||
# initialize all the fifos to `profiling_fifo_depth` so that they will be automatically implemented in BRAMs and so | ||
# they will be profiled. Alternatively, "config_dataflow -override_user_fifo_depth profiling_fifo_depth" can be | ||
# used inside build_prj.tcl to override all FIFO depths with the specified value | ||
initial_fifo_depths = {} | ||
for output_variable in vars_to_profile.values(): | ||
if output_variable.pragma: | ||
initial_fifo_depths[output_variable.name] = int(output_variable.pragma[1]) | ||
output_variable.pragma = (output_variable.pragma[0], profiling_fifo_depth) | ||
return initial_fifo_depths | ||
|
||
|
||
def execute_cosim_to_profile_fifos(model): | ||
"""Execute a co-simulation with a test-bench that calls the top function to properly profile the max FIFO depths. | ||
Note that the top function needs to execute **least twice**, so user-provided input must have at least two samples. | ||
|
||
Args: | ||
model (ModelGraph): The model to which FIFO depth optimization is applied. | ||
""" | ||
model.write() | ||
|
||
model.build( | ||
reset=False, | ||
csim=False, | ||
synth=True, | ||
cosim=True, | ||
validation=False, | ||
export=False, | ||
vsynth=False, | ||
fifo_opt=True, | ||
) | ||
|
||
|
||
def get_vitis_optimized_fifo_depths(model): | ||
"""Parse the files generated by the co-simulation to retrieve the optimized depths for the FIFOs. | ||
Attention, only the FIFOs between the layers are profiled! | ||
|
||
Args: | ||
model (ModelGraph): The model to which FIFO depth optimization is applied. | ||
|
||
Returns: | ||
Dict[str, int]: A dictionary that contains the FIFO names as keys and the optimized depths as values. | ||
""" | ||
# channel.zip is generated after the co-simulation and contains the chan_status*.csv files | ||
# in the chan_status*.csv files the max depth achieved during co-simulation can be found at the last (4th) line | ||
path_to_zip_file = ( | ||
model.config.get_output_dir() | ||
+ '/' | ||
+ model.config.get_project_name() | ||
+ '_prj' | ||
+ '/solution1/.autopilot/db/channel_depth_info/' | ||
) | ||
|
||
with zipfile.ZipFile(f'{path_to_zip_file}channel.zip', 'r') as zip_ref: | ||
zip_ref.extractall(path_to_zip_file) | ||
|
||
# the channel_info.csv file contains the mapping of each fifo name (i.e layer4_out_U) to the respective | ||
# chan_status*.csv file | ||
names_file_path = ( | ||
model.config.get_output_dir() | ||
+ '/' | ||
+ model.config.get_project_name() | ||
+ '_prj' | ||
+ '/solution1/.autopilot/db/channel_info.csv' | ||
) | ||
|
||
csv_fifo_depth_files = {} | ||
with open(names_file_path) as names_file: | ||
for line in names_file: | ||
layer_name = line.split(',')[1] | ||
csv_file_name = line.split(',')[3][:-1] | ||
csv_fifo_depth_files[layer_name] = csv_file_name | ||
|
||
optmized_fifo_depths = {} | ||
for layer_name, file_name in csv_fifo_depth_files.items(): | ||
with open(path_to_zip_file + file_name) as chan_status_file: | ||
lines = chan_status_file.readlines() | ||
optmized_fifo_depths[layer_name[:-2]] = int( | ||
lines[-1] | ||
) # remove "_U" from the layer name string and keep the last line of the file that contains the max depth | ||
|
||
return optmized_fifo_depths | ||
|
||
|
||
def generate_depths_file(model, initial_fifo_depths, optimized_fifo_depths): | ||
"""Generate a json file with the names of the FIFOs, the initial depths set by hls4ml and their optimized depths, | ||
for post-processing. The json file is not used by the rest of the pipeline, it is only produced for the user. | ||
|
||
Args: | ||
model (ModelGraph): The model to which FIFO depth optimization is applied. | ||
initial_fifo_depths (Dict[str, int]): A dictionary that contains the FIFO names as keys and the initial | ||
depths as values. | ||
optimized_fifo_depths (Dict[str, int]): A dictionary that contains the FIFO names as keys and the optimized | ||
depths as values. | ||
""" | ||
depths = {} | ||
for fifo_name in initial_fifo_depths.keys(): | ||
depths[fifo_name] = {} | ||
depths[fifo_name]['initial'] = initial_fifo_depths[fifo_name] | ||
depths[fifo_name]['optimized'] = optimized_fifo_depths[fifo_name] | ||
|
||
with open(model.config.get_output_dir() + '/fifo_depths.json', 'w') as f: | ||
json.dump(depths, f, indent=4) | ||
|
||
|
||
def set_optimized_fifo_depths(model, optimized_fifo_depths): | ||
"""Set the new optimized FIFO depths. | ||
|
||
Args: | ||
model (ModelGraph): The model to which FIFO depth optimization is applied. | ||
optimized_fifo_depths (Dict[str, int]): A dictionary that contains the FIFO names as keys and the optimized | ||
depths as values. | ||
""" | ||
|
||
# iterate through the layer output FIFOs | ||
for output_variable in model.output_vars.values(): | ||
if 'StreamVariable' in str(type(output_variable)): | ||
if output_variable.pragma: | ||
|
||
if output_variable.name not in optimized_fifo_depths.keys(): | ||
continue | ||
|
||
filtered_depth = optimized_fifo_depths[output_variable.name] | ||
output_variable.pragma = (output_variable.pragma[0], filtered_depth) | ||
|
||
|
||
class FifoDepthOptimization(ConfigurableOptimizerPass, ModelOptimizerPass): | ||
def __init__(self): | ||
# use `profiling_fifo_depth = 0` to keep the default fifo depth | ||
# consider changing 100_000 either with a very very large value > of any total bram storage space | ||
# or via vitis 2023.2 c-simulation | ||
self.profiling_fifo_depth = 100_000 | ||
|
||
def transform(self, model): | ||
"""Perform FIFO depth optimization between the FIFOs of all layers to reduce resource utilization as the | ||
initial FIFOs set by hls4ml might be larger than required. At the end of the optimization the FIFOs will | ||
have the largest depths achieved during co-simulation without causing any deadlocks between the layers | ||
(producer-consumer), thus no additional delays between the layers. In some cases, this optimization | ||
might lead to bigger FIFOs than initially set by the hls4ml tool in order to prevent deadlocks. | ||
|
||
Args: | ||
model (ModelGraph): The model to which FIFO depth optimization is applied. | ||
|
||
Raises: | ||
ValueError: If the FIFO depth for profiling provided by the user is not a non-negative integer. | ||
RuntimeError: If the IO type is not set to "io_stream". | ||
|
||
Returns: | ||
bool: The execution state of the Optimizer Pass | ||
""" | ||
|
||
if not isinstance(self.profiling_fifo_depth, int) or self.profiling_fifo_depth <= 0: | ||
raise ValueError('The FIFO depth for profiling (profiling_fifo_depth variable) must be a non-negative integer.') | ||
|
||
# check axi-stream or io-stream | ||
if not (model.config.get_config_value('IOType') == 'io_stream'): | ||
raise RuntimeError('To use this optimization you have to set `IOType` field to `io_stream` in the HLS config.') | ||
|
||
initial_fifo_depths = initialize_large_fifos(model, self.profiling_fifo_depth) | ||
execute_cosim_to_profile_fifos(model) | ||
optimized_fifo_depths = get_vitis_optimized_fifo_depths(model) | ||
generate_depths_file(model, initial_fifo_depths, optimized_fifo_depths) | ||
set_optimized_fifo_depths(model, optimized_fifo_depths) | ||
|
||
print('FIFO optimization completed') | ||
|
||
return False |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.