Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compiler Integration of Concat Operator #17

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fetch-repos.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ FINN_EXP_COMMIT="0724be21111a21f0d81a072fccc1c446e053f851"
BREVITAS_COMMIT="d4834bd2a0fad3c1fbc0ff7e1346d5dcb3797ea4"
PYVERILATOR_COMMIT="ce0a08c20cb8c1d1e84181d6f392390f846adbd1"
CNPY_COMMIT="4e8810b1a8637695171ed346ce68f6984e585ef4"
HLSLIB_COMMIT="16e5847a5e3ef76cffe84c8fad2f010d593457d3"
HLSLIB_COMMIT="2c066e87f5b8d309693c5d46c206473ca20ac68c"
OMX_COMMIT="0b59762f9e4c4f7e5aa535ee9bc29f292434ca7a"
AVNET_BDF_COMMIT="2d49cfc25766f07792c0b314489f21fe916b639b"
XIL_BDF_COMMIT="8cf4bb674a919ac34e3d99d8d71a9e60af93d14e"
Expand Down
71 changes: 53 additions & 18 deletions src/finn/custom_op/fpgadataflow/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import math
import numpy as np
import warnings
from qonnx.core.datatype import DataType
from qonnx.util.basic import roundup_to_integer_multiple

Expand All @@ -36,17 +38,18 @@

class StreamingConcat(HWCustomOp):
"""Abstraction layer for HW implementation of Concat.
Only supports concatenating along the last axis."""
Only supports concatenating along the last (channel) axis."""

def __init__(self, onnx_node, **kwargs):
super().__init__(onnx_node, **kwargs)

def get_nodeattr_types(self):
my_attrs = {
"SIMD": ("i", True, 0),
# number of elements from each stream to concat
"ElemsPerStream": ("ints", True, []),
# FINN DataTypes for inputs; output datatype inferred from input
"inputDataType": ("s", True, ""),
"ChannelsPerStream": ("ints", True, []),
# FINN DataTypes for inputs; output datatype inferred from inputs
"inputDataTypes": ("strings", True, [""]),
# number of input vectors for non-concat axes, examples:
# [1] is a single vector (like a FC layer with batch=1)
# [4] is four vectors (like a FC layer with batch=4)
Expand All @@ -57,29 +60,36 @@ def get_nodeattr_types(self):
return my_attrs

def get_n_inputs(self):
return len(self.get_nodeattr("ElemsPerStream"))
return len(self.get_nodeattr("ChannelsPerStream"))

def get_total_elems(self):
elems_per_stream = self.get_nodeattr("ElemsPerStream")
elems_per_stream = self.get_nodeattr("ChannelsPerStream")
return int(np.sum(elems_per_stream))

def get_normal_input_shape(self, ind=0):
elems_per_stream = self.get_nodeattr("ElemsPerStream")
elems_per_stream = self.get_nodeattr("ChannelsPerStream")
elems = elems_per_stream[ind]
vecs = list(self.get_nodeattr("numInputVectors"))
ishape = tuple(vecs + [elems])
return ishape

def get_folded_input_shape(self, ind=0):
return self.get_normal_input_shape(ind)
simd = self.get_nodeattr("SIMD")
folds = self.get_nodeattr("ChannelsPerStream")[ind] // simd
vecs = list(self.get_nodeattr("numInputVectors"))
return tuple(vecs + [folds, simd])

def get_normal_output_shape(self, ind=0):
total_elems = self.get_total_elems()
vecs = list(self.get_nodeattr("numInputVectors"))
return tuple(vecs + [total_elems])

def get_folded_output_shape(self, ind=0):
return self.get_normal_output_shape()
total_elems = self.get_total_elems()
simd = self.get_nodeattr("SIMD")
folds = total_elems // simd
vecs = list(self.get_nodeattr("numInputVectors"))
return tuple(vecs + [folds, simd])

def make_shape_compatible_op(self, model):
# check all input shapes
Expand All @@ -94,7 +104,16 @@ def infer_node_datatype(self, model):
# check all input datatypes
for i, inp in enumerate(self.onnx_node.input):
idt = model.get_tensor_datatype(inp)
assert idt == self.get_input_datatype()
if idt != self.get_input_datatype(i):
warn_str = "inputDataType changing for %s: %s -> %s " % (
self.onnx_node.name,
str(self.get_input_datatype(i)),
str(idt),
)
warnings.warn(warn_str)
old_datatypes_attr = self.get_nodeattr("inputDataTypes")
old_datatypes_attr[i] = idt.name
self.set_nodeattr("inputDataTypes", old_datatypes_attr)
odt = self.get_output_datatype()
model.set_tensor_datatype(self.onnx_node.output[0], odt)

Expand All @@ -103,21 +122,37 @@ def verify_node(self):

def get_input_datatype(self, ind=0):
# input dt identical for all inputs
return DataType[self.get_nodeattr("inputDataType")]
return DataType[self.get_nodeattr("inputDataTypes")[ind]]

def get_output_datatype(self, ind=0):
return self.get_input_datatype()
# infer output datatype from declared inputDataTypes
min_input = 0
max_input = 0
for i in range(len(self.get_nodeattr("inputDataTypes"))):
idt = self.get_input_datatype(i)
if idt.min() < min_input:
min_input = idt.min()
if idt.max() > max_input:
max_input = idt.max()
# if the input range is always greater than 0, then acc_max <= 2^P - 1
if min_input >= 0:
out_bit_width = math.ceil(np.log2(max_input + 1))
odt = DataType[f"UINT{out_bit_width}"]
# if the input range is signed, then acc_min >= -2^{P-1} and acc_max <=
# 2^{P - 1} - 1, which means 2^{P - 1} >= max(-acc_min, 1 + acc_max)
else:
max_abs_input = max(-min_input, 1 + max_input)
out_bit_width = math.ceil(np.log2(max_abs_input) + 1)
odt = DataType[f"INT{out_bit_width}"]
return odt

def get_instream_width(self, ind=0):
elems_per_stream = self.get_nodeattr("ElemsPerStream")
elems = elems_per_stream[ind]
ibits = self.get_input_datatype().bitwidth()
return elems * ibits
ibits = self.get_input_datatype(ind).bitwidth()
return ibits * self.get_nodeattr("SIMD")

def get_outstream_width(self, ind=0):
obits = self.get_output_datatype().bitwidth()
total_elems = self.get_total_elems()
out_width = total_elems * obits
out_width = obits * self.get_nodeattr("SIMD")
return out_width

def get_number_output_values(self):
Expand Down
Loading