From 91ee88e1f489d6eeb5f38466ddaad13945574693 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 6 Dec 2024 12:02:50 -0500 Subject: [PATCH 1/3] fixes to parsing of pytorch models when using torch functionals --- hls4ml/converters/pytorch/pooling.py | 20 ++++++++++++-------- hls4ml/converters/pytorch/reshape.py | 22 ++++++++++++++++------ 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/hls4ml/converters/pytorch/pooling.py b/hls4ml/converters/pytorch/pooling.py index 8256a9ff87..3757b2c82e 100644 --- a/hls4ml/converters/pytorch/pooling.py +++ b/hls4ml/converters/pytorch/pooling.py @@ -90,15 +90,19 @@ def parse_pooling_layer(operation, layer_name, input_names, input_shapes, node, layer['stride_height'] = node.kwargs['stride'][0] layer['stride_width'] = node.kwargs['stride'][1] else: - layer['stride_height'] = node.kwargs['stride'] - layer['stride_width'] = node.kwargs['stride'] - if type(node.kwargs['kernel_size']) is tuple: - layer['pool_height'] = node.kwargs['kernel_size'][0] - layer['pool_width'] = node.kwargs['kernel_size'][1] + if node.kwargs['stride'] is None: + # if stride is not set it is supposed to default to the kernel size + layer['stride_height'] = node.args[1] + layer['stride_width'] = node.args[1] + else: + layer['stride_height'] = node.kwargs['stride'] + layer['stride_width'] = node.kwargs['stride'] + if type(node.args[1]) is tuple: + layer['pool_height'] = node.args[1][0] + layer['pool_width'] = node.args[1][1] else: - layer['pool_height'] = node.kwargs['kernel_size'] - layer['pool_width'] = node.kwargs['kernel_size'] - + layer['pool_height'] = node.args[1] + layer['pool_width'] = node.args[1] if type(node.kwargs['padding']) is tuple: padding = node.kwargs['padding'] else: diff --git a/hls4ml/converters/pytorch/reshape.py b/hls4ml/converters/pytorch/reshape.py index 37191135a1..7e43c5b328 100644 --- a/hls4ml/converters/pytorch/reshape.py +++ b/hls4ml/converters/pytorch/reshape.py @@ -93,13 +93,23 @@ def parse_flatten_layer(operation, layer_name, input_names, input_shapes, node, layer['class_name'] = 'Reshape' layer['name'] = layer_name layer['inputs'] = input_names - - start_dim = class_object.start_dim - end_dim = class_object.end_dim - if end_dim + 1 == 0 or end_dim + 1 > len(input_shapes[0]): - end_dim = len(input_shapes[0]) + if node.op == "call_module": + start_dim = class_object.start_dim + end_dim = class_object.end_dim + if end_dim + 1 == 0 or end_dim + 1 > len(input_shapes[0]): + end_dim = len(input_shapes[0]) + else: + end_dim = end_dim + 1 else: - end_dim = end_dim + 1 + start_dim = node.args[1] + if len(node.args) == 3: + end_dim = node.args[2] + else: + end_dim = -1 + if end_dim + 1 == 0 or end_dim + 1 > len(input_shapes[0]): + end_dim = len(input_shapes[0]) + else: + end_dim = end_dim + 1 layer['target_shape'] = ( input_shapes[0][0:start_dim] + [np.prod(input_shapes[0][start_dim:end_dim])] + input_shapes[0][end_dim:] From 2afae66d3680db3cea957927eeea12c1fbd84693 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 6 Dec 2024 13:45:05 -0500 Subject: [PATCH 2/3] fix quotation marks --- hls4ml/converters/pytorch/pooling.py | 1 + hls4ml/converters/pytorch/reshape.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/hls4ml/converters/pytorch/pooling.py b/hls4ml/converters/pytorch/pooling.py index 3757b2c82e..f6b39c9010 100644 --- a/hls4ml/converters/pytorch/pooling.py +++ b/hls4ml/converters/pytorch/pooling.py @@ -86,6 +86,7 @@ def parse_pooling_layer(operation, layer_name, input_names, input_shapes, node, padding = [class_object.padding, class_object.padding] else: + print(node.args) if type(node.kwargs['stride']) is tuple: layer['stride_height'] = node.kwargs['stride'][0] layer['stride_width'] = node.kwargs['stride'][1] diff --git a/hls4ml/converters/pytorch/reshape.py b/hls4ml/converters/pytorch/reshape.py index 7e43c5b328..3d415e7832 100644 --- a/hls4ml/converters/pytorch/reshape.py +++ b/hls4ml/converters/pytorch/reshape.py @@ -93,7 +93,7 @@ def parse_flatten_layer(operation, layer_name, input_names, input_shapes, node, layer['class_name'] = 'Reshape' layer['name'] = layer_name layer['inputs'] = input_names - if node.op == "call_module": + if node.op == 'call_module': start_dim = class_object.start_dim end_dim = class_object.end_dim if end_dim + 1 == 0 or end_dim + 1 > len(input_shapes[0]): From a0a573ef5d4c06d0e61ea2b32853b572c8f8c4c1 Mon Sep 17 00:00:00 2001 From: Jan-Frederik Schulte Date: Fri, 6 Dec 2024 13:45:23 -0500 Subject: [PATCH 3/3] fix quotation marks --- hls4ml/converters/pytorch/pooling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/hls4ml/converters/pytorch/pooling.py b/hls4ml/converters/pytorch/pooling.py index f6b39c9010..3757b2c82e 100644 --- a/hls4ml/converters/pytorch/pooling.py +++ b/hls4ml/converters/pytorch/pooling.py @@ -86,7 +86,6 @@ def parse_pooling_layer(operation, layer_name, input_names, input_shapes, node, padding = [class_object.padding, class_object.padding] else: - print(node.args) if type(node.kwargs['stride']) is tuple: layer['stride_height'] = node.kwargs['stride'][0] layer['stride_width'] = node.kwargs['stride'][1]