Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIFO depth optimizer for Vitis backend #1037

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
5ed7dd2
Init depthwise resource implementation for streaming interface
Jul 11, 2024
441c1b8
Init fifo optimization file for vitis backend
Jul 15, 2024
8ba0211
Register fifo opt flow in vitis backend
Jul 15, 2024
bd59846
Init changes in build_prj.tcl and modification files in vitis writer
Jul 15, 2024
ab4c232
Fix vitis writer by adding project.tcl modifer
Jul 15, 2024
1e583c5
Fix build_prj.tcl to synthesize with the large FIFOs
Jul 15, 2024
127da7c
Fix if statement in cosim tcl script
Jul 16, 2024
a88fb0d
Clean the optimizer file
Jul 16, 2024
47281c4
Implement the optmized depths parsing
Jul 16, 2024
57e8ffe
Implement setter for new depths
Jul 16, 2024
26f54c8
Fix csv file name parsing
Jul 16, 2024
78a8933
Fix name parsing, deeply hardcoded for now
Jul 16, 2024
0b7d4b3
Clean documentation and files
Jul 17, 2024
6bbd2a2
Remove unused function
Jul 17, 2024
67a00bf
Add documentation and runtime checks
Jul 17, 2024
b492308
Add documentation
Jul 17, 2024
d0918b5
Include extracting optimized depths
Jul 17, 2024
baa81f5
Fix documentation
Jul 17, 2024
62a933f
Add function to override Vivado test bench
Jul 18, 2024
eb85e41
Fix hls4ml docs
Jul 18, 2024
ba47496
Undo changes in sepconv stream
steltze Jul 18, 2024
3ab3b61
Format code
steltze Jul 18, 2024
7ca3438
Run pre-commit
steltze Jul 18, 2024
d2a5aa6
Remove unused imports
steltze Jul 18, 2024
e76fcde
Run pre-commit
steltze Jul 18, 2024
3328cf9
Remove comment
steltze Jul 18, 2024
01d5851
Fix typo and documentation
steltze Jul 19, 2024
a133606
Remove commented out code
steltze Jul 29, 2024
a077c33
Init unit test
steltze Jul 29, 2024
8daca5a
Use proper model for unit test to profile fifos
steltze Jul 30, 2024
7c7e4d3
Fix json generator to include before and after depths
steltze Jul 30, 2024
5482221
Set up full test
steltze Jul 31, 2024
3335ffa
Set up exception tests
steltze Jul 31, 2024
ede391e
Clean test
steltze Jul 31, 2024
25ca08a
Fix full test
steltze Jul 31, 2024
34949bb
Clean test
steltze Jul 31, 2024
1a1f347
Run precommit
steltze Jul 31, 2024
18f9385
Force the cosimulation to execute twice
steltze Jul 31, 2024
e1d80a5
Skip tests
steltze Jul 31, 2024
0c4f958
Update documentation
steltze Aug 1, 2024
92dc849
Fix conflict, use built-in os function
steltze Aug 2, 2024
e7b4caa
Setup onnx pytest
steltze Nov 27, 2024
a2557fd
Rebase and fix optimizer after main branch changes
steltze Dec 2, 2024
496a7e3
Update documentation
steltze Dec 2, 2024
81b3acd
Run precommit
steltze Dec 2, 2024
e86e11d
Fix qonnx test by optimizing away the input quantization
steltze Dec 11, 2024
63aa3df
Run precommit
steltze Dec 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions docs/advanced/fifo_depth.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ First, we can define a simple neural network in Keras
from tensorflow.keras.models import Sequential

model = Sequential()
model.add(Dense(64, input_shape=(16,), name='fc1', activation='relu')
model.add(Dense(64, input_shape=(16,), name='fc1', activation='relu'))
model.add(Dense(32, name='fc2', activation='relu'))
model.add(Dense(32, name='fc3', activation='relu'))
model.add(Dense(5, name='fc3', activation='softmax'))
model.add(Dense(5, name='fc4', activation='softmax'))

Then, we can convert the model, including the flow

Expand All @@ -47,3 +47,17 @@ Then, we can convert the model, including the flow
hls_model.build(reset=False, csim=True, synth=True, cosim=True)

For more details and results, see `H. Borras et al., "Open-source FPGA-ML codesign for the MLPerf Tiny Benchmark" (2022) <https://arxiv.org/abs/2206.11791>`_.

Similarly, the FIFO buffers can be optimized while using the `Vitis` backend with the following changes

.. code-block:: Python

config['Flows'] = ['vitis:fifo_depth_optimization']
hls4ml.model.optimizer.get_optimizer('vitis:fifo_depth_optimization').configure(profiling_fifo_depth=100_000)

hls_model = hls4ml.converters.convert_from_keras_model(model,
io_type='io_stream',
hls_config=config,
output_dir='hls4mlprj_fifo_depth_opt',
part='xc7z020clg400-1',
backend='Vitis')
247 changes: 247 additions & 0 deletions hls4ml/backends/vitis/passes/fifo_depth_optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
import json
import os

from hls4ml.model.optimizer.optimizer import ConfigurableOptimizerPass, ModelOptimizerPass


def initialize_large_fifos(model, profiling_fifo_depth):
"""Set all FIFO depths equal to a large value so that they can be profiled.

Args:
model (ModelGraph): The model to which FIFO depth optimization is applied.
profiling_fifo_depth (int): A large non-negative integer, must be larger than the max expected depth of the FIFOs.

Returns:
Dict[str, int]: A dictionary containing FIFO names as keys and their initial depths as values is returned for
comparison with the optimized depths.
"""

# filter all the output variables and keep only the internal FIFOs, excluding output objects that are not FIFOs and the
# input and output FIFOs as they can't be profiled and are implementation dependant i.e AXI Stream, AXI Master or
# connected to another IP
vars_to_profile = {
output_variable_name: output_variable
for output_variable_name, output_variable in model.output_vars.items()
if ("VivadoStreamVariable" in str(type(output_variable)))
and output_variable != model.get_output_variables()[0]
and output_variable != model.get_input_variables()[0]
}

# initialize all the fifos to `profiling_fifo_depth` so that they will be automatically implemented in BRAMs and so
# they will be profiled. Alternatively, "config_dataflow -override_user_fifo_depth profiling_fifo_depth" can be
# used inside build_prj.tcl to override all FIFO depths with the specified value
initial_fifo_depths = {}
for output_variable in vars_to_profile.values():
if output_variable.pragma:
initial_fifo_depths[output_variable.name] = int(output_variable.pragma[1])
output_variable.pragma = (output_variable.pragma[0], profiling_fifo_depth)
return initial_fifo_depths


def override_test_bench(model):
"""In order for the FIFO depth profiling to produce correct results, it is necessary for the cosimulation to
call the top function - Vitis IP at **least twice**. The test bench produced by the Vivado Writer is
overwritten by adding a for-loop over the top function.

Args:
model (ModelGraph): The model to which FIFO depth optimization is applied.
"""
indent = " "
path_to_old_test_bench = f"{model.config.get_output_dir()}/{model.config.get_project_name()}_test.cpp"
path_to_new_test_bench = f"{model.config.get_output_dir()}/{model.config.get_project_name()}_new_test.cpp"

newline = ""
second_part_of_testbench = False
with open(path_to_old_test_bench) as old_test_bench:
file_iterator = iter(old_test_bench)
for line in file_iterator:

if "// hls-fpga-machine-learning insert zero" in line:
newline += indent + indent + "const unsigned BATCH_SIZE = 2;\n"
newline += (
indent
+ indent
+ "for(unsigned batch_iteration = 0; batch_iteration < BATCH_SIZE; ++batch_iteration) {\n"
)
newline += line
second_part_of_testbench = True
elif ("// hls-fpga-machine-learning insert tb-output" in line) and second_part_of_testbench:
newline += line
newline += next(file_iterator)
newline += indent + "}\n"
else:
newline += line

with open(path_to_new_test_bench, "w+") as new_test_bench:
new_test_bench.write(newline)

# replace the old test bench with the new test bench that includes a for-loop
os.replace(path_to_new_test_bench, path_to_old_test_bench)
return


def execute_cosim_to_profile_fifos(model):
"""Execute a cosimulation with a testh bench that calls the top function - Vitis IP at **least twice**,
to properly profile the max FIFO depths. The function will momentarily replace the initial test bench
with a suitable one for the optimization, and after the optimizer pass, the original test bench reinitialized.

Args:
model (ModelGraph): The model to which FIFO depth optimization is applied.
"""
model.write()

override_test_bench(model)

model.build(
reset=False,
csim=False,
synth=True,
cosim=True,
validation=False,
export=False,
vsynth=False,
fifo_opt=True,
)

return


def get_vitis_optimized_fifo_depths(model):
"""Parse the files generated by the cosimulation to retrieve the optimized depths for the FIFOs.
Attention, only the FIFOs between the layers are profiled!

Args:
model (ModelGraph): The model to which FIFO depth optimization is applied.

Returns:
Dict[str, int]: A dictionary that contains the FIFO names as keys and the optimized depths as values.
"""
# channel.zip is generated after the cosimulation and contains the chan_status*.csv files
# in the chan_status*.csv files the max depth achieved during cosimulation can be found at the last (4th) line
path_to_zip_file = (
model.config.get_output_dir()
+ "/"
+ model.config.get_project_name()
+ "_prj"
+ "/solution1/.autopilot/db/channel_depth_info/"
)

os.system(f"unzip -q -o {path_to_zip_file}channel.zip -d {path_to_zip_file}")

# the channel_info.csv file contains the mapping of each fifo name (i.e layer4_out_U) to the respective
# chan_status*.csv file
names_file_path = (
model.config.get_output_dir()
+ "/"
+ model.config.get_project_name()
+ "_prj"
+ "/solution1/.autopilot/db/channel_info.csv"
)

csv_fifo_depth_files = {}
with open(names_file_path) as names_file:
for line in names_file:
layer_name = line.split(",")[1]
csv_file_name = line.split(",")[3][:-1]
csv_fifo_depth_files[layer_name] = csv_file_name

optmized_fifo_depths = {}
for layer_name, file_name in csv_fifo_depth_files.items():
with open(path_to_zip_file + file_name) as chan_status_file:
lines = chan_status_file.readlines()
optmized_fifo_depths[layer_name[:-2]] = int(
lines[-1]
) # remove "_U" from the layer name string and keep the last line of the file that contains the max depth

return optmized_fifo_depths


def generate_depths_file(model, initial_fifo_depths, optimized_fifo_depths):
"""Generate a json file with the names of the FIFOs, the initial depths set by hls4ml and their optimized depths,
for post-processing. The json file is not used by the rest of the pipeline, it is only produced for the user.

Args:
model (ModelGraph): The model to which FIFO depth optimization is applied.
initial_fifo_depths (Dict[str, int]): A dictionary that contains the FIFO names as keys and the initial
depths as values.
optmized_fifo_depths (Dict[str, int]): A dictionary that contains the FIFO names as keys and the optimized
depths as values.
"""
depths = {}
for fifo_name in initial_fifo_depths.keys():
depths[fifo_name] = {}
depths[fifo_name]['initial'] = initial_fifo_depths[fifo_name]
depths[fifo_name]['optimized'] = optimized_fifo_depths[fifo_name]

with open(model.config.get_output_dir() + "/fifo_depths.json", "w") as f:
json.dump(depths, f, indent=4)


def set_optimized_fifo_depths(model, optimized_fifo_depths):
"""Set the new optimized FIFO depths.

Args:
model (ModelGraph): The model to which FIFO depth optimization is applied.
optmized_fifo_depths (Dict[str, int]): A dictionary that contains the FIFO names as keys and the optimized
depths as values.
"""

# iterate through the layer output FIFOs
for output_variable in model.output_vars.values():
if "VivadoStreamVariable" in str(type(output_variable)):
if output_variable.pragma:

if output_variable.name not in optimized_fifo_depths.keys():
continue

filtered_depth = optimized_fifo_depths[output_variable.name]
output_variable.pragma = (output_variable.pragma[0], filtered_depth)
return


class FifoDepthOptimization(ConfigurableOptimizerPass, ModelOptimizerPass):
def __init__(self):
pass

def transform(self, model):
"""Perform FIFO depth optimization between the FIFOs of all layers to reduce resource utilization as the
initial FIFOs set by hls4ml might be larger than required. At the end of the optimization the FIFOs will
have the largest depths achieved during cosimulation without causing any deadlocks between the layers
(producer-consumer), thus no additional delays between the layers. In some cases, this optimization
might lead to bigger FIFOs than initially set by the hls4ml tool in order to prevent deadlocks.

Args:
model (ModelGraph): The model to which FIFO depth optimization is applied.

Raises:
ValueError: If the FIFO depth for profiling provided by the user is not a non-negative integer.
RuntimeError: If the IO type is not set to "io_stream".

Returns:
bool: The execution state of the Optimzer Pass
"""

# use `large_fifo_depth = 0` to keep the default fifo depth
# consider changing 100_000 either with a very very large value > of any total bram storage space
# or via vitis 2023.2 c-simulation
profiling_fifo_depth = getattr(self, "profiling_fifo_depth", 100_000)

if not isinstance(profiling_fifo_depth, int) or profiling_fifo_depth < 0:
raise ValueError("The FIFO depth for profiling (profiling_fifo_depth variable) must be a non-negative integer.")

# check axi-stream or io-stream
if not (model.config.get_config_value("IOType") == "io_stream"):
raise RuntimeError("To use this optimization you have to set `IOType` field to `io_stream` in the HLS config.")

initial_fifo_depths = initialize_large_fifos(model, profiling_fifo_depth)

execute_cosim_to_profile_fifos(model)

optimized_fifo_depths = get_vitis_optimized_fifo_depths(model)

generate_depths_file(model, initial_fifo_depths, optimized_fifo_depths)

set_optimized_fifo_depths(model, optimized_fifo_depths)

print("[hls4ml] - FIFO optimization completed")
return False
33 changes: 30 additions & 3 deletions hls4ml/backends/vitis/vitis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def _register_flows(self):

self._default_flow = register_flow('ip', None, requires=ip_flow_requirements, backend=self.name)

# Register the fifo depth optimization flow which is different from the one for vivado
fifo_depth_opt_passes = [
'vitis:fifo_depth_optimization'
] + writer_passes # After optimization, a new project will be written

register_flow('fifo_depth_optimization', fifo_depth_opt_passes, requires=['vitis:ip'], backend=self.name)

def create_initial_config(
self,
part='xcvu13p-flga2577-2-e',
Expand Down Expand Up @@ -76,7 +83,18 @@ def create_initial_config(

return config

def build(self, model, reset=False, csim=True, synth=True, cosim=False, validation=False, export=False, vsynth=False):
def build(
self,
model,
reset=False,
csim=True,
synth=True,
cosim=False,
validation=False,
export=False,
vsynth=False,
fifo_opt=False,
):
if 'linux' in sys.platform:
found = os.system('command -v vitis_hls > /dev/null')
if found != 0:
Expand All @@ -87,8 +105,17 @@ def build(self, model, reset=False, csim=True, synth=True, cosim=False, validati
os.system(
(
'vitis_hls -f build_prj.tcl "reset={reset} csim={csim} synth={synth} cosim={cosim} '
'validation={validation} export={export} vsynth={vsynth}"'
).format(reset=reset, csim=csim, synth=synth, cosim=cosim, validation=validation, export=export, vsynth=vsynth)
'validation={validation} export={export} vsynth={vsynth} fifo_opt={fifo_opt}"'
).format(
reset=reset,
csim=csim,
synth=synth,
cosim=cosim,
validation=validation,
export=export,
vsynth=vsynth,
fifo_opt=fifo_opt,
)
)
os.chdir(curr_dir)

Expand Down
6 changes: 5 additions & 1 deletion hls4ml/templates/vivado/build_prj.tcl
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ if {$opt(csim)} {

if {$opt(synth)} {
puts "***** C/RTL SYNTHESIS *****"

set time_start [clock clicks -milliseconds]
csynth_design
set time_end [clock clicks -milliseconds]
Expand All @@ -195,7 +196,10 @@ if {$opt(cosim)} {

if {$opt(fifo_opt)} {
puts "\[hls4ml\] - FIFO optimization started"
add_vcd_instructions_tcl

if {[string equal "$backend" "vivado"] || [string equal $backend "vivadoaccelerator"]} {
add_vcd_instructions_tcl
}
}

remove_recursive_log_wave
Expand Down
26 changes: 26 additions & 0 deletions hls4ml/writer/vitis_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,36 @@ def write_nnet_utils_overrides(self, model):
for h in headers:
copy(srcpath + h, dstpath + h)

def write_board_script(self, model):
'''
Write the tcl scripts and kernel sources to create a Vitis IPI
'''

###################
# project.tcl
###################

f = open(f'{model.config.get_output_dir()}/project.tcl', 'w')
f.write('variable project_name\n')
f.write(f'set project_name "{model.config.get_project_name()}"\n')
f.write('variable backend\n')
f.write('set backend "vitis"\n')
f.write('variable part\n')
f.write('set part "{}"\n'.format(model.config.get_config_value('Part')))
f.write('variable clock_period\n')
f.write('set clock_period {}\n'.format(model.config.get_config_value('ClockPeriod')))
f.write('variable clock_uncertainty\n')
f.write('set clock_uncertainty {}\n'.format(model.config.get_config_value('ClockUncertainty', '12.5%')))
f.write('variable version\n')
f.write('set version "{}"\n'.format(model.config.get_config_value('Version', '1.0.0')))
f.close()
return

def write_hls(self, model):
"""
Write the HLS project. Calls the steps from VivadoWriter, adapted for Vitis
"""
super().write_hls(model)
self.write_nnet_utils_overrides(model)
self.write_tar(model)
self.write_board_script(model)
Loading