From b29587aadd070011f1de259acc10db7d90513851 Mon Sep 17 00:00:00 2001 From: Pooja Babu Date: Wed, 24 Jan 2024 12:09:58 +0100 Subject: [PATCH] Template changes --- .../printers/nest_gpu_function_call_printer.py | 1 - .../resources_nest_gpu/directives | 1 - .../point_neuron/@NEURON_NAME@.cu.jinja2 | 8 +++++--- .../directives/SetScalParamAndVar.jinja2 | 13 +++++++++++++ pynestml/utils/ast_utils.py | 18 ++++++++++++++++++ 5 files changed, 36 insertions(+), 5 deletions(-) delete mode 120000 pynestml/codegeneration/resources_nest_gpu/directives create mode 100644 pynestml/codegeneration/resources_nest_gpu/point_neuron/directives/SetScalParamAndVar.jinja2 diff --git a/pynestml/codegeneration/printers/nest_gpu_function_call_printer.py b/pynestml/codegeneration/printers/nest_gpu_function_call_printer.py index 355571ec7..79ac2646e 100644 --- a/pynestml/codegeneration/printers/nest_gpu_function_call_printer.py +++ b/pynestml/codegeneration/printers/nest_gpu_function_call_printer.py @@ -48,7 +48,6 @@ def _print_function_call_format_string(self, function_call: ASTFunctionCall) -> if function_name == PredefinedFunctions.TIME_RESOLUTION: # context dependent; we assume the template contains the necessary definitions return 'h' - #return 'NESTGPUTimeResolution' if function_name == PredefinedFunctions.TIME_STEPS: return '(int)round({!s}/NESTGPUTimeResolution)' diff --git a/pynestml/codegeneration/resources_nest_gpu/directives b/pynestml/codegeneration/resources_nest_gpu/directives deleted file mode 120000 index 6b3b2eee8..000000000 --- a/pynestml/codegeneration/resources_nest_gpu/directives +++ /dev/null @@ -1 +0,0 @@ -/Users/pooja/nestml/master/pynestml/codegeneration/resources_nest/point_neuron/directives \ No newline at end of file diff --git a/pynestml/codegeneration/resources_nest_gpu/point_neuron/@NEURON_NAME@.cu.jinja2 b/pynestml/codegeneration/resources_nest_gpu/point_neuron/@NEURON_NAME@.cu.jinja2 index 68ca962fa..8f2799cab 100644 --- a/pynestml/codegeneration/resources_nest_gpu/point_neuron/@NEURON_NAME@.cu.jinja2 +++ b/pynestml/codegeneration/resources_nest_gpu/point_neuron/@NEURON_NAME@.cu.jinja2 @@ -26,6 +26,8 @@ #include "{{ neuronName }}.h" #include "spike_buffer.h" +{%- import 'directives/SetScalParamAndVar.jinja2' as set_scal_param_var with context %} + {%- if uses_analytic_solver %} using namespace {{ neuronName }}_ns; @@ -176,7 +178,7 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/, // Parameters {%- for variable_symbol in neuron.get_parameter_symbols() %} {%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} - SetScalParam(0, n_node, "{{ printer_no_origin.print(variable) }}", {{printer_no_origin.print(variable_symbol.get_declaring_expression())}}); // as {{variable_symbol.get_type_symbol().print_symbol()}} + SetScalParam(0, n_node, "{{ printer_no_origin.print(variable) }}", {{set_scal_param_var.SetScalParamAndVar(variable_symbol.get_declaring_expression())}}); // as {{variable_symbol.get_type_symbol().print_symbol()}} {%- endfor %} // Internal variables @@ -194,7 +196,7 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/, // State variables {%- for variable_symbol in neuron.get_state_symbols() %} {%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} - SetScalVar(0, n_node, "{{ printer_no_origin.print(variable) }}", {{printer_no_origin.print(variable_symbol.get_declaring_expression())}}); + SetScalVar(0, n_node, "{{ printer_no_origin.print(variable) }}", {{set_scal_param_var.SetScalParamAndVar(variable_symbol.get_declaring_expression())}}); {%- endfor %} {%- endif %} @@ -212,7 +214,7 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/, port_input_port_step_ = 1; {# TODO #} -{# den_delay_arr_ = GetParamArr() + GetScalParamIdx("den_delay");#} +{# den_delay_arr_ = GetParamArr() + GetScalParamIdx("den_delay"); #} return 0; } diff --git a/pynestml/codegeneration/resources_nest_gpu/point_neuron/directives/SetScalParamAndVar.jinja2 b/pynestml/codegeneration/resources_nest_gpu/point_neuron/directives/SetScalParamAndVar.jinja2 new file mode 100644 index 000000000..22cc846e3 --- /dev/null +++ b/pynestml/codegeneration/resources_nest_gpu/point_neuron/directives/SetScalParamAndVar.jinja2 @@ -0,0 +1,13 @@ +{# + Initialization of param or var if they have a declaring expression + @param expr ASTExpression declaring expression of the variable +#} +{%- macro SetScalParamAndVar(expr) -%} +{%- if utils.is_declaring_expression_parameter(expr) %} + *GetScalParam(0, n_node, "{{expr}}") +{%- elif utils.is_declaring_expression_state_varible(expr) %} + *GetScalVar(0, n_node, "{{expr}}") +{%- else %} + {{printer_no_origin.print(expr)}} +{%- endif %} +{%- endmacro -%} \ No newline at end of file diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py index 6a107ceb2..d2fc208c8 100644 --- a/pynestml/utils/ast_utils.py +++ b/pynestml/utils/ast_utils.py @@ -2234,4 +2234,22 @@ def adjusted_state_symbols(cls, neuron: ASTNeuron): diff = list(set(neuron.get_state_symbols()) - set(extract_list)) return diff + extract_list return neuron.get_state_symbols() + + + @classmethod + def is_declaring_expression_parameter(cls, expr: ASTExpression) -> bool: + if isinstance(expr, ASTSimpleExpression): + if expr.is_variable(): + symbol = expr.get_scope().resolve_to_symbol(expr.get_variable().get_name(), SymbolKind.VARIABLE) + if symbol and symbol.is_parameters(): + return True + return False + @classmethod + def is_declaring_expression_state_varible(cls, expr: ASTExpression) -> bool: + if isinstance(expr, ASTSimpleExpression): + if expr.is_variable(): + symbol = expr.get_scope().resolve_to_symbol(expr.get_variable().get_name(), SymbolKind.VARIABLE) + if symbol and symbol.is_state(): + return True + return False \ No newline at end of file