Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial values for the hidden/cell state for LSTM and GRU models in Pytorch #1120

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions hls4ml/backends/quartus/passes/recurrent_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@
}};\n'''

gru_function_template = 'nnet::gru<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'
gru_function_initial_state_template = (
'nnet::gru<{input_t}, {input2_t}, {output_t}, {config}>({input}, {input2}, {output}, {w}, {wr}, {b}, {br});'
)


class GRUConfigTemplate(LayerConfigTemplate):
Expand Down Expand Up @@ -137,15 +140,23 @@ def format(self, node):
class GRUFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(GRU, include_header=recurrent_include_list)
self.template = gru_function_template

def format(self, node):
params = self._default_function_params(node)
if params['pass_initial_states'] == 'true':
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
params['input2'] = node.get_input_variable(node.inputs[1]).name
params['w'] = node.get_weights('weight').name
params['b'] = node.get_weights('bias').name
params['wr'] = node.get_weights('recurrent_weight').name
params['br'] = node.get_weights('recurrent_bias').name
return self.template.format(**params)

if params['pass_initial_states'] == 'true':
template = gru_function_initial_state_template
else:
template = gru_function_template

return template.format(**params)


################################################
Expand Down Expand Up @@ -174,6 +185,9 @@ def format(self, node):
}};\n"""

lstm_function_template = 'nnet::lstm<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
lstm_function_initial_state_template = (
'nnet::lstm<{input_t}, {input2_t}, {input3_t}, {output_t}, {config}>({input}, {input2}, {input3}, {output}, {weights});'
)


class LSTMConfigTemplate(LayerConfigTemplate):
Expand Down Expand Up @@ -214,11 +228,16 @@ def format(self, node):
class LSTMFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(LSTM, include_header=recurrent_include_list)
self.template = lstm_function_template

def format(self, node):
params = self._default_function_params(node)

if params['pass_initial_states'] == 'true':
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
params['input2'] = node.get_input_variable(node.inputs[1]).name
params['input3'] = node.get_input_variable(node.inputs[2]).name
params['input3_t'] = node.get_input_variable(node.inputs[2]).type.name

types = ['i', 'f', 'c', 'o']
params['weights'] = ''
for t in types:
Expand All @@ -228,7 +247,12 @@ def format(self, node):
for t in types:
params['weights'] += 'bias_{}_{}{}'.format(t, str(node.index), ',' if t != 'o' else '')

return self.template.format(**params)
if params['pass_initial_states'] == 'true':
template = lstm_function_initial_state_template
else:
template = lstm_function_template

return template.format(**params)


################################################
Expand Down
20 changes: 18 additions & 2 deletions hls4ml/backends/vivado/passes/recurrent_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
}};\n"""

recr_function_template = 'nnet::{recr_type}_stack<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'
recr_function_template_initial_states_lstm = 'nnet::{recr_type}_stack<{input_t}, {input2_t}, {input3_t}, {output_t}, {config}>({input}, {input2}, {input3}, {output}, {w}, {wr}, {b}, {br});' # noqa: E501
recr_function_template_initial_states_gru = 'nnet::{recr_type}_stack<{input_t}, {input2_t}, {output_t}, {config}>({input}, {input2}, {output}, {w}, {wr}, {b}, {br});' # noqa: E501

recr_include_list = ['nnet_utils/nnet_recurrent.h']

Expand Down Expand Up @@ -186,10 +188,16 @@ def format(self, node):
class RecurrentFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__((LSTM, GRU), include_header=recr_include_list)
self.template = recr_function_template

def format(self, node):
params = self._default_function_params(node)
if params['pass_initial_states'] == 'true':
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
params['input2'] = node.get_input_variable(node.inputs[1]).name
if node.class_name == 'LSTM':
params['input3'] = node.get_input_variable(node.inputs[2]).name
params['input3_t'] = node.get_input_variable(node.inputs[2]).type.name

params['w'] = node.get_weights('weight').name
params['b'] = node.get_weights('bias').name
params['wr'] = node.get_weights('recurrent_weight').name
Expand All @@ -198,4 +206,12 @@ def format(self, node):
params['recurrent_activation'] = node.get_attr('recurrent_activation')
params['recr_type'] = node.class_name.lower()

return self.template.format(**params)
if params['pass_initial_states'] == 'true':
if node.class_name == 'LSTM':
template = recr_function_template_initial_states_lstm
else:
template = recr_function_template_initial_states_gru
else:
template = recr_function_template

return template.format(**params)
2 changes: 2 additions & 0 deletions hls4ml/converters/keras/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,6 @@ def parse_rnn_layer(keras_layer, input_names, input_shapes, data_reader):
if layer['return_state']:
raise Exception('"return_state" of {} layer is not yet supported.')

layer['pass_initial_states'] = False

return layer, output_shape
22 changes: 11 additions & 11 deletions hls4ml/converters/pytorch/recurrent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import warnings

import numpy as np

from hls4ml.converters.pytorch_to_hls import pytorch_handler
Expand All @@ -15,14 +13,13 @@ def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, clas

layer["name"] = layer_name

layer['inputs'] = [input_names[0]]
if len(input_names) > 1:
warnings.warn(
'hls4ml disregards the initial value of the hidden state passed to the model, assuming that it is all zeros',
stacklevel=2,
)
layer['inputs'] = input_names
if 'IOType' in config.keys():
if len(input_names) > 1 and config['IOType'] == 'io_stream':
raise Exception('Passing initial values for the hidden state is not support for io_stream input type.')

layer['class_name'] = operation
if operation == "RNN":
if operation == 'RNN':
layer['class_name'] = 'SimpleRNN'

layer['return_sequences'] = False # parameter does not exist in pytorch
Expand All @@ -31,7 +28,7 @@ def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, clas
if layer['class_name'] == 'SimpleRNN':
layer['activation'] = class_object.nonlinearity # Default is tanh, can also be ReLU in pytorch
else:
layer['activation'] = "tanh" # GRU and LSTM are hard-coded to use tanh in pytorch
layer['activation'] = 'tanh' # GRU and LSTM are hard-coded to use tanh in pytorch

if layer['class_name'] == 'GRU' or layer['class_name'] == 'LSTM':
layer['recurrent_activation'] = 'sigmoid' # GRU and LSTM are hard-coded to use sigmoid in pytorch
Expand All @@ -51,7 +48,6 @@ def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, clas

if class_object.bidirectional:
raise Exception('hls4ml does not support birectional RNNs')

if class_object.dropout > 0:
raise Exception('hls4ml does not support RNNs with dropout')

Expand All @@ -70,5 +66,9 @@ def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, clas
output_shape = [input_shapes[0][0], layer['n_out']]

layer['pytorch'] = True # need to switch some behaviors to match pytorch implementations
if len(input_names) == 1:
layer['pass_initial_states'] = False
else:
layer['pass_initial_states'] = True

return layer, output_shape
14 changes: 11 additions & 3 deletions hls4ml/converters/pytorch_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,17 @@ def parse_pytorch_model(config, verbose=True):
# parse info from class object
input_names = [inputs_map.get(str(i), str(i)) for i in node.args]
if pytorch_class in ["RNN", "GRU", "LSTM"]:
# we currently don't support the passing of the initial value of the hidden state to RNN models
input_names = [inputs_map.get(str(node.args[0]), str(node.args[0]))]
input_shapes = [output_shapes[str(node.args[0])]]
input_shapes = []
input_names = []
for i in node.args:
if isinstance(i, tuple):
for y in i:
input_shapes.append(output_shapes[str(y)])
input_names.append(inputs_map.get(str(y), str(y)))
else:
input_shapes.append(output_shapes[str(i)])
input_names.append(inputs_map.get(str(i), str(i)))

# if a 'getitem' is the input to a node, step back in the graph to find the real source of the input
elif "getitem" in node.args[0].name:
for tmp_node in traced_model.graph.nodes:
Expand Down
114 changes: 114 additions & 0 deletions hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,47 @@ void gru(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_outputs * CONFIG_T::
}
}

template <class data_T, class data2_T, class res_T, typename CONFIG_T>
void gru(data_T data[CONFIG_T::n_in], data2_T h[CONFIG_T::n_units], res_T res[CONFIG_T::n_outputs * CONFIG_T::n_units],
const typename CONFIG_T::weight_t weights[3 * CONFIG_T::n_units * CONFIG_T::n_in],
const typename CONFIG_T::weight_t recurrent_weights[3 * CONFIG_T::n_units * CONFIG_T::n_units],
const typename CONFIG_T::bias_t bias[3 * CONFIG_T::n_units],
const typename CONFIG_T::bias_t recurrent_bias[3 * CONFIG_T::n_units]) {

hls_register data_T x[CONFIG_T::n_in];
// hls_register res_T h[CONFIG_T::n_units];

// #pragma unroll
// for (int i = 0; i < CONFIG_T::n_units; i++) {
// h[i] = 0;
// }

// Loop depedency - cannot pipeline
#pragma disable_loop_pipelining
for (int t = 0; t < CONFIG_T::n_timesteps; t++) {
// Get data at current time step
#pragma unroll
for (int j = 0; j < CONFIG_T::n_in; j++) {
x[j] = data[j + t * CONFIG_T::n_in];
}

nnet::gru_cell<data_T, res_T, CONFIG_T>(x, h, weights, recurrent_weights, bias, recurrent_bias);

if (CONFIG_T::return_sequences) {
#pragma unroll
for (int i = 0; i < CONFIG_T::n_units; i++) {
res[CONFIG_T::n_units * t + i] = h[i];
}
}
}

if (!CONFIG_T::return_sequences) {
#pragma unroll
for (int i = 0; i < (CONFIG_T::n_units); i++) {
res[i] = h[i];
}
}
}
//----------------------
// SimpleRNN
//----------------------
Expand Down Expand Up @@ -711,6 +752,79 @@ void lstm(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in], res_T res[CONFIG_
}
}

template <class data_T, class data2_T, class data3_T, class res_T, class CONFIG_T>
void lstm(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in], data2_T hidden_state_initial[CONFIG_T::n_out],
data3_T cell_state_initial[CONFIG_T::n_out], res_T res[CONFIG_T::n_outputs * CONFIG_T::n_out],
const typename CONFIG_T::weight_t WI[CONFIG_T::n_in * CONFIG_T::n_out],
const typename CONFIG_T::weight_t WF[CONFIG_T::n_in * CONFIG_T::n_out],
const typename CONFIG_T::weight_t WC[CONFIG_T::n_in * CONFIG_T::n_out],
const typename CONFIG_T::weight_t WO[CONFIG_T::n_in * CONFIG_T::n_out],
const typename CONFIG_T::weight_t RWI[CONFIG_T::n_out * CONFIG_T::n_out],
const typename CONFIG_T::weight_t RWF[CONFIG_T::n_out * CONFIG_T::n_out],
const typename CONFIG_T::weight_t RWC[CONFIG_T::n_out * CONFIG_T::n_out],
const typename CONFIG_T::weight_t RWO[CONFIG_T::n_out * CONFIG_T::n_out],
const typename CONFIG_T::bias_t BI[CONFIG_T::n_out], const typename CONFIG_T::bias_t BF[CONFIG_T::n_out],
const typename CONFIG_T::bias_t BC[CONFIG_T::n_out], const typename CONFIG_T::bias_t BO[CONFIG_T::n_out]) {
res_T hidden_state[CONFIG_T::n_out][CONFIG_T::n_timesteps + 1] hls_register;
res_T hidden_state_temp[CONFIG_T::n_out] hls_register;
res_T cell_state[CONFIG_T::n_out][CONFIG_T::n_timesteps + 1] hls_register;
res_T cell_state_temp[CONFIG_T::n_out] hls_register;
res_T h[CONFIG_T::n_out] hls_register;
res_T c[CONFIG_T::n_out] hls_register;
data_T in[CONFIG_T::n_in] hls_register;

// Set initially hidden state (output) to zero
INIT_LOOP:
#pragma unroll
for (int x = 0; x < CONFIG_T::n_out; x++) {
hidden_state[x][0] = hidden_state_initial[x];
cell_state[x][0] = cell_state_initial[x];
}

// Input dimension
#pragma disable_loop_pipelining
for (int i = 0; i < CONFIG_T::n_timesteps; i++) {
// Data at current time step
for (int x = 0; x < CONFIG_T::n_in; x++) {
in[x] = data[x + i * CONFIG_T::n_in];
}

// Hidden state at current time step
#pragma unroll
for (int x = 0; x < CONFIG_T::n_out; x++) {
hidden_state_temp[x] = hidden_state[x][i];
cell_state_temp[x] = cell_state[x][i];
}

// Do LSTM
lstm_cell<data_T, res_T, CONFIG_T>(in, hidden_state_temp, h, cell_state_temp, c, WI, WF, WC, WO, RWI, RWF, RWC, RWO,
BI, BF, BC, BO);

// Write result
#pragma unroll
for (int x = 0; x < CONFIG_T::n_out; x++) {
hidden_state[x][i + 1] = h[x];
cell_state[x][i + 1] = c[x];
}
}

if (CONFIG_T::return_sequences == 0) {
// Output when return_sequences is false
#pragma unroll
for (int x = 0; x < CONFIG_T::n_out; x++) {
res[x] = hidden_state[x][CONFIG_T::n_timesteps];
}
} else {
// Output when return_sequences is true
#pragma unroll
for (int x = 0; x < CONFIG_T::n_timesteps; x++) {
for (int h = 0; h < CONFIG_T::n_out; h++) {
res[x * CONFIG_T::n_out + h] = hidden_state[h][x + 1];
}
}
}
}

} // namespace nnet

#endif
Loading
Loading