From b29587aadd070011f1de259acc10db7d90513851 Mon Sep 17 00:00:00 2001
From: Pooja Babu
Date: Wed, 24 Jan 2024 12:09:58 +0100
Subject: [PATCH] Template changes
---
.../printers/nest_gpu_function_call_printer.py | 1 -
.../resources_nest_gpu/directives | 1 -
.../point_neuron/@NEURON_NAME@.cu.jinja2 | 8 +++++---
.../directives/SetScalParamAndVar.jinja2 | 13 +++++++++++++
pynestml/utils/ast_utils.py | 18 ++++++++++++++++++
5 files changed, 36 insertions(+), 5 deletions(-)
delete mode 120000 pynestml/codegeneration/resources_nest_gpu/directives
create mode 100644 pynestml/codegeneration/resources_nest_gpu/point_neuron/directives/SetScalParamAndVar.jinja2
diff --git a/pynestml/codegeneration/printers/nest_gpu_function_call_printer.py b/pynestml/codegeneration/printers/nest_gpu_function_call_printer.py
index 355571ec7..79ac2646e 100644
--- a/pynestml/codegeneration/printers/nest_gpu_function_call_printer.py
+++ b/pynestml/codegeneration/printers/nest_gpu_function_call_printer.py
@@ -48,7 +48,6 @@ def _print_function_call_format_string(self, function_call: ASTFunctionCall) ->
if function_name == PredefinedFunctions.TIME_RESOLUTION:
# context dependent; we assume the template contains the necessary definitions
return 'h'
- #return 'NESTGPUTimeResolution'
if function_name == PredefinedFunctions.TIME_STEPS:
return '(int)round({!s}/NESTGPUTimeResolution)'
diff --git a/pynestml/codegeneration/resources_nest_gpu/directives b/pynestml/codegeneration/resources_nest_gpu/directives
deleted file mode 120000
index 6b3b2eee8..000000000
--- a/pynestml/codegeneration/resources_nest_gpu/directives
+++ /dev/null
@@ -1 +0,0 @@
-/Users/pooja/nestml/master/pynestml/codegeneration/resources_nest/point_neuron/directives
\ No newline at end of file
diff --git a/pynestml/codegeneration/resources_nest_gpu/point_neuron/@NEURON_NAME@.cu.jinja2 b/pynestml/codegeneration/resources_nest_gpu/point_neuron/@NEURON_NAME@.cu.jinja2
index 68ca962fa..8f2799cab 100644
--- a/pynestml/codegeneration/resources_nest_gpu/point_neuron/@NEURON_NAME@.cu.jinja2
+++ b/pynestml/codegeneration/resources_nest_gpu/point_neuron/@NEURON_NAME@.cu.jinja2
@@ -26,6 +26,8 @@
#include "{{ neuronName }}.h"
#include "spike_buffer.h"
+{%- import 'directives/SetScalParamAndVar.jinja2' as set_scal_param_var with context %}
+
{%- if uses_analytic_solver %}
using namespace {{ neuronName }}_ns;
@@ -176,7 +178,7 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
// Parameters
{%- for variable_symbol in neuron.get_parameter_symbols() %}
{%- set variable = utils.get_parameter_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
- SetScalParam(0, n_node, "{{ printer_no_origin.print(variable) }}", {{printer_no_origin.print(variable_symbol.get_declaring_expression())}}); // as {{variable_symbol.get_type_symbol().print_symbol()}}
+ SetScalParam(0, n_node, "{{ printer_no_origin.print(variable) }}", {{set_scal_param_var.SetScalParamAndVar(variable_symbol.get_declaring_expression())}}); // as {{variable_symbol.get_type_symbol().print_symbol()}}
{%- endfor %}
// Internal variables
@@ -194,7 +196,7 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
// State variables
{%- for variable_symbol in neuron.get_state_symbols() %}
{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %}
- SetScalVar(0, n_node, "{{ printer_no_origin.print(variable) }}", {{printer_no_origin.print(variable_symbol.get_declaring_expression())}});
+ SetScalVar(0, n_node, "{{ printer_no_origin.print(variable) }}", {{set_scal_param_var.SetScalParamAndVar(variable_symbol.get_declaring_expression())}});
{%- endfor %}
{%- endif %}
@@ -212,7 +214,7 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
port_input_port_step_ = 1;
{# TODO #}
-{# den_delay_arr_ = GetParamArr() + GetScalParamIdx("den_delay");#}
+{# den_delay_arr_ = GetParamArr() + GetScalParamIdx("den_delay"); #}
return 0;
}
diff --git a/pynestml/codegeneration/resources_nest_gpu/point_neuron/directives/SetScalParamAndVar.jinja2 b/pynestml/codegeneration/resources_nest_gpu/point_neuron/directives/SetScalParamAndVar.jinja2
new file mode 100644
index 000000000..22cc846e3
--- /dev/null
+++ b/pynestml/codegeneration/resources_nest_gpu/point_neuron/directives/SetScalParamAndVar.jinja2
@@ -0,0 +1,13 @@
+{#
+ Initialization of param or var if they have a declaring expression
+ @param expr ASTExpression declaring expression of the variable
+#}
+{%- macro SetScalParamAndVar(expr) -%}
+{%- if utils.is_declaring_expression_parameter(expr) %}
+ *GetScalParam(0, n_node, "{{expr}}")
+{%- elif utils.is_declaring_expression_state_varible(expr) %}
+ *GetScalVar(0, n_node, "{{expr}}")
+{%- else %}
+ {{printer_no_origin.print(expr)}}
+{%- endif %}
+{%- endmacro -%}
\ No newline at end of file
diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py
index 6a107ceb2..d2fc208c8 100644
--- a/pynestml/utils/ast_utils.py
+++ b/pynestml/utils/ast_utils.py
@@ -2234,4 +2234,22 @@ def adjusted_state_symbols(cls, neuron: ASTNeuron):
diff = list(set(neuron.get_state_symbols()) - set(extract_list))
return diff + extract_list
return neuron.get_state_symbols()
+
+
+ @classmethod
+ def is_declaring_expression_parameter(cls, expr: ASTExpression) -> bool:
+ if isinstance(expr, ASTSimpleExpression):
+ if expr.is_variable():
+ symbol = expr.get_scope().resolve_to_symbol(expr.get_variable().get_name(), SymbolKind.VARIABLE)
+ if symbol and symbol.is_parameters():
+ return True
+ return False
+ @classmethod
+ def is_declaring_expression_state_varible(cls, expr: ASTExpression) -> bool:
+ if isinstance(expr, ASTSimpleExpression):
+ if expr.is_variable():
+ symbol = expr.get_scope().resolve_to_symbol(expr.get_variable().get_name(), SymbolKind.VARIABLE)
+ if symbol and symbol.is_state():
+ return True
+ return False
\ No newline at end of file