diff --git a/.github/workflows/continuous_benchmarking.yml b/.github/workflows/continuous_benchmarking.yml
new file mode 100644
index 000000000..980da3473
--- /dev/null
+++ b/.github/workflows/continuous_benchmarking.yml
@@ -0,0 +1,91 @@
+on:
+ pull_request_target:
+ types: [opened, reopened, edited, synchronize]
+
+jobs:
+ fork_pr_requires_review:
+ environment: ${{ (github.event.pull_request.head.repo.full_name == github.repository && 'internal') || 'external' }}
+ runs-on: ubuntu-latest
+ steps:
+ - run: true
+
+ benchmark_fork_pr_branch:
+ needs: fork_pr_requires_review
+ name: Continuous Benchmarking Fork PRs with Bencher
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ repository: ${{ github.event.pull_request.head.repo.full_name }}
+ ref: ${{ github.event.pull_request.head.sha }}
+ persist-credentials: false
+
+ - uses: bencherdev/bencher@main
+
+ # Setup Python version
+ - name: Setup Python 3.8
+ uses: actions/setup-python@v5
+ with:
+ python-version: 3.8
+
+ # Install dependencies
+ - name: Install apt dependencies
+ run: |
+ sudo apt-get update
+ sudo apt-get install libltdl7-dev libgsl0-dev libncurses5-dev libreadline6-dev pkg-config
+ sudo apt-get install python3-all-dev python3-matplotlib python3-numpy python3-scipy ipython3
+
+ # Install Python dependencies
+ - name: Python dependencies
+ run: |
+ python -m pip install --upgrade pip pytest jupyterlab matplotlib pycodestyle scipy pandas pytest-benchmark
+ python -m pip install -r requirements.txt
+
+ # Install NEST simulator
+ - name: NEST simulator
+ run: |
+ python -m pip install cython
+ echo "GITHUB_WORKSPACE = $GITHUB_WORKSPACE"
+ NEST_SIMULATOR=$(pwd)/nest-simulator
+ NEST_INSTALL=$(pwd)/nest_install
+ echo "NEST_SIMULATOR = $NEST_SIMULATOR"
+ echo "NEST_INSTALL = $NEST_INSTALL"
+
+ git clone --depth=1 https://github.com/nest/nest-simulator
+ mkdir nest_install
+ echo "NEST_INSTALL=$NEST_INSTALL" >> $GITHUB_ENV
+ cd nest_install
+ cmake -DCMAKE_INSTALL_PREFIX=$NEST_INSTALL $NEST_SIMULATOR
+ make && make install
+ cd ..
+
+ # Install NESTML (repeated)
+ - name: Install NESTML
+ run: |
+ export PYTHONPATH=${{ env.PYTHONPATH }}:${{ env.NEST_INSTALL }}/lib/python3.8/site-packages
+ #echo PYTHONPATH=`pwd` >> $GITHUB_ENV
+ echo "PYTHONPATH=$PYTHONPATH" >> $GITHUB_ENV
+ python setup.py install
+
+ - name: Track Fork PR Benchmarks with Bencher
+ env:
+ LD_LIBRARY_PATH: ${{ env.NEST_INSTALL }}/lib/nest
+ run: |
+ echo "NEST_INSTALL = $NEST_INSTALL"
+ LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${{ env.NEST_INSTALL }}/lib/nest bencher run \
+ --project nestml \
+ --token '${{ secrets.BENCHER_API_TOKEN }}' \
+ --branch '${{ github.event.number }}/merge' \
+ --branch-start-point '${{ github.base_ref }}' \
+ --branch-start-point-hash '${{ github.event.pull_request.base.sha }}' \
+ --branch-reset \
+ --github-actions "${{ secrets.GITHUB_TOKEN }}" \
+ --testbed ubuntu-latest \
+ --adapter python_pytest \
+ --file results.json \
+ --err \
+ 'LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${{ env.NEST_INSTALL }}/lib/nest python3 -m pytest --benchmark-json results.json -s $GITHUB_WORKSPACE/tests/nest_continuous_benchmarking/test_nest_continuous_benchmarking.py'
+
+ - name: Setup tmate session
+ if: ${{ failure() }}
+ uses: mxschmitt/action-tmate@v3
diff --git a/doc/tutorials/stdp_third_factor_active_dendrite/stdp_third_factor_active_dendrite.ipynb b/doc/tutorials/stdp_third_factor_active_dendrite/stdp_third_factor_active_dendrite.ipynb
index 147f5b4a1..3922d0ac8 100644
--- a/doc/tutorials/stdp_third_factor_active_dendrite/stdp_third_factor_active_dendrite.ipynb
+++ b/doc/tutorials/stdp_third_factor_active_dendrite/stdp_third_factor_active_dendrite.ipynb
@@ -1347,7 +1347,7 @@
" NESTCodeGeneratorUtils.generate_code_for(nestml_neuron_model,\n",
" nestml_synapse_model,\n",
" codegen_opts=codegen_opts,\n",
- " logging_level=\"INFO\") # try \"INFO\" or \"DEBUG\" for more debug information"
+ " logging_level=\"WARNING\") # try \"INFO\" or \"DEBUG\" for more debug information"
]
},
{
diff --git a/pynestml/cocos/co_co_all_variables_defined.py b/pynestml/cocos/co_co_all_variables_defined.py
index e41b0727e..38cfa89ab 100644
--- a/pynestml/cocos/co_co_all_variables_defined.py
+++ b/pynestml/cocos/co_co_all_variables_defined.py
@@ -41,11 +41,10 @@ class CoCoAllVariablesDefined(CoCo):
"""
@classmethod
- def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False):
+ def check_co_co(cls, node: ASTModel):
"""
Checks if this coco applies for the handed over neuron. Models which contain undefined variables are not correct.
:param node: a single neuron instance.
- :param after_ast_rewrite: indicates whether this coco is checked after the code generator has done rewriting of the abstract syntax tree. If True, checks are not as rigorous. Use False where possible.
"""
# for each variable in all expressions, check if the variable has been defined previously
expression_collector_visitor = ASTExpressionCollectorVisitor()
@@ -62,32 +61,6 @@ def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False):
# test if the symbol has been defined at least
if symbol is None:
- if after_ast_rewrite: # after ODE-toolbox transformations, convolutions are replaced by state variables, so cannot perform this check properly
- symbol2 = node.get_scope().resolve_to_symbol(var.get_name(), SymbolKind.VARIABLE)
- if symbol2 is not None:
- # an inline expression defining this variable name (ignoring differential order) exists
- if "__X__" in str(symbol2): # if this variable was the result of a convolution...
- continue
- else:
- # for kernels, also allow derivatives of that kernel to appear
-
- inline_expr_names = []
- inline_exprs = []
- for equations_block in node.get_equations_blocks():
- inline_expr_names.extend([inline_expr.variable_name for inline_expr in equations_block.get_inline_expressions()])
- inline_exprs.extend(equations_block.get_inline_expressions())
-
- if var.get_name() in inline_expr_names:
- inline_expr_idx = inline_expr_names.index(var.get_name())
- inline_expr = inline_exprs[inline_expr_idx]
- from pynestml.utils.ast_utils import ASTUtils
- if ASTUtils.inline_aliases_convolution(inline_expr):
- symbol2 = node.get_scope().resolve_to_symbol(var.get_name(), SymbolKind.VARIABLE)
- if symbol2 is not None:
- # actually, no problem detected, skip error
- # XXX: TODO: check that differential order is less than or equal to that of the kernel
- continue
-
# check if this symbol is actually a type, e.g. "mV" in the expression "(1 + 2) * mV"
symbol2 = var.get_scope().resolve_to_symbol(var.get_complete_name(), SymbolKind.TYPE)
if symbol2 is not None:
@@ -106,9 +79,14 @@ def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False):
# in this case its ok if it is recursive or defined later on
continue
+ if symbol.is_predefined:
+ continue
+
+ if symbol.block_type == BlockType.LOCAL and symbol.get_referenced_object().get_source_position().before(var.get_source_position()):
+ continue
+
# check if it has been defined before usage, except for predefined symbols, input ports and variables added by the AST transformation functions
- if (not symbol.is_predefined) \
- and symbol.block_type != BlockType.INPUT \
+ if symbol.block_type != BlockType.INPUT \
and not symbol.get_referenced_object().get_source_position().is_added_source_position():
# except for parameters, those can be defined after
if ((not symbol.get_referenced_object().get_source_position().before(var.get_source_position()))
diff --git a/pynestml/cocos/co_co_function_unique.py b/pynestml/cocos/co_co_function_unique.py
index 15643c0ad..bf0f2be60 100644
--- a/pynestml/cocos/co_co_function_unique.py
+++ b/pynestml/cocos/co_co_function_unique.py
@@ -65,4 +65,5 @@ def check_co_co(cls, model: ASTModel):
log_level=LoggingLevel.ERROR,
message=message, code=code)
checked.append(funcA)
+
checked_funcs_names.append(func.get_name())
diff --git a/pynestml/cocos/co_co_illegal_expression.py b/pynestml/cocos/co_co_illegal_expression.py
index b78396e3b..c362d0dc5 100644
--- a/pynestml/cocos/co_co_illegal_expression.py
+++ b/pynestml/cocos/co_co_illegal_expression.py
@@ -18,13 +18,13 @@
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
-from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
-from pynestml.utils.ast_source_location import ASTSourceLocation
-from pynestml.meta_model.ast_declaration import ASTDeclaration
from pynestml.cocos.co_co import CoCo
+from pynestml.meta_model.ast_declaration import ASTDeclaration
+from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
from pynestml.symbols.error_type_symbol import ErrorTypeSymbol
from pynestml.symbols.predefined_types import PredefinedTypes
+from pynestml.utils.ast_source_location import ASTSourceLocation
from pynestml.utils.logger import LoggingLevel, Logger
from pynestml.utils.logging_helper import LoggingHelper
from pynestml.utils.messages import Messages
diff --git a/pynestml/cocos/co_co_nest_random_functions_legally_used.py b/pynestml/cocos/co_co_nest_random_functions_legally_used.py
new file mode 100644
index 000000000..81e2fc464
--- /dev/null
+++ b/pynestml/cocos/co_co_nest_random_functions_legally_used.py
@@ -0,0 +1,73 @@
+# -*- coding: utf-8 -*-
+#
+# co_co_nest_random_functions_legally_used.py
+#
+# This file is part of NEST.
+#
+# Copyright (C) 2004 The NEST Initiative
+#
+# NEST is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# NEST is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with NEST. If not, see .
+
+from pynestml.cocos.co_co import CoCo
+from pynestml.meta_model.ast_model import ASTModel
+from pynestml.meta_model.ast_node import ASTNode
+from pynestml.meta_model.ast_on_condition_block import ASTOnConditionBlock
+from pynestml.meta_model.ast_on_receive_block import ASTOnReceiveBlock
+from pynestml.meta_model.ast_update_block import ASTUpdateBlock
+from pynestml.symbols.predefined_functions import PredefinedFunctions
+from pynestml.utils.logger import LoggingLevel, Logger
+from pynestml.utils.messages import Messages
+from pynestml.visitors.ast_visitor import ASTVisitor
+
+
+class CoCoNestRandomFunctionsLegallyUsed(CoCo):
+ """
+ This CoCo ensure that the random functions are used only in the ``update``, ``onReceive``, and ``onCondition`` blocks.
+ This CoCo is only checked for the NEST Simulator target.
+ """
+
+ @classmethod
+ def check_co_co(cls, node: ASTNode):
+ """
+ Checks the coco.
+ :param node: a single node (typically, a neuron or synapse)
+ """
+ visitor = CoCoNestRandomFunctionsLegallyUsedVisitor()
+ visitor.neuron = node
+ node.accept(visitor)
+
+
+class CoCoNestRandomFunctionsLegallyUsedVisitor(ASTVisitor):
+ def visit_function_call(self, node):
+ """
+ Visits a function call
+ :param node: a function call
+ """
+ function_name = node.get_name()
+ if function_name == PredefinedFunctions.RANDOM_NORMAL or function_name == PredefinedFunctions.RANDOM_UNIFORM \
+ or function_name == PredefinedFunctions.RANDOM_POISSON:
+ parent = node
+ while parent:
+ parent = parent.get_parent()
+
+ if isinstance(parent, ASTUpdateBlock) or isinstance(parent, ASTOnReceiveBlock) \
+ or isinstance(parent, ASTOnConditionBlock):
+ # the random function is correctly defined, hence return
+ return
+
+ if isinstance(parent, ASTModel):
+ # the random function is defined in other blocks (parameters, state, internals). Hence, an error.
+ code, message = Messages.get_random_functions_legally_used(function_name)
+ Logger.log_message(node=self.neuron, code=code, message=message, error_position=node.get_source_position(),
+ log_level=LoggingLevel.ERROR)
diff --git a/pynestml/cocos/co_co_no_kernels_except_in_convolve.py b/pynestml/cocos/co_co_no_kernels_except_in_convolve.py
index 18b862292..e318ae566 100644
--- a/pynestml/cocos/co_co_no_kernels_except_in_convolve.py
+++ b/pynestml/cocos/co_co_no_kernels_except_in_convolve.py
@@ -22,11 +22,14 @@
from typing import List
from pynestml.cocos.co_co import CoCo
+from pynestml.meta_model.ast_declaration import ASTDeclaration
+from pynestml.meta_model.ast_external_variable import ASTExternalVariable
from pynestml.meta_model.ast_function_call import ASTFunctionCall
from pynestml.meta_model.ast_kernel import ASTKernel
from pynestml.meta_model.ast_model import ASTModel
from pynestml.meta_model.ast_node import ASTNode
from pynestml.meta_model.ast_variable import ASTVariable
+from pynestml.symbols.predefined_functions import PredefinedFunctions
from pynestml.symbols.symbol import SymbolKind
from pynestml.utils.logger import Logger, LoggingLevel
from pynestml.utils.messages import Messages
@@ -89,24 +92,44 @@ def visit_variable(self, node: ASTNode):
if not (isinstance(node, ASTExternalVariable) and node.get_alternate_name()):
code, message = Messages.get_no_variable_found(kernelName)
Logger.log_message(node=self.__neuron_node, code=code, message=message, log_level=LoggingLevel.ERROR)
+
continue
+
if not symbol.is_kernel():
continue
+
if node.get_complete_name() == kernelName:
- parent = node.get_parent()
- if parent is not None:
+ parent = node
+ correct = False
+ while parent is not None and not isinstance(parent, ASTModel):
+ parent = parent.get_parent()
+ assert parent is not None
+
+ if isinstance(parent, ASTDeclaration):
+ for lhs_var in parent.get_variables():
+ if kernelName == lhs_var.get_complete_name():
+ # kernel name appears on lhs of declaration, assume it is initial state
+ correct = True
+ parent = None # break out of outer loop
+ break
+
if isinstance(parent, ASTKernel):
- continue
- grandparent = parent.get_parent()
- if grandparent is not None and isinstance(grandparent, ASTFunctionCall):
- grandparent_func_name = grandparent.get_name()
- if grandparent_func_name == 'convolve':
- continue
- code, message = Messages.get_kernel_outside_convolve(kernelName)
- Logger.log_message(code=code,
- message=message,
- log_level=LoggingLevel.ERROR,
- error_position=node.get_source_position())
+ # kernel name is used inside kernel definition, e.g. for a node ``g``, it appears in ``kernel g'' = -1/tau**2 * g - 2/tau * g'``
+ correct = True
+ break
+
+ if isinstance(parent, ASTFunctionCall):
+ func_name = parent.get_name()
+ if func_name == PredefinedFunctions.CONVOLVE:
+ # kernel name is used inside convolve call
+ correct = True
+
+ if not correct:
+ code, message = Messages.get_kernel_outside_convolve(kernelName)
+ Logger.log_message(code=code,
+ message=message,
+ log_level=LoggingLevel.ERROR,
+ error_position=node.get_source_position())
class KernelCollectingVisitor(ASTVisitor):
diff --git a/pynestml/cocos/co_co_v_comp_exists.py b/pynestml/cocos/co_co_v_comp_exists.py
index 4ef08c0ec..51308f2cc 100644
--- a/pynestml/cocos/co_co_v_comp_exists.py
+++ b/pynestml/cocos/co_co_v_comp_exists.py
@@ -43,9 +43,6 @@ def check_co_co(cls, neuron: ASTModel):
Models which are supposed to be compartmental but do not contain
state variable called v_comp are not correct.
:param neuron: a single neuron instance.
- :param after_ast_rewrite: indicates whether this coco is checked
- after the code generator has done rewriting of the abstract syntax tree.
- If True, checks are not as rigorous. Use False where possible.
"""
from pynestml.codegeneration.nest_compartmental_code_generator import NESTCompartmentalCodeGenerator
diff --git a/pynestml/cocos/co_cos_manager.py b/pynestml/cocos/co_cos_manager.py
index 01d008890..c90ffa2b1 100644
--- a/pynestml/cocos/co_cos_manager.py
+++ b/pynestml/cocos/co_cos_manager.py
@@ -23,11 +23,9 @@
from pynestml.cocos.co_co_all_variables_defined import CoCoAllVariablesDefined
from pynestml.cocos.co_co_inline_expression_not_assigned_to import CoCoInlineExpressionNotAssignedTo
-from pynestml.cocos.co_co_input_port_not_assigned_to import CoCoInputPortNotAssignedTo
from pynestml.cocos.co_co_cm_channel_model import CoCoCmChannelModel
from pynestml.cocos.co_co_cm_continuous_input_model import CoCoCmContinuousInputModel
from pynestml.cocos.co_co_convolve_cond_correctly_built import CoCoConvolveCondCorrectlyBuilt
-from pynestml.cocos.co_co_convolve_has_correct_parameter import CoCoConvolveHasCorrectParameter
from pynestml.cocos.co_co_input_port_not_assigned_to import CoCoInputPortNotAssignedTo
from pynestml.cocos.co_co_integrate_odes_params_correct import CoCoIntegrateODEsParamsCorrect
from pynestml.cocos.co_co_correct_numerator_of_unit import CoCoCorrectNumeratorOfUnit
@@ -43,6 +41,7 @@
from pynestml.cocos.co_co_invariant_is_boolean import CoCoInvariantIsBoolean
from pynestml.cocos.co_co_kernel_type import CoCoKernelType
from pynestml.cocos.co_co_model_name_unique import CoCoModelNameUnique
+from pynestml.cocos.co_co_nest_random_functions_legally_used import CoCoNestRandomFunctionsLegallyUsed
from pynestml.cocos.co_co_no_kernels_except_in_convolve import CoCoNoKernelsExceptInConvolve
from pynestml.cocos.co_co_no_nest_name_space_collision import CoCoNoNestNameSpaceCollision
from pynestml.cocos.co_co_no_duplicate_compilation_unit_names import CoCoNoDuplicateCompilationUnitNames
@@ -69,6 +68,7 @@
from pynestml.cocos.co_co_priorities_correctly_specified import CoCoPrioritiesCorrectlySpecified
from pynestml.meta_model.ast_model import ASTModel
from pynestml.frontend.frontend_configuration import FrontendConfiguration
+from pynestml.utils.logger import Logger
class CoCosManager:
@@ -123,12 +123,12 @@ def check_state_variables_initialized(cls, model: ASTModel):
CoCoStateVariablesInitialized.check_co_co(model)
@classmethod
- def check_variables_defined_before_usage(cls, model: ASTModel, after_ast_rewrite: bool) -> None:
+ def check_variables_defined_before_usage(cls, model: ASTModel) -> None:
"""
Checks that all variables are defined before being used.
:param model: a single model.
"""
- CoCoAllVariablesDefined.check_co_co(model, after_ast_rewrite)
+ CoCoAllVariablesDefined.check_co_co(model)
@classmethod
def check_v_comp_requirement(cls, neuron: ASTModel):
@@ -402,17 +402,27 @@ def check_input_port_size_type(cls, model: ASTModel):
CoCoVectorInputPortsCorrectSizeType.check_co_co(model)
@classmethod
- def post_symbol_table_builder_checks(cls, model: ASTModel, after_ast_rewrite: bool = False):
+ def check_co_co_nest_random_functions_legally_used(cls, model: ASTModel):
+ """
+ Checks if the random number functions are used only in the update block.
+ :param model: a single model object.
+ """
+ CoCoNestRandomFunctionsLegallyUsed.check_co_co(model)
+
+ @classmethod
+ def check_cocos(cls, model: ASTModel, after_ast_rewrite: bool = False):
"""
Checks all context conditions.
:param model: a single model object.
"""
+ Logger.set_current_node(model)
+
cls.check_each_block_defined_at_most_once(model)
cls.check_function_defined(model)
cls.check_variables_unique_in_scope(model)
cls.check_inline_expression_not_assigned_to(model)
cls.check_state_variables_initialized(model)
- cls.check_variables_defined_before_usage(model, after_ast_rewrite)
+ cls.check_variables_defined_before_usage(model)
if FrontendConfiguration.get_target_platform().upper() == 'NEST_COMPARTMENTAL':
# XXX: TODO: refactor this out; define a ``cocos_from_target_name()`` in the frontend instead.
cls.check_v_comp_requirement(model)
@@ -452,3 +462,5 @@ def post_symbol_table_builder_checks(cls, model: ASTModel, after_ast_rewrite: bo
cls.check_co_co_priorities_correctly_specified(model)
cls.check_resolution_func_legally_used(model)
cls.check_input_port_size_type(model)
+
+ Logger.set_current_node(None)
diff --git a/pynestml/codegeneration/builder.py b/pynestml/codegeneration/builder.py
index 2e6757c1a..a9f98bf58 100644
--- a/pynestml/codegeneration/builder.py
+++ b/pynestml/codegeneration/builder.py
@@ -20,12 +20,12 @@
# along with NEST. If not, see .
from __future__ import annotations
-import subprocess
-import os
from typing import Any, Mapping, Optional
from abc import ABCMeta, abstractmethod
+import os
+import subprocess
from pynestml.exceptions.invalid_target_exception import InvalidTargetException
from pynestml.frontend.frontend_configuration import FrontendConfiguration
diff --git a/pynestml/codegeneration/nest_code_generator.py b/pynestml/codegeneration/nest_code_generator.py
index 0551e9a6e..cc3c91b93 100644
--- a/pynestml/codegeneration/nest_code_generator.py
+++ b/pynestml/codegeneration/nest_code_generator.py
@@ -28,6 +28,7 @@
import pynestml
from pynestml.cocos.co_co_nest_synapse_delay_not_assigned_to import CoCoNESTSynapseDelayNotAssignedTo
+from pynestml.cocos.co_cos_manager import CoCosManager
from pynestml.codegeneration.code_generator import CodeGenerator
from pynestml.codegeneration.code_generator_utils import CodeGeneratorUtils
from pynestml.codegeneration.nest_assignments_helper import NestAssignmentsHelper
@@ -172,21 +173,30 @@ def __init__(self, options: Optional[Mapping[str, Any]] = None):
self.setup_printers()
def run_nest_target_specific_cocos(self, neurons: Sequence[ASTModel], synapses: Sequence[ASTModel]):
- for synapse in synapses:
- synapse_name_stripped = removesuffix(removesuffix(synapse.name.split("_with_")[0], "_"), FrontendConfiguration.suffix)
+ for model in neurons + synapses:
+ # Check if the random number functions are used in the right blocks
+ CoCosManager.check_co_co_nest_random_functions_legally_used(model)
+
+ if Logger.has_errors(model):
+ raise Exception("Error(s) occurred during code generation")
+
+ if self.get_option("neuron_synapse_pairs"):
+ for model in synapses:
+ synapse_name_stripped = removesuffix(removesuffix(model.name.split("_with_")[0], "_"),
+ FrontendConfiguration.suffix)
+ # special case for NEST delay variable (state or parameter)
+ assert synapse_name_stripped in self.get_option("delay_variable").keys(), "Please specify a delay variable for synapse '" + synapse_name_stripped + "' in the code generator options (see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay-and-synaptic-weight)"
+ assert ASTUtils.get_variable_by_name(model, self.get_option("delay_variable")[synapse_name_stripped]), "Delay variable '" + self.get_option("delay_variable")[synapse_name_stripped] + "' not found in synapse '" + synapse_name_stripped + "' (see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay-and-synaptic-weight)"
- # special case for NEST delay variable (state or parameter)
- assert synapse_name_stripped in self.get_option("delay_variable").keys(), "Please specify a delay variable for synapse '" + synapse_name_stripped + "' in the code generator options (see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay-and-synaptic-weight)"
- assert ASTUtils.get_variable_by_name(synapse, self.get_option("delay_variable")[synapse_name_stripped]), "Delay variable '" + self.get_option("delay_variable")[synapse_name_stripped] + "' not found in synapse '" + synapse_name_stripped + "' (see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay-and-synaptic-weight)"
+ # special case for NEST weight variable (state or parameter)
+ assert synapse_name_stripped in self.get_option("weight_variable").keys(), "Please specify a weight variable for synapse '" + synapse_name_stripped + "' in the code generator options (see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay-and-synaptic-weight)"
+ assert ASTUtils.get_variable_by_name(model, self.get_option("weight_variable")[synapse_name_stripped]), "Weight variable '" + self.get_option("weight_variable")[synapse_name_stripped] + "' not found in synapse '" + synapse_name_stripped + "' (see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay-and-synaptic-weight)"
- # special case for NEST weight variable (state or parameter)
- assert synapse_name_stripped in self.get_option("weight_variable").keys(), "Please specify a weight variable for synapse '" + synapse_name_stripped + "' in the code generator options (see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay-and-synaptic-weight)"
- assert ASTUtils.get_variable_by_name(synapse, self.get_option("weight_variable")[synapse_name_stripped]), "Weight variable '" + self.get_option("weight_variable")[synapse_name_stripped] + "' not found in synapse '" + synapse_name_stripped + "' (see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay-and-synaptic-weight)"
+ if self.option_exists("delay_variable") and synapse_name_stripped in self.get_option("delay_variable").keys():
+ delay_variable = self.get_option("delay_variable")[synapse_name_stripped]
+ CoCoNESTSynapseDelayNotAssignedTo.check_co_co(delay_variable, model)
- if self.option_exists("delay_variable") and synapse_name_stripped in self.get_option("delay_variable").keys():
- delay_variable = self.get_option("delay_variable")[synapse_name_stripped]
- CoCoNESTSynapseDelayNotAssignedTo.check_co_co(delay_variable, synapse)
- if Logger.has_errors(synapse):
+ if Logger.has_errors(model):
raise Exception("Error(s) occurred during code generation")
def setup_printers(self):
@@ -374,6 +384,9 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di
if not used_in_eq:
self.non_equations_state_variables[neuron.get_name()].append(var)
+ # cache state variables before symbol table update for the sake of delay variables
+ state_vars_before_update = neuron.get_state_symbols()
+
ASTUtils.remove_initial_values_for_kernels(neuron)
kernels = ASTUtils.remove_kernel_definitions_from_equations_block(neuron)
ASTUtils.update_initial_values_for_odes(neuron, [analytic_solver, numeric_solver])
@@ -388,7 +401,6 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di
neuron = ASTUtils.add_declarations_to_internals(
neuron, self.analytic_solver[neuron.get_name()]["propagators"])
- state_vars_before_update = neuron.get_state_symbols()
self.update_symbol_table(neuron)
# Update the delay parameter parameters after symbol table update
@@ -898,8 +910,8 @@ def update_symbol_table(self, neuron) -> None:
"""
SymbolTable.delete_model_scope(neuron.get_name())
symbol_table_visitor = ASTSymbolTableVisitor()
- symbol_table_visitor.after_ast_rewrite_ = True
neuron.accept(symbol_table_visitor)
+ CoCosManager.check_cocos(neuron, after_ast_rewrite=True)
SymbolTable.add_model_scope(neuron.get_name(), neuron.get_scope())
def get_spike_update_expressions(self, neuron: ASTModel, kernel_buffers, solver_dicts, delta_factors) -> Tuple[Dict[str, ASTAssignment], Dict[str, ASTAssignment]]:
diff --git a/pynestml/codegeneration/nest_compartmental_code_generator.py b/pynestml/codegeneration/nest_compartmental_code_generator.py
index 4711bc497..00f061775 100644
--- a/pynestml/codegeneration/nest_compartmental_code_generator.py
+++ b/pynestml/codegeneration/nest_compartmental_code_generator.py
@@ -18,14 +18,18 @@
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
-import shutil
+
from typing import Any, Dict, List, Mapping, Optional
import datetime
import os
from jinja2 import TemplateRuntimeError
+
+from odetoolbox import analysis
+
import pynestml
+from pynestml.cocos.co_cos_manager import CoCosManager
from pynestml.codegeneration.code_generator import CodeGenerator
from pynestml.codegeneration.nest_assignments_helper import NestAssignmentsHelper
from pynestml.codegeneration.nest_declarations_helper import NestDeclarationsHelper
@@ -53,9 +57,9 @@
from pynestml.meta_model.ast_variable import ASTVariable
from pynestml.symbol_table.symbol_table import SymbolTable
from pynestml.symbols.symbol import SymbolKind
+from pynestml.transformers.inline_expression_expansion_transformer import InlineExpressionExpansionTransformer
from pynestml.utils.ast_vector_parameter_setter_and_printer import ASTVectorParameterSetterAndPrinter
from pynestml.utils.ast_vector_parameter_setter_and_printer_factory import ASTVectorParameterSetterAndPrinterFactory
-from pynestml.transformers.inline_expression_expansion_transformer import InlineExpressionExpansionTransformer
from pynestml.utils.mechanism_processing import MechanismProcessing
from pynestml.utils.channel_processing import ChannelProcessing
from pynestml.utils.concentration_processing import ConcentrationProcessing
@@ -72,7 +76,6 @@
from pynestml.utils.synapse_processing import SynapseProcessing
from pynestml.visitors.ast_random_number_generator_visitor import ASTRandomNumberGeneratorVisitor
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
-from odetoolbox import analysis
class NESTCompartmentalCodeGenerator(CodeGenerator):
@@ -740,8 +743,8 @@ def update_symbol_table(self, neuron, kernel_buffers):
"""
SymbolTable.delete_model_scope(neuron.get_name())
symbol_table_visitor = ASTSymbolTableVisitor()
- symbol_table_visitor.after_ast_rewrite_ = True
neuron.accept(symbol_table_visitor)
+ CoCosManager.check_cocos(neuron, after_ast_rewrite=True)
SymbolTable.add_model_scope(neuron.get_name(), neuron.get_scope())
def _get_ast_variable(self, neuron, var_name) -> Optional[ASTVariable]:
diff --git a/pynestml/codegeneration/python_standalone_code_generator.py b/pynestml/codegeneration/python_standalone_code_generator.py
index f44123743..d6afaa095 100644
--- a/pynestml/codegeneration/python_standalone_code_generator.py
+++ b/pynestml/codegeneration/python_standalone_code_generator.py
@@ -111,7 +111,6 @@ def setup_printers(self):
# GSL printers
self._gsl_variable_printer = PythonSteppingFunctionVariablePrinter(None)
- print("In Python code generator: created self._gsl_variable_printer = " + str(self._gsl_variable_printer))
self._gsl_function_call_printer = PythonSteppingFunctionFunctionCallPrinter(None)
self._gsl_printer = PythonExpressionPrinter(simple_expression_printer=PythonSimpleExpressionPrinter(variable_printer=self._gsl_variable_printer,
constant_printer=self._constant_printer,
diff --git a/pynestml/codegeneration/spinnaker_code_generator.py b/pynestml/codegeneration/spinnaker_code_generator.py
index 2a8fed7de..dce247e9c 100644
--- a/pynestml/codegeneration/spinnaker_code_generator.py
+++ b/pynestml/codegeneration/spinnaker_code_generator.py
@@ -137,7 +137,6 @@ def setup_printers(self):
# GSL printers
self._gsl_variable_printer = PythonSteppingFunctionVariablePrinter(None)
- print("In Python code generator: created self._gsl_variable_printer = " + str(self._gsl_variable_printer))
self._gsl_function_call_printer = PythonSteppingFunctionFunctionCallPrinter(None)
self._gsl_printer = PythonExpressionPrinter(simple_expression_printer=SpinnakerPythonSimpleExpressionPrinter(
variable_printer=self._gsl_variable_printer,
@@ -216,6 +215,7 @@ def generate_code(self, models: Sequence[ASTModel]) -> None:
for model in models:
cloned_model = model.clone()
cloned_model.accept(ASTSymbolTableVisitor())
+ CoCosManager.check_cocos(cloned_model)
cloned_models.append(cloned_model)
self.codegen_cpp.generate_code(cloned_models)
@@ -224,6 +224,7 @@ def generate_code(self, models: Sequence[ASTModel]) -> None:
for model in models:
cloned_model = model.clone()
cloned_model.accept(ASTSymbolTableVisitor())
+ CoCosManager.check_cocos(cloned_model)
cloned_models.append(cloned_model)
self.codegen_py.generate_code(cloned_models)
diff --git a/pynestml/frontend/frontend_configuration.py b/pynestml/frontend/frontend_configuration.py
index 173534c95..aae1fc29a 100644
--- a/pynestml/frontend/frontend_configuration.py
+++ b/pynestml/frontend/frontend_configuration.py
@@ -244,8 +244,8 @@ def handle_module_name(cls, module_name):
@classmethod
def handle_target_platform(cls, target_platform: Optional[str]):
- if target_platform is None or target_platform.upper() == 'NONE':
- target_platform = '' # make sure `target_platform` is always a string
+ if target_platform is None:
+ target_platform = "NONE" # make sure `target_platform` is always a string
from pynestml.frontend.pynestml_frontend import get_known_targets
diff --git a/pynestml/frontend/pynestml_frontend.py b/pynestml/frontend/pynestml_frontend.py
index c257822de..ca3866619 100644
--- a/pynestml/frontend/pynestml_frontend.py
+++ b/pynestml/frontend/pynestml_frontend.py
@@ -41,6 +41,8 @@
from pynestml.utils.logger import Logger, LoggingLevel
from pynestml.utils.messages import Messages
from pynestml.utils.model_parser import ModelParser
+from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
+from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
def get_known_targets():
@@ -131,10 +133,10 @@ def code_generator_from_target_name(target_name: str, options: Optional[Mapping[
return SpiNNakerCodeGenerator(options)
if target_name.upper() == "NONE":
- # dummy/null target: user requested to not generate any code
+ # dummy/null target: user requested to not generate any code (for instance, when just doing validation of a model)
code, message = Messages.get_no_code_generated()
Logger.log_message(None, code, message, None, LoggingLevel.INFO)
- return CodeGenerator("", options)
+ return CodeGenerator(options)
# cannot reach here due to earlier assert -- silence static checker warnings
assert "Unknown code generator requested: " + target_name
@@ -193,12 +195,17 @@ def generate_target(input_path: Union[str, Sequence[str]], target_platform: str,
Enable development mode: code generation is attempted even for models that contain errors, and extra information is rendered in the generated code.
codegen_opts : Optional[Mapping[str, Any]]
A dictionary containing additional options for the target code generator.
+
+ Return
+ ------
+ errors_occurred
+ Flag indicating whether errors occurred during processing. False if processing was successful; True if errors occurred in any of the models.
"""
configure_front_end(input_path, target_platform, target_path, install_path, logging_level,
module_name, store_log, suffix, dev, codegen_opts)
- if not process() == 0:
- raise Exception("Error(s) occurred while processing the model")
+
+ return process()
def configure_front_end(input_path: Union[str, Sequence[str]], target_platform: str, target_path=None,
@@ -373,34 +380,36 @@ def generate_nest_compartmental_target(input_path: Union[str, Sequence[str]], ta
def main() -> int:
- """
+ r"""
Entry point for the command-line application.
Returns
-------
- The process exit code: 0 for success, > 0 for failure
+ exit_code
+ The process exit code: 0 for success, > 0 for failure
"""
try:
FrontendConfiguration.parse_config(sys.argv[1:])
except InvalidPathException as e:
print(e)
+
return 1
+
# the default Python recursion limit is 1000, which might not be enough in practice when running an AST visitor on a deep tree, e.g. containing an automatically generated expression
sys.setrecursionlimit(10000)
+
# after all argument have been collected, start the actual processing
return int(process())
-def get_parsed_models():
+def get_parsed_models() -> List[ASTModel]:
r"""
Handle the parsing and validation of the NESTML files
Returns
-------
- models: Sequence[ASTModel]
+ models
List of correctly parsed models
- errors_occurred : bool
- Flag indicating whether errors occurred during processing
"""
# init log dir
create_report_dir()
@@ -417,36 +426,29 @@ def get_parsed_models():
for nestml_file in nestml_files:
parsed_unit = ModelParser.parse_file(nestml_file)
- if parsed_unit is None:
- # Parsing error in the NESTML model, return True
- return [], True
-
- compilation_units.append(parsed_unit)
+ if parsed_unit:
+ compilation_units.append(parsed_unit)
- if len(compilation_units) > 0:
- # generate a list of all models
- models: Sequence[ASTModel] = []
- for compilationUnit in compilation_units:
- models.extend(compilationUnit.get_model_list())
+ # generate a list of all models
+ models: Sequence[ASTModel] = []
+ for compilation_unit in compilation_units:
+ CoCosManager.check_model_names_unique(compilation_unit)
+ models.extend(compilation_unit.get_model_list())
- # check that no models with duplicate names have been defined
- CoCosManager.check_no_duplicate_compilation_unit_names(models)
+ # check that no models with duplicate names have been defined
+ CoCosManager.check_no_duplicate_compilation_unit_names(models)
- # now exclude those which are broken, i.e. have errors.
- for model in models:
- if Logger.has_errors(model):
- code, message = Messages.get_model_contains_errors(model.get_name())
- Logger.log_message(node=model, code=code, message=message,
- error_position=model.get_source_position(),
- log_level=LoggingLevel.WARNING)
- return [model], True
+ for model in models:
+ model.accept(ASTParentVisitor())
+ model.accept(ASTSymbolTableVisitor())
- return models, False
+ return models
def transform_models(transformers, models):
for transformer in transformers:
models = transformer.transform(models)
+
return models
@@ -454,44 +456,65 @@ def generate_code(code_generators, models):
code_generators.generate_code(models)
-def process():
+def process() -> bool:
r"""
The main toolchain workflow entry point. For all models: parse, validate, transform, generate code and build.
- Returns
- -------
- errors_occurred : bool
- Flag indicating whether errors occurred during processing
+ Return
+ ------
+ errors_occurred
+ Flag indicating whether errors occurred during processing. False if processing was successful; True if errors occurred in any of the models.
"""
- # initialize and set options for transformers, code generator and builder
- codegen_and_builder_opts = FrontendConfiguration.get_codegen_opts()
-
- transformers, codegen_and_builder_opts = transformers_from_target_name(FrontendConfiguration.get_target_platform(),
- options=codegen_and_builder_opts)
+ # initialise model transformers
+ transformers, unused_opts_transformer = transformers_from_target_name(FrontendConfiguration.get_target_platform(),
+ options=FrontendConfiguration.get_codegen_opts())
+ # initialise code generator
code_generator = code_generator_from_target_name(FrontendConfiguration.get_target_platform())
- codegen_and_builder_opts = code_generator.set_options(codegen_and_builder_opts)
+ unused_opts_codegen = code_generator.set_options(FrontendConfiguration.get_codegen_opts())
+
+ # initialise builder
+ _builder, unused_opts_builder = builder_from_target_name(FrontendConfiguration.get_target_platform(),
+ options=FrontendConfiguration.get_codegen_opts())
+
+ # check for unused codegen options
+ for opt_key in FrontendConfiguration.get_codegen_opts().keys():
+ if opt_key in unused_opts_transformer.keys() and opt_key in unused_opts_codegen.keys() and opt_key in unused_opts_builder.keys():
+ raise CodeGeneratorOptionsException("The code generator option \"" + opt_key + "\" does not exist.")
+
+ models = get_parsed_models()
- _builder, codegen_and_builder_opts = builder_from_target_name(FrontendConfiguration.get_target_platform(), options=codegen_and_builder_opts)
+ # validation -- check cocos for models that do not have errors already
+ excluded_models = []
+ for model in models:
+ if Logger.has_errors(model.name):
+ code, message = Messages.get_model_contains_errors(model.get_name())
+ Logger.log_message(node=model, code=code, message=message,
+ error_position=model.get_source_position(),
+ log_level=LoggingLevel.WARNING)
+ excluded_models.append(model)
+ else:
+ CoCosManager.check_cocos(model)
- if len(codegen_and_builder_opts) > 0:
- raise CodeGeneratorOptionsException("The code generator option(s) \"" + ", ".join(codegen_and_builder_opts.keys()) + "\" do not exist.")
+ # exclude models that have errors
+ models = list(set(models) - set(excluded_models))
- models, errors_occurred = get_parsed_models()
+ # transformation(s)
+ models = transform_models(transformers, models)
- if not errors_occurred:
- models = transform_models(transformers, models)
- generate_code(code_generator, models)
+ # generate code
+ generate_code(code_generator, models)
- # perform build
- if _builder is not None:
- _builder.build()
+ # perform build
+ if _builder is not None:
+ _builder.build()
if FrontendConfiguration.store_log:
store_log_to_file()
- return errors_occurred
+ # return a boolean indicating whether errors occurred
+ return len(Logger.get_all_messages_of_level(LoggingLevel.ERROR)) > 0
def init_predefined():
diff --git a/pynestml/meta_model/ast_model.py b/pynestml/meta_model/ast_model.py
index 834e56897..c4b7374bf 100644
--- a/pynestml/meta_model/ast_model.py
+++ b/pynestml/meta_model/ast_model.py
@@ -459,23 +459,27 @@ def add_to_internals_block(self, declaration: ASTDeclaration, index: int = -1) -
Adds the handed over declaration the internals block
:param declaration: a single declaration
"""
- assert len(self.get_internals_blocks()) <= 1, "Only one internals block supported for now"
from pynestml.utils.ast_utils import ASTUtils
+ from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
+ from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
+
+ assert len(self.get_internals_blocks()) <= 1, "Only one internals block supported for now"
+
if not self.get_internals_blocks():
ASTUtils.create_internal_block(self)
+
n_declarations = len(self.get_internals_blocks()[0].get_declarations())
if n_declarations == 0:
index = 0
else:
index = 1 + (index % len(self.get_internals_blocks()[0].get_declarations()))
+
self.get_internals_blocks()[0].get_declarations().insert(index, declaration)
declaration.update_scope(self.get_internals_blocks()[0].get_scope())
- from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
- from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
symtable_vistor = ASTSymbolTableVisitor()
symtable_vistor.block_type_stack.push(BlockType.INTERNALS)
- declaration.accept(symtable_vistor)
- self.get_internals_blocks()[0].accept(ASTParentVisitor())
+ self.accept(ASTParentVisitor())
+ self.accept(symtable_vistor)
symtable_vistor.block_type_stack.pop()
def add_to_state_block(self, declaration: ASTDeclaration) -> None:
@@ -483,24 +487,26 @@ def add_to_state_block(self, declaration: ASTDeclaration) -> None:
Adds the handed over declaration to an arbitrary state block. A state block will be created if none exists.
:param declaration: a single declaration.
"""
- assert len(self.get_state_blocks()) <= 1, "Only one internals block supported for now"
+ from pynestml.symbols.symbol import SymbolKind
from pynestml.utils.ast_utils import ASTUtils
+ from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
+ from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
+
+ assert len(self.get_state_blocks()) <= 1, "Only one internals block supported for now"
+
if not self.get_state_blocks():
ASTUtils.create_state_block(self)
+
self.get_state_blocks()[0].get_declarations().append(declaration)
declaration.update_scope(self.get_state_blocks()[0].get_scope())
- from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
- from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
symtable_vistor = ASTSymbolTableVisitor()
symtable_vistor.block_type_stack.push(BlockType.STATE)
- declaration.accept(symtable_vistor)
- self.get_state_blocks()[0].accept(ASTParentVisitor())
+ self.accept(ASTParentVisitor())
+ self.accept(symtable_vistor)
symtable_vistor.block_type_stack.pop()
- from pynestml.symbols.symbol import SymbolKind
- assert declaration.get_variables()[0].get_scope().resolve_to_symbol(
- declaration.get_variables()[0].get_name(), SymbolKind.VARIABLE) is not None
- assert declaration.get_scope().resolve_to_symbol(declaration.get_variables()[0].get_name(),
- SymbolKind.VARIABLE) is not None
+
+ assert declaration.get_variables()[0].get_scope().resolve_to_symbol(declaration.get_variables()[0].get_name(), SymbolKind.VARIABLE) is not None
+ assert declaration.get_scope().resolve_to_symbol(declaration.get_variables()[0].get_name(), SymbolKind.VARIABLE) is not None
def print_comment(self, prefix: str = "") -> str:
"""
@@ -566,7 +572,6 @@ def get_spike_input_port_names(self) -> List[str]:
"""
Returns a list of all spike input ports defined in the model.
"""
- print("get_spike_input_port_names = " + str([port.get_symbol_name() for port in self.get_spike_input_ports()]))
return [port.get_symbol_name() for port in self.get_spike_input_ports()]
def get_continuous_input_ports(self) -> List[VariableSymbol]:
diff --git a/pynestml/symbols/symbol.py b/pynestml/symbols/symbol.py
index 1e294566b..c73435c6d 100644
--- a/pynestml/symbols/symbol.py
+++ b/pynestml/symbols/symbol.py
@@ -18,8 +18,8 @@
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
-from abc import ABCMeta, abstractmethod
+from abc import ABCMeta, abstractmethod
from enum import Enum
diff --git a/pynestml/symbols/type_symbol.py b/pynestml/symbols/type_symbol.py
index 7047cdbca..a3eb28a12 100644
--- a/pynestml/symbols/type_symbol.py
+++ b/pynestml/symbols/type_symbol.py
@@ -18,11 +18,11 @@
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
+
from abc import ABCMeta, abstractmethod
from pynestml.symbols.symbol import Symbol
from pynestml.utils.logger import Logger, LoggingLevel
-from pynestml.utils.messages import Messages
class TypeSymbol(Symbol):
@@ -198,6 +198,7 @@ def is_castable_to(self, _other_type):
def binary_operation_not_defined_error(self, _operator, _other):
from pynestml.symbols.error_type_symbol import ErrorTypeSymbol
+ from pynestml.utils.messages import Messages
result = ErrorTypeSymbol()
code, message = Messages.get_binary_operation_not_defined(
lhs=self.print_nestml_type(), operator=_operator, rhs=_other.print_nestml_type())
@@ -208,6 +209,7 @@ def binary_operation_not_defined_error(self, _operator, _other):
def unary_operation_not_defined_error(self, _operator):
from pynestml.symbols.error_type_symbol import ErrorTypeSymbol
result = ErrorTypeSymbol()
+ from pynestml.utils.messages import Messages
code, message = Messages.get_unary_operation_not_defined(_operator,
self.print_symbol())
Logger.log_message(code=code, message=message, error_position=self.referenced_object.get_source_position(),
@@ -226,6 +228,7 @@ def inverse_of_unit(cls, other):
return result
def warn_implicit_cast_from_to(self, _from, _to):
+ from pynestml.utils.messages import Messages
code, message = Messages.get_implicit_cast_rhs_to_lhs(_to.print_symbol(), _from.print_symbol())
Logger.log_message(code=code, message=message,
error_position=self.get_referenced_object().get_source_position(),
diff --git a/pynestml/symbols/unit_type_symbol.py b/pynestml/symbols/unit_type_symbol.py
index 37c43b035..1f9977de0 100644
--- a/pynestml/symbols/unit_type_symbol.py
+++ b/pynestml/symbols/unit_type_symbol.py
@@ -19,6 +19,7 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
+from typing import Optional
from pynestml.symbols.type_symbol import TypeSymbol
from pynestml.utils.logger import Logger, LoggingLevel
from pynestml.utils.messages import Messages
@@ -131,12 +132,12 @@ def __sub__(self, other):
def add_or_sub_another_unit(self, other):
if self.equals(other):
return other
- else:
- return self.attempt_magnitude_cast(other)
+
+ return self.attempt_magnitude_cast(other)
def attempt_magnitude_cast(self, other):
if self.differs_only_in_magnitude(other):
- factor = UnitTypeSymbol.get_conversion_factor(self.astropy_unit, other.astropy_unit)
+ factor = UnitTypeSymbol.get_conversion_factor(other.astropy_unit, self.astropy_unit)
other.referenced_object.set_implicit_conversion_factor(factor)
code, message = Messages.get_implicit_magnitude_conversion(self, other, factor)
Logger.log_message(code=code, message=message,
@@ -144,18 +145,20 @@ def attempt_magnitude_cast(self, other):
log_level=LoggingLevel.INFO)
return self
- else:
- return self.binary_operation_not_defined_error('+/-', other)
- # TODO: change order of parameters to conform with the from_to scheme.
- # TODO: Also rename to reflect that, i.e. get_conversion_factor_from_to
+ return self.binary_operation_not_defined_error('+/-', other)
+
@classmethod
- def get_conversion_factor(cls, to, _from):
+ def get_conversion_factor(cls, _from, to) -> Optional[float]:
"""
- Calculates the conversion factor from _convertee_unit to target_unit.
- Behaviour is only well-defined if both units have the same physical base type
+ Calculates the conversion factor from _convertee_unit to target_unit. Behaviour is only well-defined if both units have the same physical base type.
"""
- factor = (_from / to).si.scale
+ try:
+ factor = (_from / to).si.scale
+ except BaseException:
+ # this can fail in case of e.g. trying to convert from "1/s" to "2/s"
+ return None
+
return factor
def is_castable_to(self, _other_type):
diff --git a/pynestml/transformers/assign_implicit_conversion_factors_transformer.py b/pynestml/transformers/assign_implicit_conversion_factors_transformer.py
new file mode 100644
index 000000000..f44ee12d5
--- /dev/null
+++ b/pynestml/transformers/assign_implicit_conversion_factors_transformer.py
@@ -0,0 +1,335 @@
+# -*- coding: utf-8 -*-
+#
+# assign_implicit_conversion_factors_transformer.py
+#
+# This file is part of NEST.
+#
+# Copyright (C) 2004 The NEST Initiative
+#
+# NEST is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# NEST is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with NEST. If not, see .
+
+from typing import Sequence, Union
+
+from pynestml.meta_model.ast_compound_stmt import ASTCompoundStmt
+from pynestml.meta_model.ast_declaration import ASTDeclaration
+from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
+from pynestml.meta_model.ast_node import ASTNode
+from pynestml.meta_model.ast_small_stmt import ASTSmallStmt
+from pynestml.meta_model.ast_stmt import ASTStmt
+from pynestml.symbols.error_type_symbol import ErrorTypeSymbol
+from pynestml.symbols.predefined_types import PredefinedTypes
+from pynestml.symbols.symbol import SymbolKind
+from pynestml.symbols.template_type_symbol import TemplateTypeSymbol
+from pynestml.symbols.variadic_type_symbol import VariadicTypeSymbol
+from pynestml.transformers.transformer import Transformer
+from pynestml.utils.ast_source_location import ASTSourceLocation
+from pynestml.utils.ast_utils import ASTUtils
+from pynestml.utils.logger import LoggingLevel, Logger
+from pynestml.utils.logging_helper import LoggingHelper
+from pynestml.utils.messages import Messages
+from pynestml.utils.type_caster import TypeCaster
+from pynestml.visitors.ast_visitor import ASTVisitor
+
+
+class AssignImplicitConversionFactorsTransformer(Transformer):
+ r"""
+ Assign implicit conversion factors in expressions.
+ """
+
+ def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, Sequence[ASTNode]]:
+ single = False
+ if isinstance(models, ASTNode):
+ single = True
+ models = [models]
+
+ for model in models:
+ model.accept(AssignImplicitConversionFactorVisitor())
+ self.__assign_return_types(model)
+
+ if single:
+ return models[0]
+ return models
+
+ def __assign_return_types(self, _node):
+ for userDefinedFunction in _node.get_functions():
+ symbol = userDefinedFunction.get_scope().resolve_to_symbol(userDefinedFunction.get_name(),
+ SymbolKind.FUNCTION)
+ # first ensure that the block contains at least one statement
+ if symbol is not None and len(userDefinedFunction.get_block().get_stmts()) > 0:
+ # now check that the last statement is a return
+ self.__check_return_recursively(userDefinedFunction,
+ symbol.get_return_type(),
+ userDefinedFunction.get_block().get_stmts(),
+ False)
+ # now if it does not have a statement, but uses a return type, it is an error
+ elif symbol is not None and userDefinedFunction.has_return_type() and \
+ not symbol.get_return_type().equals(PredefinedTypes.get_void_type()):
+ code, message = Messages.get_no_return()
+ Logger.log_message(node=_node, code=code, message=message,
+ error_position=userDefinedFunction.get_source_position(),
+ log_level=LoggingLevel.ERROR)
+
+ def __check_return_recursively(self, processed_function, type_symbol=None, stmts=None, ret_defined: bool = False) -> None:
+ """
+ For a handed over statement, it checks if the statement is a return statement and if it is typed according to the handed over type symbol.
+ :param type_symbol: a single type symbol
+ :type type_symbol: type_symbol
+ :param stmts: a list of statements, either simple or compound
+ :type stmts: list(ASTSmallStmt,ASTCompoundStmt)
+ :param ret_defined: indicates whether a ret has already been defined after this block of stmt, thus is not
+ necessary. Implies that the return has been defined in the higher level block
+ """
+ # in order to ensure that in the sub-blocks, a return is not necessary, we check if the last one in this
+ # block is a return statement, thus it is not required to have a return in the sub-blocks, but optional
+ last_statement = stmts[len(stmts) - 1]
+ ret_defined = False or ret_defined
+ if (len(stmts) > 0 and isinstance(last_statement, ASTStmt)
+ and last_statement.is_small_stmt()
+ and last_statement.small_stmt.is_return_stmt()):
+ ret_defined = True
+
+ # now check that returns are there if necessary and correctly typed
+ for c_stmt in stmts:
+ if c_stmt.is_small_stmt():
+ stmt = c_stmt.small_stmt
+ else:
+ stmt = c_stmt.compound_stmt
+
+ # if it is a small statement, check if it is a return statement
+ if isinstance(stmt, ASTSmallStmt) and stmt.is_return_stmt():
+ # first check if the return is the last one in this block of statements
+ if stmts.index(c_stmt) != (len(stmts) - 1):
+ code, message = Messages.get_not_last_statement('Return')
+ Logger.log_message(error_position=stmt.get_source_position(),
+ code=code, message=message,
+ log_level=LoggingLevel.WARNING)
+
+ # now check that it corresponds to the declared type
+ if stmt.get_return_stmt().has_expression() and type_symbol is PredefinedTypes.get_void_type():
+ code, message = Messages.get_type_different_from_expected(PredefinedTypes.get_void_type(),
+ stmt.get_return_stmt().get_expression().type)
+ Logger.log_message(error_position=stmt.get_source_position(),
+ message=message, code=code, log_level=LoggingLevel.ERROR)
+
+ # if it is not void check if the type corresponds to the one stated
+ if not stmt.get_return_stmt().has_expression() and \
+ not type_symbol.equals(PredefinedTypes.get_void_type()):
+ code, message = Messages.get_type_different_from_expected(PredefinedTypes.get_void_type(),
+ type_symbol)
+ Logger.log_message(error_position=stmt.get_source_position(),
+ message=message, code=code, log_level=LoggingLevel.ERROR)
+
+ if stmt.get_return_stmt().has_expression():
+ type_of_return = stmt.get_return_stmt().get_expression().type
+ if isinstance(type_of_return, ErrorTypeSymbol):
+ code, message = Messages.get_type_could_not_be_derived(processed_function.get_name())
+ Logger.log_message(error_position=stmt.get_source_position(),
+ code=code, message=message, log_level=LoggingLevel.ERROR)
+ elif not type_of_return.equals(type_symbol):
+ TypeCaster.try_to_recover_or_error(type_symbol, type_of_return,
+ stmt.get_return_stmt().get_expression())
+ elif isinstance(stmt, ASTCompoundStmt):
+ # otherwise it is a compound stmt, thus check recursively
+ if stmt.is_if_stmt():
+ self.__check_return_recursively(processed_function,
+ type_symbol,
+ stmt.get_if_stmt().get_if_clause().get_block().get_stmts(),
+ ret_defined)
+ for else_ifs in stmt.get_if_stmt().get_elif_clauses():
+ self.__check_return_recursively(processed_function,
+ type_symbol, else_ifs.get_block().get_stmts(), ret_defined)
+ if stmt.get_if_stmt().has_else_clause():
+ self.__check_return_recursively(processed_function,
+ type_symbol,
+ stmt.get_if_stmt().get_else_clause().get_block().get_stmts(),
+ ret_defined)
+ elif stmt.is_while_stmt():
+ self.__check_return_recursively(processed_function,
+ type_symbol, stmt.get_while_stmt().get_block().get_stmts(),
+ ret_defined)
+ elif stmt.is_for_stmt():
+ self.__check_return_recursively(processed_function,
+ type_symbol, stmt.get_for_stmt().get_block().get_stmts(),
+ ret_defined)
+ # now, if a return statement has not been defined in the corresponding higher level block, we have to ensure that it is defined here
+ elif not ret_defined and stmts.index(c_stmt) == (len(stmts) - 1):
+ if not (isinstance(stmt, ASTSmallStmt) and stmt.is_return_stmt()):
+ code, message = Messages.get_no_return()
+ Logger.log_message(error_position=stmt.get_source_position(), log_level=LoggingLevel.ERROR,
+ code=code, message=message)
+
+
+class AssignImplicitConversionFactorVisitor(ASTVisitor):
+ """
+ This visitor checks that all expression correspond to the expected type.
+ """
+
+ def visit_declaration(self, node):
+ """
+ Visits a single declaration and asserts that type of lhs is equal to type of rhs.
+ :param node: a single declaration.
+ :type node: ASTDeclaration
+ """
+ assert isinstance(node, ASTDeclaration)
+ if node.has_expression():
+ if node.get_expression().get_source_position().equals(ASTSourceLocation.get_added_source_position()):
+ # no type checks are executed for added nodes, since we assume correctness
+ return
+ lhs_type = node.get_data_type().get_type_symbol()
+ rhs_type = node.get_expression().type
+ if isinstance(rhs_type, ErrorTypeSymbol):
+ LoggingHelper.drop_missing_type_error(node)
+ return
+ if self.__types_do_not_match(lhs_type, rhs_type):
+ TypeCaster.try_to_recover_or_error(lhs_type, rhs_type, node.get_expression())
+
+ def visit_inline_expression(self, node):
+ """
+ Visits a single inline expression and asserts that type of lhs is equal to type of rhs.
+ """
+ assert isinstance(node, ASTInlineExpression)
+ lhs_type = node.get_data_type().get_type_symbol()
+ rhs_type = node.get_expression().type
+ if isinstance(rhs_type, ErrorTypeSymbol):
+ LoggingHelper.drop_missing_type_error(node)
+ return
+
+ if self.__types_do_not_match(lhs_type, rhs_type):
+ TypeCaster.try_to_recover_or_error(lhs_type, rhs_type, node.get_expression())
+
+ def visit_assignment(self, node):
+ """
+ Visits a single expression and assures that type(lhs) == type(rhs).
+ :param node: a single assignment.
+ :type node: ASTAssignment
+ """
+ from pynestml.meta_model.ast_assignment import ASTAssignment
+ assert isinstance(node, ASTAssignment)
+
+ if node.get_source_position().equals(ASTSourceLocation.get_added_source_position()):
+ # no type checks are executed for added nodes, since we assume correctness
+ return
+ if node.is_direct_assignment: # case a = b is simple
+ self.handle_simple_assignment(node)
+ else:
+ self.handle_compound_assignment(node) # e.g. a *= b
+
+ def handle_compound_assignment(self, node):
+ rhs_expr = node.get_expression()
+ lhs_variable_symbol = node.get_variable().resolve_in_own_scope()
+ rhs_type_symbol = rhs_expr.type
+
+ if lhs_variable_symbol is None:
+ code, message = Messages.get_equation_var_not_in_state_block(node.get_variable().get_complete_name())
+ Logger.log_message(code=code, message=message, error_position=node.get_source_position(),
+ log_level=LoggingLevel.ERROR)
+ return
+
+ if isinstance(rhs_type_symbol, ErrorTypeSymbol):
+ LoggingHelper.drop_missing_type_error(node)
+ return
+
+ lhs_type_symbol = lhs_variable_symbol.get_type_symbol()
+
+ if node.is_compound_product:
+ if self.__types_do_not_match(lhs_type_symbol, lhs_type_symbol * rhs_type_symbol):
+ TypeCaster.try_to_recover_or_error(lhs_type_symbol, lhs_type_symbol * rhs_type_symbol,
+ node.get_expression())
+ return
+ return
+
+ if node.is_compound_quotient:
+ if self.__types_do_not_match(lhs_type_symbol, lhs_type_symbol / rhs_type_symbol):
+ TypeCaster.try_to_recover_or_error(lhs_type_symbol, lhs_type_symbol / rhs_type_symbol,
+ node.get_expression())
+ return
+ return
+
+ assert node.is_compound_sum or node.is_compound_minus
+ if self.__types_do_not_match(lhs_type_symbol, rhs_type_symbol):
+ TypeCaster.try_to_recover_or_error(lhs_type_symbol, rhs_type_symbol,
+ node.get_expression())
+
+ @staticmethod
+ def __types_do_not_match(lhs_type_symbol, rhs_type_symbol):
+ if lhs_type_symbol is None:
+ return True
+
+ return not lhs_type_symbol.equals(rhs_type_symbol)
+
+ def handle_simple_assignment(self, node):
+ from pynestml.symbols.symbol import SymbolKind
+ lhs_variable_symbol = node.get_scope().resolve_to_symbol(node.get_variable().get_complete_name(),
+ SymbolKind.VARIABLE)
+
+ rhs_type_symbol = node.get_expression().type
+ if isinstance(rhs_type_symbol, ErrorTypeSymbol):
+ LoggingHelper.drop_missing_type_error(node)
+ return
+
+ if lhs_variable_symbol is not None and self.__types_do_not_match(lhs_variable_symbol.get_type_symbol(),
+ rhs_type_symbol):
+ TypeCaster.try_to_recover_or_error(lhs_variable_symbol.get_type_symbol(), rhs_type_symbol,
+ node.get_expression())
+
+ def visit_function_call(self, node):
+ """
+ Check consistency for a single function call: check if the called function has been declared, whether the number and types of arguments correspond to the declaration, etc.
+
+ :param node: a single function call.
+ :type node: ASTFunctionCall
+ """
+ func_name = node.get_name()
+
+ if func_name == 'convolve':
+ return
+
+ symbol = node.get_scope().resolve_to_symbol(node.get_name(), SymbolKind.FUNCTION)
+
+ if symbol is None and ASTUtils.is_function_delay_variable(node):
+ return
+
+ # first check if the function has been declared
+ if symbol is None:
+ code, message = Messages.get_function_not_declared(node.get_name())
+ Logger.log_message(error_position=node.get_source_position(), log_level=LoggingLevel.ERROR,
+ code=code, message=message)
+ return
+
+ # check if the number of arguments is the same as in the symbol; accept anything for variadic types
+ is_variadic: bool = len(symbol.get_parameter_types()) == 1 and isinstance(symbol.get_parameter_types()[0], VariadicTypeSymbol)
+ if (not is_variadic) and len(node.get_args()) != len(symbol.get_parameter_types()):
+ code, message = Messages.get_wrong_number_of_args(str(node), len(symbol.get_parameter_types()),
+ len(node.get_args()))
+ Logger.log_message(code=code, message=message, log_level=LoggingLevel.ERROR,
+ error_position=node.get_source_position())
+ return
+
+ # finally check if the call is correctly typed
+ expected_types = symbol.get_parameter_types()
+ actual_args = node.get_args()
+ actual_types = [arg.type for arg in actual_args]
+ for actual_arg, actual_type, expected_type in zip(actual_args, actual_types, expected_types):
+ if isinstance(actual_type, ErrorTypeSymbol):
+ code, message = Messages.get_type_could_not_be_derived(actual_arg)
+ Logger.log_message(code=code, message=message, log_level=LoggingLevel.ERROR,
+ error_position=actual_arg.get_source_position())
+ return
+
+ if isinstance(expected_type, VariadicTypeSymbol):
+ # variadic type symbol accepts anything
+ return
+
+ if not actual_type.equals(expected_type) and not isinstance(expected_type, TemplateTypeSymbol):
+ TypeCaster.try_to_recover_or_error(expected_type, actual_type, actual_arg)
diff --git a/pynestml/transformers/synapse_post_neuron_transformer.py b/pynestml/transformers/synapse_post_neuron_transformer.py
index 68cc70a62..a06260824 100644
--- a/pynestml/transformers/synapse_post_neuron_transformer.py
+++ b/pynestml/transformers/synapse_post_neuron_transformer.py
@@ -23,6 +23,7 @@
from typing import Any, Sequence, Mapping, Optional, Union
+from pynestml.cocos.co_cos_manager import CoCosManager
from pynestml.frontend.frontend_configuration import FrontendConfiguration
from pynestml.meta_model.ast_assignment import ASTAssignment
from pynestml.meta_model.ast_equations_block import ASTEquationsBlock
@@ -563,11 +564,6 @@ def mark_post_port(_expr=None):
# replace occurrences of the variables in expressions in the original synapse with calls to the corresponding neuron getters
#
- # make sure the moved symbols can be resolved in the scope of the neuron (that's where ``ASTExternalVariable._altscope`` will be pointing to)
- ast_symbol_table_visitor = ASTSymbolTableVisitor()
- ast_symbol_table_visitor.after_ast_rewrite_ = True
- new_neuron.accept(ast_symbol_table_visitor)
-
Logger.log_message(
None, -1, "In synapse: replacing variables with suffixed external variable references", None, LoggingLevel.INFO)
for state_var in syn_to_neuron_state_vars:
@@ -609,7 +605,6 @@ def mark_post_port(_expr=None):
new_neuron.accept(ASTParentVisitor())
new_synapse.accept(ASTParentVisitor())
ast_symbol_table_visitor = ASTSymbolTableVisitor()
- ast_symbol_table_visitor.after_ast_rewrite_ = True
new_neuron.accept(ast_symbol_table_visitor)
new_synapse.accept(ast_symbol_table_visitor)
diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py
index d3d6f6ef5..a3983694d 100644
--- a/pynestml/utils/ast_utils.py
+++ b/pynestml/utils/ast_utils.py
@@ -28,7 +28,6 @@
from pynestml.codegeneration.printers.ast_printer import ASTPrinter
from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter
-from pynestml.codegeneration.printers.nestml_printer import NESTMLPrinter
from pynestml.frontend.frontend_configuration import FrontendConfiguration
from pynestml.generated.PyNestMLLexer import PyNestMLLexer
from pynestml.meta_model.ast_assignment import ASTAssignment
@@ -66,7 +65,6 @@
from pynestml.utils.messages import Messages
from pynestml.utils.string_utils import removesuffix
from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor
-from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
from pynestml.visitors.ast_visitor import ASTVisitor
@@ -1766,10 +1764,12 @@ def remove_initial_values_for_kernels(cls, model: ASTModel) -> None:
@classmethod
def update_initial_values_for_odes(cls, model: ASTModel, solver_dicts: List[dict]) -> None:
"""
- Update initial values for original ODE declarations (e.g. V_m', g_ahp'') that are present in the model
- before ODE-toolbox processing, with the formatted variable names and initial values returned by ODE-toolbox.
+ Update initial values for original ODE declarations (e.g. V_m', g_ahp'') that are present in the model before ODE-toolbox processing, with the formatted variable names and initial values returned by ODE-toolbox.
"""
from pynestml.utils.model_parser import ModelParser
+ from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
+ from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
+
assert len(model.get_equations_blocks()) == 1, "Only one equation block should be present"
if not model.get_state_blocks():
@@ -1782,10 +1782,6 @@ def update_initial_values_for_odes(cls, model: ASTModel, solver_dicts: List[dict
if cls.is_ode_variable(var.get_name(), model):
assert cls.variable_in_solver(cls.to_ode_toolbox_processed_name(var_name), solver_dicts)
- # replace the left-hand side variable name by the ode-toolbox format
- var.set_name(cls.to_ode_toolbox_processed_name(var.get_complete_name()))
- var.set_differential_order(0)
-
# replace the defining expression by the ode-toolbox result
iv_expr = cls.get_initial_value_from_ode_toolbox_result(
cls.to_ode_toolbox_processed_name(var_name), solver_dicts)
@@ -1794,6 +1790,9 @@ def update_initial_values_for_odes(cls, model: ASTModel, solver_dicts: List[dict
iv_expr.update_scope(state_block.get_scope())
iv_decl.set_expression(iv_expr)
+ model.accept(ASTParentVisitor())
+ model.accept(ASTSymbolTableVisitor())
+
@classmethod
def integrate_odes_args_strs_from_function_call(cls, function_call: ASTFunctionCall):
arg_names = []
@@ -2296,6 +2295,7 @@ def replace_convolve_calls_with_buffers_(cls, model: ASTModel, equations_block:
r"""
Replace all occurrences of `convolve(kernel[']^n, spike_input_port)` with the corresponding buffer variable, e.g. `g_E__X__spikes_exc[__d]^n` for a kernel named `g_E` and a spike input port named `spikes_exc`.
"""
+ from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
def replace_function_call_through_var(_expr=None):
if _expr.is_function_call() and _expr.get_function_call().get_name() == "convolve":
@@ -2326,6 +2326,7 @@ def func(x):
return replace_function_call_through_var(x) if isinstance(x, ASTSimpleExpression) else True
equations_block.accept(ASTHigherOrderVisitor(func))
+ equations_block.accept(ASTSymbolTableVisitor())
@classmethod
def update_blocktype_for_common_parameters(cls, node):
diff --git a/pynestml/utils/logger.py b/pynestml/utils/logger.py
index 06e95b804..8404f1245 100644
--- a/pynestml/utils/logger.py
+++ b/pynestml/utils/logger.py
@@ -19,7 +19,7 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
-from typing import List, Mapping, Optional, Tuple
+from typing import List, Mapping, Optional, Tuple, Union
from collections import OrderedDict
from enum import Enum
@@ -75,6 +75,7 @@ class Logger:
def init_logger(cls, logging_level: LoggingLevel):
"""
Initializes the logger.
+
:param logging_level: the logging level as required
:type logging_level: LoggingLevel
"""
@@ -82,7 +83,6 @@ def init_logger(cls, logging_level: LoggingLevel):
cls.curr_message = 0
cls.log = {}
cls.log_frozen = False
- return
@classmethod
def freeze_log(cls, do_freeze: bool = True):
@@ -95,6 +95,7 @@ def freeze_log(cls, do_freeze: bool = True):
def get_log(cls) -> Mapping[int, Tuple[ASTNode, LoggingLevel, str]]:
"""
Returns the overall log of messages. The structure of the log is: (NODE, LEVEL, MESSAGE)
+
:return: mapping from id to ASTNode, log level and message.
"""
return cls.log
@@ -103,6 +104,7 @@ def get_log(cls) -> Mapping[int, Tuple[ASTNode, LoggingLevel, str]]:
def set_log(cls, log, counter):
"""
Restores log from the 'log' variable
+
:param log: the log
:param counter: the counter
"""
@@ -113,20 +115,19 @@ def set_log(cls, log, counter):
def log_message(cls, node: ASTNode = None, code: MessageCode = None, message: str = None, error_position: ASTSourceLocation = None, log_level: LoggingLevel = None):
"""
Logs the handed over message on the handed over node. If the current logging is appropriate, the message is also printed.
+
:param node: the node in which the error occurred
:param code: a single error code
- :type code: ErrorCode
:param error_position: the position on which the error occurred.
- :type error_position: SourcePosition
:param message: a message.
- :type message: str
:param log_level: the corresponding log level.
- :type log_level: LoggingLevel
"""
if cls.log_frozen:
return
+
if cls.curr_message is None:
cls.init_logger(LoggingLevel.INFO)
+
from pynestml.meta_model.ast_node import ASTNode
from pynestml.utils.ast_source_location import ASTSourceLocation
assert (node is None or isinstance(node, ASTNode)), \
@@ -134,15 +135,23 @@ def log_message(cls, node: ASTNode = None, code: MessageCode = None, message: st
assert (error_position is None or isinstance(error_position, ASTSourceLocation)), \
'(PyNestML.Logger) Wrong type of error position provided (%s)!' % type(error_position)
from pynestml.meta_model.ast_model import ASTModel
+
if isinstance(node, ASTModel):
cls.log[cls.curr_message] = (
node.get_artifact_name(), node, log_level, code, error_position, message)
- elif cls.current_node is not None:
- cls.log[cls.curr_message] = (cls.current_node.get_artifact_name(), cls.current_node,
+ else:
+ if cls.current_node is not None:
+ artifact_name = cls.current_node.get_artifact_name()
+ else:
+ artifact_name = ""
+
+ cls.log[cls.curr_message] = (artifact_name, cls.current_node,
log_level, code, error_position, message)
+
cls.curr_message += 1
if cls.no_print:
return
+
if cls.logging_level.value <= log_level.value:
if isinstance(node, ASTInlineExpression):
node_name = node.variable_name
@@ -163,10 +172,9 @@ def log_message(cls, node: ASTNode = None, code: MessageCode = None, message: st
def string_to_level(cls, string: str) -> LoggingLevel:
"""
Returns the logging level corresponding to the handed over string. If no such exits, returns None.
+
:param string: a single string representing the level.
- :type string: str
:return: a single logging level.
- :rtype: LoggingLevel
"""
if string == 'DEBUG':
return LoggingLevel.DEBUG
@@ -183,7 +191,7 @@ def string_to_level(cls, string: str) -> LoggingLevel:
if string == 'NO' or string == 'NONE':
return LoggingLevel.NO
- raise Exception('Tried to convert unknown string \"' + string + '\" to logging level')
+ raise Exception("Tried to convert unknown string '" + string + "' to logging level")
@classmethod
def level_to_string(cls, level: LoggingLevel) -> str:
@@ -207,7 +215,7 @@ def level_to_string(cls, level: LoggingLevel) -> str:
if level == LoggingLevel.NO:
return 'NO'
- raise Exception('Tried to convert unknown logging level \"' + str(level) + '\" to string')
+ raise Exception("Tried to convert unknown logging level '" + str(level) + "' to string")
@classmethod
def set_logging_level(cls, level: LoggingLevel) -> None:
@@ -218,79 +226,89 @@ def set_logging_level(cls, level: LoggingLevel) -> None:
"""
if cls.log_frozen:
return
+
cls.logging_level = level
@classmethod
def set_current_node(cls, node: Optional[ASTNode]) -> None:
"""
- Sets the handed over node as the currently processed one. This enables a retrieval of messages for a
- specific node.
- :param node: a single node instance
+ Sets the handed over node as the currently processed one. This enables a retrieval of messages for a specific node.
+
+ :param node: a single node instance
"""
cls.current_node = node
@classmethod
- def get_all_messages_of_level_and_or_node(cls, node: ASTNode, level: LoggingLevel) -> List[Tuple[ASTNode, LoggingLevel, str]]:
+ def get_all_messages_of_level_and_or_node(cls, node: Union[ASTNode, str], level: LoggingLevel) -> List[Tuple[ASTNode, LoggingLevel, str]]:
"""
- Returns all messages which have a certain logging level, or have been reported for a certain node, or
- both.
+ Returns all messages which have a certain logging level, or have been reported for a certain node, or both.
+
:param node: a single node instance
:param level: a logging level
- :type level: LoggingLevel
:return: a list of messages with their levels.
- :rtype: list((str,Logging_Level)
"""
if level is None and node is None:
return cls.get_log()
+
+ if isinstance(node, str):
+ # search by artifact name
+ node_artifact_name = node
+ node = None
+ else:
+ # search by artifact class object
+ node_artifact_name = None
+
ret = list()
for (artifactName, node_i, logLevel, code, errorPosition, message) in cls.log.values():
- if (level == logLevel if level is not None else True) and (
- node if node is not None else True) and (
- node.get_artifact_name() == artifactName if node is not None else True):
+ if (level == logLevel if level is not None else True) and (node if node is not None else True) and (node_artifact_name == artifactName if node is not None else True):
ret.append((node, logLevel, message))
+
return ret
@classmethod
def get_all_messages_of_level(cls, level: LoggingLevel) -> List[Tuple[ASTNode, LoggingLevel, str]]:
"""
Returns all messages which have a certain logging level.
+
:param level: a logging level
- :type level: LoggingLevel
:return: a list of messages with their levels.
- :rtype: list((str,Logging_Level)
"""
if level is None:
return cls.get_log()
+
ret = list()
for (artifactName, node, logLevel, code, errorPosition, message) in cls.log.values():
if level == logLevel:
ret.append((node, logLevel, message))
+
return ret
@classmethod
def get_all_messages_of_node(cls, node: ASTNode) -> List[Tuple[ASTNode, LoggingLevel, str]]:
"""
Returns all messages which have been reported for a certain node.
+
:param node: a single node instance
:return: a list of messages with their levels.
- :rtype: list((str,Logging_Level)
"""
if node is None:
return cls.get_log()
+
ret = list()
for (artifactName, node_i, logLevel, code, errorPosition, message) in cls.log.values():
if (node_i == node if node is not None else True) and \
(node.get_artifact_name() == artifactName if node is not None else True):
ret.append((node, logLevel, message))
+
return ret
@classmethod
def has_errors(cls, node: ASTNode) -> bool:
"""
Indicates whether the handed over node, thus the corresponding model, has errors.
+
:param node: a single node instance.
:return: True if errors detected, otherwise False
- :rtype: bool
"""
return len(cls.get_all_messages_of_level_and_or_node(node, LoggingLevel.ERROR)) > 0
@@ -311,6 +329,7 @@ def get_json_format(cls) -> str:
(node.get_name() if node is not None else 'GLOBAL') + '", ' + \
'"severity":"' \
+ str(logLevel.name) + '", '
+
if code is not None:
ret += '"code":"' + \
code.name + \
@@ -323,10 +342,12 @@ def get_json_format(cls) -> str:
'", ' + \
'"message":"' + str(message).replace('"', "'") + '"}'
ret += ','
+
if len(cls.log.keys()) == 0:
parsed = json.loads('[]', object_pairs_hook=OrderedDict)
else:
ret = ret[:-1] # delete the last ","
ret += ']'
parsed = json.loads(ret, object_pairs_hook=OrderedDict)
+
return json.dumps(parsed, indent=2, sort_keys=False)
diff --git a/pynestml/utils/mechs_info_enricher.py b/pynestml/utils/mechs_info_enricher.py
index 456ece178..ea645a02c 100644
--- a/pynestml/utils/mechs_info_enricher.py
+++ b/pynestml/utils/mechs_info_enricher.py
@@ -22,13 +22,14 @@
from collections import defaultdict
from pynestml.meta_model.ast_model import ASTModel
+from pynestml.symbols.predefined_functions import PredefinedFunctions
+from pynestml.symbols.symbol import SymbolKind
+from pynestml.utils.ast_vector_parameter_setter_and_printer_factory import ASTVectorParameterSetterAndPrinterFactory
from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor
from pynestml.utils.ast_utils import ASTUtils
-from pynestml.visitors.ast_visitor import ASTVisitor
from pynestml.utils.model_parser import ModelParser
-from pynestml.symbols.predefined_functions import PredefinedFunctions
-from pynestml.symbols.symbol import SymbolKind
+from pynestml.visitors.ast_visitor import ASTVisitor
class MechsInfoEnricher:
@@ -57,33 +58,6 @@ def transform_ode_solutions(cls, neuron, mechs_info):
solution_transformed["states"] = defaultdict()
solution_transformed["propagators"] = defaultdict()
- for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["initial_values"].items():
- variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name,
- SymbolKind.VARIABLE)
-
- expression = ModelParser.parse_expression(rhs_str)
- # pretend that update expressions are in "equations" block,
- # which should always be present, as synapses have been
- # defined to get here
- expression.update_scope(neuron.get_equations_blocks()[0].get_scope())
- expression.accept(ASTSymbolTableVisitor())
-
- update_expr_str = ode_info["ode_toolbox_output"][ode_solution_index]["update_expressions"][
- variable_name]
- update_expr_ast = ModelParser.parse_expression(
- update_expr_str)
- # pretend that update expressions are in "equations" block,
- # which should always be present, as differential equations
- # must have been defined to get here
- update_expr_ast.update_scope(
- neuron.get_equations_blocks()[0].get_scope())
- update_expr_ast.accept(ASTSymbolTableVisitor())
-
- solution_transformed["states"][variable_name] = {
- "ASTVariable": variable,
- "init_expression": expression,
- "update_expression": update_expr_ast,
- }
for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["propagators"].items():
prop_variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name,
SymbolKind.VARIABLE)
@@ -118,6 +92,36 @@ def transform_ode_solutions(cls, neuron, mechs_info):
PredefinedFunctions.TIME_RESOLUTION:
mechanism_info["time_resolution_var"] = variable
+ for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["initial_values"].items():
+ variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name,
+ SymbolKind.VARIABLE)
+
+ expression = ModelParser.parse_expression(rhs_str)
+ # pretend that update expressions are in "equations" block,
+ # which should always be present, as synapses have been
+ # defined to get here
+ expression.update_scope(neuron.get_equations_blocks()[0].get_scope())
+ expression.accept(ASTSymbolTableVisitor())
+
+ update_expr_str = ode_info["ode_toolbox_output"][ode_solution_index]["update_expressions"][
+ variable_name]
+ update_expr_ast = ModelParser.parse_expression(
+ update_expr_str)
+ # pretend that update expressions are in "equations" block,
+ # which should always be present, as differential equations
+ # must have been defined to get here
+ update_expr_ast.update_scope(
+ neuron.get_scope())
+ update_expr_ast.accept(ASTParentVisitor())
+ update_expr_ast.accept(ASTSymbolTableVisitor())
+ neuron.accept(ASTSymbolTableVisitor())
+
+ solution_transformed["states"][variable_name] = {
+ "ASTVariable": variable,
+ "init_expression": expression,
+ "update_expression": update_expr_ast,
+ }
+
mechanism_info["ODEs"][ode_var_name]["transformed_solutions"].append(solution_transformed)
neuron.accept(ASTParentVisitor())
diff --git a/pynestml/utils/messages.py b/pynestml/utils/messages.py
index 69b32a8f4..90efd06b7 100644
--- a/pynestml/utils/messages.py
+++ b/pynestml/utils/messages.py
@@ -18,11 +18,15 @@
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
-from enum import Enum
+
+from __future__ import annotations
+
from typing import Tuple
-from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
from collections.abc import Iterable
+from enum import Enum
+
+from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
from pynestml.meta_model.ast_function import ASTFunction
@@ -132,6 +136,7 @@ class MessageCode(Enum):
VOID_FUNCTION_ON_RHS = 110
NON_CONSTANT_EXPONENT = 111
EXPONENT_MUST_BE_INTEGER = 112
+ RANDOM_FUNCTIONS_LEGALLY_USED = 113
class Messages:
@@ -158,8 +163,8 @@ def get_input_path_not_found(cls, path):
return MessageCode.INPUT_PATH_NOT_FOUND, message
@classmethod
- def get_unknown_target(cls, target):
- message = 'Unknown target ("%s")' % (target)
+ def get_unknown_target_platform(cls, target: str):
+ message = "Unknown target: '" + target + "'"
return MessageCode.UNKNOWN_TARGET, message
@classmethod
@@ -313,22 +318,13 @@ def get_different_type_rhs_lhs(
return MessageCode.CAST_NOT_POSSIBLE, message
@classmethod
- def get_type_different_from_expected(cls, expected_type, got_type):
+ def get_type_different_from_expected(cls, expected_type, got_type) -> Tuple[MessageCode, str]:
"""
Returns a message indicating that the received type is different from the expected one.
:param expected_type: the expected type
- :type expected_type: TypeSymbol
:param got_type: the actual type
- :type got_type: type_symbol
:return: a message
- :rtype: (MessageCode,str)
"""
- from pynestml.symbols.type_symbol import TypeSymbol
- assert (expected_type is not None and isinstance(expected_type, TypeSymbol)), \
- '(PyNestML.Utils.Message) Not a type symbol provided (%s)!' % type(
- expected_type)
- assert (got_type is not None and isinstance(got_type, TypeSymbol)), \
- '(PyNestML.Utils.Message) Not a type symbol provided (%s)!' % type(got_type)
message = 'Actual type different from expected. Expected: \'%s\', got: \'%s\'!' % (
expected_type.print_symbol(), got_type.print_symbol())
return MessageCode.TYPE_DIFFERENT_FROM_EXPECTED, message
@@ -430,11 +426,10 @@ def get_module_generated(cls, path: str) -> Tuple[MessageCode, str]:
return MessageCode.MODULE_SUCCESSFULLY_GENERATED, message
@classmethod
- def get_variable_used_before_declaration(cls, variable_name):
+ def get_variable_used_before_declaration(cls, variable_name: str):
"""
Returns a message indicating that a variable is used before declaration.
:param variable_name: a variable name
- :type variable_name: str
:return: a message
:rtype: (MessageCode,str)
"""
@@ -701,7 +696,7 @@ def get_model_redeclared(cls, name: str) -> Tuple[MessageCode, str]:
'(PyNestML.Utils.Message) Not a string provided (%s)!' % type(name)
assert (name is not None and isinstance(name, str)), \
'(PyNestML.Utils.Message) Not a string provided (%s)!' % type(name)
- message = 'model \'%s\' redeclared!' % name
+ message = 'Model \'%s\' redeclared!' % name
return MessageCode.MODEL_REDECLARED, message
@classmethod
@@ -1375,3 +1370,8 @@ def get_non_constant_exponent(cls) -> Tuple[MessageCode, str]:
message = "Cannot calculate value of exponent. Must be a constant value!"
return MessageCode.NON_CONSTANT_EXPONENT, message
+
+ @classmethod
+ def get_random_functions_legally_used(cls, name):
+ message = "The function '" + name + "' can only be used in the update, onReceive, or onCondition blocks."
+ return MessageCode.RANDOM_FUNCTIONS_LEGALLY_USED, message
diff --git a/pynestml/utils/model_parser.py b/pynestml/utils/model_parser.py
index 7fabf361e..62a8669bb 100644
--- a/pynestml/utils/model_parser.py
+++ b/pynestml/utils/model_parser.py
@@ -24,6 +24,7 @@
from antlr4 import CommonTokenStream, FileStream, InputStream
from antlr4.error.ErrorStrategy import BailErrorStrategy, DefaultErrorStrategy
from antlr4.error.ErrorListener import ConsoleErrorListener
+from pynestml.cocos.co_cos_manager import CoCosManager
from pynestml.generated.PyNestMLLexer import PyNestMLLexer
from pynestml.generated.PyNestMLParser import PyNestMLParser
@@ -65,6 +66,7 @@
from pynestml.meta_model.ast_variable import ASTVariable
from pynestml.meta_model.ast_while_stmt import ASTWhileStmt
from pynestml.symbol_table.symbol_table import SymbolTable
+from pynestml.transformers.assign_implicit_conversion_factors_transformer import AssignImplicitConversionFactorsTransformer
from pynestml.utils.ast_source_location import ASTSourceLocation
from pynestml.utils.error_listener import NestMLErrorListener
from pynestml.utils.logger import Logger, LoggingLevel
@@ -142,10 +144,14 @@ def parse_file(cls, file_path=None):
for model in ast.get_model_list():
model.accept(ASTSymbolTableVisitor())
SymbolTable.add_model_scope(model.get_name(), model.get_scope())
+ Logger.set_current_node(model)
+ AssignImplicitConversionFactorsTransformer().transform(model)
+ Logger.set_current_node(None)
# store source paths
for model in ast.get_model_list():
model.file_path = file_path
+
ast.file_path = file_path
return ast
diff --git a/pynestml/utils/type_caster.py b/pynestml/utils/type_caster.py
index 34e4e6ccc..4ce2624dd 100644
--- a/pynestml/utils/type_caster.py
+++ b/pynestml/utils/type_caster.py
@@ -28,12 +28,11 @@ class TypeCaster:
@staticmethod
def do_magnitude_conversion_rhs_to_lhs(_rhs_type_symbol, _lhs_type_symbol, _containing_expression):
"""
- determine conversion factor from rhs to lhs, register it with the relevant expression
+ Determine conversion factor from rhs to lhs, register it with the relevant expression
"""
_containing_expression.set_implicit_conversion_factor(
- UnitTypeSymbol.get_conversion_factor(_lhs_type_symbol.astropy_unit,
- _rhs_type_symbol.astropy_unit))
- _containing_expression.type = _lhs_type_symbol
+ UnitTypeSymbol.get_conversion_factor(_rhs_type_symbol.astropy_unit,
+ _lhs_type_symbol.astropy_unit))
code, message = Messages.get_implicit_magnitude_conversion(_lhs_type_symbol, _rhs_type_symbol,
_containing_expression.get_implicit_conversion_factor())
Logger.log_message(code=code, message=message,
@@ -45,18 +44,26 @@ def try_to_recover_or_error(_lhs_type_symbol, _rhs_type_symbol, _containing_expr
if _rhs_type_symbol.is_castable_to(_lhs_type_symbol):
if isinstance(_lhs_type_symbol, UnitTypeSymbol) \
and isinstance(_rhs_type_symbol, UnitTypeSymbol):
- conversion_factor = UnitTypeSymbol.get_conversion_factor(
- _lhs_type_symbol.astropy_unit, _rhs_type_symbol.astropy_unit)
+ conversion_factor = UnitTypeSymbol.get_conversion_factor(_rhs_type_symbol.astropy_unit, _lhs_type_symbol.astropy_unit)
+
+ if conversion_factor is None:
+ # error during conversion
+ code, message = Messages.get_type_different_from_expected(_lhs_type_symbol, _rhs_type_symbol)
+ Logger.log_message(error_position=_containing_expression.get_source_position(),
+ code=code, message=message, log_level=LoggingLevel.ERROR)
+ return
+
if not conversion_factor == 1.:
# the units are mutually convertible, but require a factor unequal to 1 (e.g. mV and A*Ohm)
- TypeCaster.do_magnitude_conversion_rhs_to_lhs(
- _rhs_type_symbol, _lhs_type_symbol, _containing_expression)
+ TypeCaster.do_magnitude_conversion_rhs_to_lhs(_rhs_type_symbol, _lhs_type_symbol, _containing_expression)
+
# the units are mutually convertible (e.g. V and A*Ohm)
code, message = Messages.get_implicit_cast_rhs_to_lhs(_rhs_type_symbol.print_symbol(),
_lhs_type_symbol.print_symbol())
Logger.log_message(error_position=_containing_expression.get_source_position(),
code=code, message=message, log_level=LoggingLevel.INFO)
- else:
- code, message = Messages.get_type_different_from_expected(_lhs_type_symbol, _rhs_type_symbol)
- Logger.log_message(error_position=_containing_expression.get_source_position(),
- code=code, message=message, log_level=LoggingLevel.ERROR)
+ return
+
+ code, message = Messages.get_type_different_from_expected(_lhs_type_symbol, _rhs_type_symbol)
+ Logger.log_message(error_position=_containing_expression.get_source_position(),
+ code=code, message=message, log_level=LoggingLevel.ERROR)
diff --git a/pynestml/visitors/ast_builder_visitor.py b/pynestml/visitors/ast_builder_visitor.py
index 0e766d530..bfc4dd902 100644
--- a/pynestml/visitors/ast_builder_visitor.py
+++ b/pynestml/visitors/ast_builder_visitor.py
@@ -52,16 +52,17 @@ def visitNestMLCompilationUnit(self, ctx):
models = list()
for child in ctx.model():
models.append(self.visit(child))
+
# extract the name of the artifact from the context
if hasattr(ctx.start.source[1], 'fileName'):
artifact_name = ntpath.basename(ctx.start.source[1].fileName)
else:
artifact_name = 'parsed_from_string'
+
compilation_unit = ASTNodeFactory.create_ast_nestml_compilation_unit(list_of_models=models,
source_position=create_source_pos(ctx),
artifact_name=artifact_name)
- # first ensure certain properties of the model
- CoCosManager.check_model_names_unique(compilation_unit)
+
return compilation_unit
# Visit a parse tree produced by PyNESTMLParser#datatype.
@@ -387,15 +388,6 @@ def visitDeclaration(self, ctx):
expression = self.visit(ctx.rhs) if ctx.rhs is not None else None
invariant = self.visit(ctx.invariant) if ctx.invariant is not None else None
- # print("Visiting variable \"" + str(str(ctx.NAME())) + "\"...")
- # # check if this variable was decorated as homogeneous
- # import pynestml.generated.PyNestMLLexer
- # is_homogeneous = any([isinstance(ch, pynestml.generated.PyNestMLParser.PyNestMLParser.AnyDecoratorContext) \
- # and len(ch.getTokens(pynestml.generated.PyNestMLLexer.PyNestMLLexer.DECORATOR_HOMOGENEOUS)) > 0 \
- # for ch in ctx.parentCtx.children])
- # if is_homogeneous:
- # print("\t----> is homogeneous")
-
declaration = ASTNodeFactory.create_ast_declaration(is_recordable=is_recordable,
variables=variables,
data_type=data_type,
diff --git a/pynestml/visitors/ast_function_call_visitor.py b/pynestml/visitors/ast_function_call_visitor.py
index 7d7bf75c4..e4ec8650e 100644
--- a/pynestml/visitors/ast_function_call_visitor.py
+++ b/pynestml/visitors/ast_function_call_visitor.py
@@ -94,7 +94,6 @@ def visit_simple_expression(self, node: ASTSimpleExpression) -> None:
# return type of the convolve function is the type of the second parameter multiplied by the unit of time (s)
if function_name == PredefinedFunctions.CONVOLVE:
- # Deviations from the assumptions made here are handled in the convolveCoco
buffer_parameter = node.get_function_call().get_args()[1]
if buffer_parameter.get_variable() is not None:
diff --git a/pynestml/visitors/ast_symbol_table_visitor.py b/pynestml/visitors/ast_symbol_table_visitor.py
index 011182543..bc85d4cdd 100644
--- a/pynestml/visitors/ast_symbol_table_visitor.py
+++ b/pynestml/visitors/ast_symbol_table_visitor.py
@@ -19,7 +19,6 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
-from pynestml.cocos.co_cos_manager import CoCosManager
from pynestml.meta_model.ast_model import ASTModel
from pynestml.meta_model.ast_model_body import ASTModelBody
from pynestml.meta_model.ast_namespace_decorator import ASTNamespaceDecorator
@@ -53,7 +52,6 @@ def __init__(self):
self.symbol_stack = Stack()
self.scope_stack = Stack()
self.block_type_stack = Stack()
- self.after_ast_rewrite_ = False
def visit_model(self, node: ASTModel) -> None:
"""
@@ -79,10 +77,6 @@ def visit_model(self, node: ASTModel) -> None:
node.get_scope().add_symbol(types[symbol])
def endvisit_model(self, node: ASTModel):
- # before following checks occur, we need to ensure several simple properties
- CoCosManager.post_symbol_table_builder_checks(
- node, after_ast_rewrite=self.after_ast_rewrite_)
-
# update the equations
for equation_block in node.get_equations_blocks():
ASTUtils.assign_ode_to_variables(equation_block)
@@ -287,8 +281,7 @@ def visit_declaration(self, node: ASTDeclaration) -> None:
namespace_decorators = {}
for d in node.get_decorators():
if isinstance(d, ASTNamespaceDecorator):
- namespace_decorators[str(d.get_namespace())] = str(
- d.get_name())
+ namespace_decorators[str(d.get_namespace())] = str(d.get_name())
else:
decorators.append(d)
@@ -296,6 +289,7 @@ def visit_declaration(self, node: ASTDeclaration) -> None:
block_type = None
if not self.block_type_stack.is_empty():
block_type = self.block_type_stack.top()
+
for var in node.get_variables(): # for all variables declared create a new symbol
var.update_scope(node.get_scope())
@@ -324,11 +318,14 @@ def visit_declaration(self, node: ASTDeclaration) -> None:
symbol.set_comment(node.get_comment())
node.get_scope().add_symbol(symbol)
var.set_type_symbol(type_symbol)
+
# the data type
node.get_data_type().update_scope(node.get_scope())
+
# the rhs update
if node.has_expression():
node.get_expression().update_scope(node.get_scope())
+
# the invariant update
if node.has_invariant():
node.get_invariant().update_scope(node.get_scope())
diff --git a/tests/cocos_test.py b/tests/cocos_test.py
deleted file mode 100644
index f557faaf0..000000000
--- a/tests/cocos_test.py
+++ /dev/null
@@ -1,698 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# cocos_test.py
-#
-# This file is part of NEST.
-#
-# Copyright (C) 2004 The NEST Initiative
-#
-# NEST is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 2 of the License, or
-# (at your option) any later version.
-#
-# NEST is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with NEST. If not, see .
-
-from __future__ import print_function
-
-import os
-import unittest
-
-from pynestml.utils.ast_source_location import ASTSourceLocation
-from pynestml.symbol_table.symbol_table import SymbolTable
-from pynestml.symbols.predefined_functions import PredefinedFunctions
-from pynestml.symbols.predefined_types import PredefinedTypes
-from pynestml.symbols.predefined_units import PredefinedUnits
-from pynestml.symbols.predefined_variables import PredefinedVariables
-from pynestml.utils.logger import LoggingLevel, Logger
-from pynestml.utils.model_parser import ModelParser
-
-
-class CoCosTest(unittest.TestCase):
-
- def setUp(self):
- Logger.init_logger(LoggingLevel.INFO)
- SymbolTable.initialize_symbol_table(
- ASTSourceLocation(
- start_line=0,
- start_column=0,
- end_line=0,
- end_column=0))
- PredefinedUnits.register_units()
- PredefinedTypes.register_types()
- PredefinedVariables.register_variables()
- PredefinedFunctions.register_functions()
-
- def test_invalid_element_defined_after_usage(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoVariableDefinedAfterUsage.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_element_defined_after_usage(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVariableDefinedAfterUsage.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_element_in_same_line(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoElementInSameLine.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_element_in_same_line(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoElementInSameLine.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_integrate_odes_called_if_equations_defined(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoIntegrateOdesCalledIfEquationsDefined.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_integrate_odes_called_if_equations_defined(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoIntegrateOdesCalledIfEquationsDefined.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_element_not_defined_in_scope(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoVariableNotDefined.nestml'))
- self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)), 5)
-
- def test_valid_element_not_defined_in_scope(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVariableNotDefined.nestml'))
- self.assertEqual(
- len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)),
- 0)
-
- def test_variable_with_same_name_as_unit(self):
- Logger.set_logging_level(LoggingLevel.NO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVariableWithSameNameAsUnit.nestml'))
- self.assertEqual(
- len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.WARNING)),
- 3)
-
- def test_invalid_variable_redeclaration(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoVariableRedeclared.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_variable_redeclaration(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVariableRedeclared.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_each_block_unique(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoEachBlockUnique.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2)
-
- def test_valid_each_block_unique(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoEachBlockUnique.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_function_unique_and_defined(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoFunctionNotUnique.nestml'))
- self.assertEqual(
- len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 5)
-
- def test_valid_function_unique_and_defined(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoFunctionNotUnique.nestml'))
- self.assertEqual(
- len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_inline_expressions_have_rhs(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoInlineExpressionHasNoRhs.nestml'))
- assert model is None
-
- def test_valid_inline_expressions_have_rhs(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoInlineExpressionHasNoRhs.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_inline_expression_has_several_lhs(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoInlineExpressionWithSeveralLhs.nestml'))
- assert model is None
-
- def test_valid_inline_expression_has_several_lhs(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoInlineExpressionWithSeveralLhs.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_no_values_assigned_to_input_ports(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoValueAssignedToInputPort.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_no_values_assigned_to_input_ports(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoValueAssignedToInputPort.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_order_of_equations_correct(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoNoOrderOfEquations.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2)
-
- def test_valid_order_of_equations_correct(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoNoOrderOfEquations.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_numerator_of_unit_one(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoUnitNumeratorNotOne.nestml'))
- self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)), 2)
-
- def test_valid_numerator_of_unit_one(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoUnitNumeratorNotOne.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_names_of_neurons_unique(self):
- Logger.init_logger(LoggingLevel.INFO)
- ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoMultipleNeuronsWithEqualName.nestml'))
- self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(None, LoggingLevel.ERROR)), 1)
-
- def test_valid_names_of_neurons_unique(self):
- Logger.init_logger(LoggingLevel.INFO)
- ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoMultipleNeuronsWithEqualName.nestml'))
- self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(None, LoggingLevel.ERROR)), 0)
-
- def test_invalid_no_nest_collision(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoNestNamespaceCollision.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_no_nest_collision(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoNestNamespaceCollision.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_redundant_input_port_keywords_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoInputPortWithRedundantTypes.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_redundant_input_port_keywords_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoInputPortWithRedundantTypes.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_parameters_assigned_only_in_parameters_block(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoParameterAssignedOutsideBlock.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_parameters_assigned_only_in_parameters_block(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoParameterAssignedOutsideBlock.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_inline_expressions_assigned_only_in_declaration(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoAssignmentToInlineExpression.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_invalid_internals_assigned_only_in_internals_block(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoInternalAssignedOutsideBlock.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_internals_assigned_only_in_internals_block(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoInternalAssignedOutsideBlock.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_function_with_wrong_arg_number_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoFunctionCallNotConsistentWrongArgNumber.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_function_with_wrong_arg_number_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoFunctionCallNotConsistentWrongArgNumber.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_init_values_have_rhs_and_ode(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoInitValuesWithoutOde.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.WARNING)), 2)
-
- def test_valid_init_values_have_rhs_and_ode(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoInitValuesWithoutOde.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.WARNING)), 2)
-
- def test_invalid_incorrect_return_stmt_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoIncorrectReturnStatement.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 4)
-
- def test_valid_incorrect_return_stmt_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoIncorrectReturnStatement.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_ode_vars_outside_init_block_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoOdeVarNotInInitialValues.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_ode_vars_outside_init_block_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoOdeVarNotInInitialValues.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_convolve_correctly_defined(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoConvolveNotCorrectlyProvided.nestml'))
- self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)), 3)
-
- def test_valid_convolve_correctly_defined(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoConvolveNotCorrectlyProvided.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_vector_in_non_vector_declaration_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoVectorInNonVectorDeclaration.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_vector_in_non_vector_declaration_detected(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVectorInNonVectorDeclaration.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_vector_parameter_declaration(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoVectorParameterDeclaration.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_vector_parameter_declaration(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVectorParameterDeclaration.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_vector_parameter_type(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoVectorParameterType.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_vector_parameter_type(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVectorParameterType.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_vector_parameter_size(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoVectorDeclarationSize.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2)
-
- def test_valid_vector_parameter_size(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVectorDeclarationSize.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_convolve_correctly_parameterized(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoConvolveNotCorrectlyParametrized.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2)
-
- def test_valid_convolve_correctly_parameterized(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoConvolveNotCorrectlyParametrized.nestml'))
- self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)), 0)
-
- def test_invalid_invariant_correctly_typed(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoInvariantNotBool.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_invariant_correctly_typed(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoInvariantNotBool.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_expression_correctly_typed(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoIllegalExpression.nestml'))
- self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)), 6)
-
- def test_valid_expression_correctly_typed(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoIllegalExpression.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_compound_expression_correctly_typed(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CompoundOperatorWithDifferentButCompatibleUnits.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 5)
-
- def test_valid_compound_expression_correctly_typed(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CompoundOperatorWithDifferentButCompatibleUnits.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_ode_correctly_typed(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoOdeIncorrectlyTyped.nestml'))
- self.assertTrue(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)) > 0)
-
- def test_valid_ode_correctly_typed(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoOdeCorrectlyTyped.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_output_block_defined_if_emit_call(self):
- """test that an error is raised when the emit_spike() function is called by the neuron, but an output block is not defined"""
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoOutputPortDefinedIfEmitCall.nestml'))
- self.assertTrue(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)) > 0)
-
- def test_invalid_output_port_defined_if_emit_call(self):
- """test that an error is raised when the emit_spike() function is called by the neuron, but a spiking output port is not defined"""
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoOutputPortDefinedIfEmitCall-2.nestml'))
- self.assertTrue(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)) > 0)
-
- def test_valid_output_port_defined_if_emit_call(self):
- """test that no error is raised when the output block is missing, but not emit_spike() functions are called"""
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoOutputPortDefinedIfEmitCall.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_valid_coco_kernel_type(self):
- """
- Test the functionality of CoCoKernelType.
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoKernelType.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_coco_kernel_type(self):
- """
- Test the functionality of CoCoKernelType.
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoKernelType.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_invalid_coco_kernel_type_initial_values(self):
- """
- Test the functionality of CoCoKernelType.
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoKernelTypeInitialValues.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 4)
-
- def test_valid_coco_state_variables_initialized(self):
- """
- Test that the CoCo condition is applicable for all the variables in the state block initialized with a value
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoStateVariablesInitialized.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_coco_state_variables_initialized(self):
- """
- Test that the CoCo condition is applicable for all the variables in the state block not initialized
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoStateVariablesInitialized.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2)
-
- def test_invalid_co_co_priorities_correctly_specified(self):
- """
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoPrioritiesCorrectlySpecified.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
-
- def test_valid_co_co_priorities_correctly_specified(self):
- """
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoPrioritiesCorrectlySpecified.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_co_co_resolution_legally_used(self):
- """
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoResolutionLegallyUsed.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2)
-
- def test_valid_co_co_resolution_legally_used(self):
- """
- """
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoResolutionLegallyUsed.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_valid_co_co_vector_input_port(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')),
- 'CoCoVectorInputPortSizeAndType.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
-
- def test_invalid_co_co_vector_input_port(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')),
- 'CoCoVectorInputPortSizeAndType.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1)
diff --git a/tests/function_parameter_templating_test.py b/tests/function_parameter_templating_test.py
deleted file mode 100644
index e3cb89e41..000000000
--- a/tests/function_parameter_templating_test.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# function_parameter_templating_test.py
-#
-# This file is part of NEST.
-#
-# Copyright (C) 2004 The NEST Initiative
-#
-# NEST is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 2 of the License, or
-# (at your option) any later version.
-#
-# NEST is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with NEST. If not, see .
-
-import os
-import unittest
-
-from pynestml.symbol_table.symbol_table import SymbolTable
-from pynestml.symbols.predefined_functions import PredefinedFunctions
-from pynestml.symbols.predefined_types import PredefinedTypes
-from pynestml.symbols.predefined_units import PredefinedUnits
-from pynestml.symbols.predefined_variables import PredefinedVariables
-from pynestml.utils.ast_source_location import ASTSourceLocation
-from pynestml.utils.logger import Logger, LoggingLevel
-from pynestml.utils.model_parser import ModelParser
-
-# minor setup steps required
-SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0))
-PredefinedUnits.register_units()
-PredefinedTypes.register_types()
-PredefinedVariables.register_variables()
-PredefinedFunctions.register_functions()
-
-
-class FunctionParameterTemplatingTest(unittest.TestCase):
- """
- This test is used to test the correct derivation of types when functions use templated type parameters.
- """
-
- def test(self):
- Logger.init_logger(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__),
- "resources", "FunctionParameterTemplatingTest.nestml"))))
- self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0],
- LoggingLevel.ERROR)), 7)
-
-
-if __name__ == '__main__':
- unittest.main()
diff --git a/tests/nest_compartmental_tests/test__cocos.py b/tests/nest_compartmental_tests/test__cocos.py
index dc4daa28c..7ee55f8a1 100644
--- a/tests/nest_compartmental_tests/test__cocos.py
+++ b/tests/nest_compartmental_tests/test__cocos.py
@@ -21,41 +21,39 @@
from __future__ import print_function
+from typing import Optional
+
import os
import pytest
-from pynestml.frontend.frontend_configuration import FrontendConfiguration
-
-from pynestml.utils.ast_source_location import ASTSourceLocation
+from pynestml.meta_model.ast_model import ASTModel
from pynestml.symbol_table.symbol_table import SymbolTable
from pynestml.symbols.predefined_functions import PredefinedFunctions
from pynestml.symbols.predefined_types import PredefinedTypes
from pynestml.symbols.predefined_units import PredefinedUnits
from pynestml.symbols.predefined_variables import PredefinedVariables
+from pynestml.utils.ast_source_location import ASTSourceLocation
from pynestml.utils.logger import LoggingLevel, Logger
from pynestml.utils.model_parser import ModelParser
-@pytest.fixture
-def setUp():
- Logger.init_logger(LoggingLevel.INFO)
- SymbolTable.initialize_symbol_table(
- ASTSourceLocation(
- start_line=0,
- start_column=0,
- end_line=0,
- end_column=0))
- PredefinedUnits.register_units()
- PredefinedTypes.register_types()
- PredefinedVariables.register_variables()
- PredefinedFunctions.register_functions()
- FrontendConfiguration.target_platform = "NEST_COMPARTMENTAL"
-
-
class TestCoCos:
- def test_invalid_cm_variables_declared(self, setUp):
- model = ModelParser.parse_file(
+ @pytest.fixture(scope="module", autouse=True)
+ def setUp(self):
+ SymbolTable.initialize_symbol_table(
+ ASTSourceLocation(
+ start_line=0,
+ start_column=0,
+ end_line=0,
+ end_column=0))
+ PredefinedUnits.register_units()
+ PredefinedTypes.register_types()
+ PredefinedVariables.register_variables()
+ PredefinedFunctions.register_functions()
+
+ def test_invalid_cm_variables_declared(self):
+ model = self._parse_and_validate_model(
os.path.join(
os.path.realpath(
os.path.join(
@@ -63,11 +61,10 @@ def test_invalid_cm_variables_declared(self, setUp):
'invalid')),
'CoCoCmVariablesDeclared.nestml'))
assert len(Logger.get_all_messages_of_level_and_or_node(
- model.get_model_list()[0], LoggingLevel.ERROR)) == 5
+ model, LoggingLevel.ERROR)) == 6
- def test_valid_cm_variables_declared(self, setUp):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
+ def test_valid_cm_variables_declared(self):
+ model = self._parse_and_validate_model(
os.path.join(
os.path.realpath(
os.path.join(
@@ -75,12 +72,12 @@ def test_valid_cm_variables_declared(self, setUp):
'valid')),
'CoCoCmVariablesDeclared.nestml'))
assert len(Logger.get_all_messages_of_level_and_or_node(
- model.get_model_list()[0], LoggingLevel.ERROR)) == 0
+ model, LoggingLevel.ERROR)) == 0
# it is currently not enforced for the non-cm parameter block, but cm
# needs that
- def test_invalid_cm_variable_has_rhs(self, setUp):
- model = ModelParser.parse_file(
+ def test_invalid_cm_variable_has_rhs(self):
+ model = self._parse_and_validate_model(
os.path.join(
os.path.realpath(
os.path.join(
@@ -88,11 +85,11 @@ def test_invalid_cm_variable_has_rhs(self, setUp):
'invalid')),
'CoCoCmVariableHasRhs.nestml'))
assert len(Logger.get_all_messages_of_level_and_or_node(
- model.get_model_list()[0], LoggingLevel.ERROR)) == 2
+ model, LoggingLevel.ERROR)) == 2
- def test_valid_cm_variable_has_rhs(self, setUp):
+ def test_valid_cm_variable_has_rhs(self):
Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
+ model = self._parse_and_validate_model(
os.path.join(
os.path.realpath(
os.path.join(
@@ -100,12 +97,12 @@ def test_valid_cm_variable_has_rhs(self, setUp):
'valid')),
'CoCoCmVariableHasRhs.nestml'))
assert len(Logger.get_all_messages_of_level_and_or_node(
- model.get_model_list()[0], LoggingLevel.ERROR)) == 0
+ model, LoggingLevel.ERROR)) == 0
# it is currently not enforced for the non-cm parameter block, but cm
# needs that
- def test_invalid_cm_v_comp_exists(self, setUp):
- model = ModelParser.parse_file(
+ def test_invalid_cm_v_comp_exists(self):
+ model = self._parse_and_validate_model(
os.path.join(
os.path.realpath(
os.path.join(
@@ -113,11 +110,11 @@ def test_invalid_cm_v_comp_exists(self, setUp):
'invalid')),
'CoCoCmVcompExists.nestml'))
assert len(Logger.get_all_messages_of_level_and_or_node(
- model.get_model_list()[0], LoggingLevel.ERROR)) == 4
+ model, LoggingLevel.ERROR)) == 4
- def test_valid_cm_v_comp_exists(self, setUp):
+ def test_valid_cm_v_comp_exists(self):
Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
+ model = self._parse_and_validate_model(
os.path.join(
os.path.realpath(
os.path.join(
@@ -125,4 +122,23 @@ def test_valid_cm_v_comp_exists(self, setUp):
'valid')),
'CoCoCmVcompExists.nestml'))
assert len(Logger.get_all_messages_of_level_and_or_node(
- model.get_model_list()[0], LoggingLevel.ERROR)) == 0
+ model, LoggingLevel.ERROR)) == 0
+
+ def _parse_and_validate_model(self, fname: str) -> Optional[str]:
+ from pynestml.frontend.pynestml_frontend import generate_target
+
+ Logger.init_logger(LoggingLevel.DEBUG)
+
+ try:
+ generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG")
+ except BaseException:
+ return None
+
+ ast_compilation_unit = ModelParser.parse_file(fname)
+ if ast_compilation_unit is None or len(ast_compilation_unit.get_model_list()) == 0:
+ return None
+
+ model: ASTModel = ast_compilation_unit.get_model_list()[0]
+ model_name = model.get_name()
+
+ return model_name
diff --git a/tests/nest_continuous_benchmarking/test_nest_continuous_benchmarking.py b/tests/nest_continuous_benchmarking/test_nest_continuous_benchmarking.py
new file mode 100644
index 000000000..70e94706b
--- /dev/null
+++ b/tests/nest_continuous_benchmarking/test_nest_continuous_benchmarking.py
@@ -0,0 +1,311 @@
+# -*- coding: utf-8 -*-
+#
+# test_nest_continuous_benchmarking.py
+#
+# This file is part of NEST.
+#
+# Copyright (C) 2004 The NEST Initiative
+#
+# NEST is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# NEST is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with NEST. If not, see .
+
+import numpy as np
+import os
+import pytest
+
+import nest
+
+from pynestml.frontend.pynestml_frontend import generate_nest_target
+
+try:
+ import matplotlib
+ matplotlib.use('Agg')
+ import matplotlib.ticker
+ import matplotlib.pyplot as plt
+ TEST_PLOTS = True
+except Exception:
+ TEST_PLOTS = False
+
+sim_mdl = True
+sim_ref = True
+
+
+class TestNESTContinuousBenchmarking:
+
+ neuron_model_name = "iaf_psc_exp_neuron_nestml__with_stdp_nn_symm_synapse_nestml"
+ ref_neuron_model_name = "iaf_psc_exp_neuron_nestml_non_jit"
+
+ synapse_model_name = "stdp_nn_symm_synapse_nestml__with_iaf_psc_exp_neuron_nestml"
+ ref_synapse_model_name = "stdp_nn_symm_synapse"
+
+ @pytest.fixture(scope="module", autouse=True)
+ def setUp(self):
+ """Generate the neuron model code"""
+
+ # generate the "jit" model (co-generated neuron and synapse), that does not rely on ArchivingNode
+ files = [os.path.join("models", "neurons", "iaf_psc_exp_neuron.nestml"),
+ os.path.join("models", "synapses", "stdp_nn_symm_synapse.nestml")]
+ input_path = [os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join(
+ os.pardir, os.pardir, s))) for s in files]
+ generate_nest_target(input_path=input_path,
+ target_path="/tmp/nestml-jit",
+ logging_level="INFO",
+ module_name="nestml_jit_module",
+ suffix="_nestml",
+ codegen_opts={"neuron_parent_class": "StructuralPlasticityNode",
+ "neuron_parent_class_include": "structural_plasticity_node.h",
+ "neuron_synapse_pairs": [{"neuron": "iaf_psc_exp_neuron",
+ "synapse": "stdp_nn_symm_synapse",
+ "post_ports": ["post_spikes"]}],
+ "delay_variable": {"stdp_nn_symm_synapse": "d"},
+ "weight_variable": {"stdp_nn_symm_synapse": "w"}})
+
+ # generate the "non-jit" model, that relies on ArchivingNode
+ generate_nest_target(input_path=os.path.realpath(os.path.join(os.path.dirname(__file__),
+ os.path.join(os.pardir, os.pardir, "models", "neurons", "iaf_psc_exp_neuron.nestml"))),
+ target_path="/tmp/nestml-non-jit",
+ logging_level="INFO",
+ module_name="nestml_non_jit_module",
+ suffix="_nestml_non_jit",
+ codegen_opts={"neuron_parent_class": "ArchivingNode",
+ "neuron_parent_class_include": "archiving_node.h"})
+
+ @pytest.mark.benchmark
+ def test_stdp_nn_synapse(self, benchmark):
+ return benchmark(self._test_stdp_nn_synapse)
+
+ def _test_stdp_nn_synapse(self):
+
+ fname_snip = ""
+
+ pre_spike_times = [1., 11., 21.] # [ms]
+ post_spike_times = [6., 16., 26.] # [ms]
+
+ post_spike_times = np.sort(np.unique(1 + np.round(10 * np.sort(np.abs(np.random.randn(10)))))) # [ms]
+ pre_spike_times = np.sort(np.unique(1 + np.round(10 * np.sort(np.abs(np.random.randn(10)))))) # [ms]
+
+ post_spike_times = np.sort(np.unique(1 + np.round(100 * np.sort(np.abs(np.random.randn(100)))))) # [ms]
+ pre_spike_times = np.sort(np.unique(1 + np.round(100 * np.sort(np.abs(np.random.randn(100)))))) # [ms]
+
+ pre_spike_times = np.array([2., 4., 7., 8., 12., 13., 19., 23., 24., 28., 29., 30., 33., 34.,
+ 35., 36., 38., 40., 42., 46., 51., 53., 54., 55., 56., 59., 63., 64.,
+ 65., 66., 68., 72., 73., 76., 79., 80., 83., 84., 86., 87., 90., 95.,
+ 99., 100., 103., 104., 105., 111., 112., 126., 131., 133., 134., 139., 147., 150.,
+ 152., 155., 172., 175., 176., 181., 196., 197., 199., 202., 213., 215., 217., 265.])
+ post_spike_times = np.array([4., 5., 6., 7., 10., 11., 12., 16., 17., 18., 19., 20., 22., 23.,
+ 25., 27., 29., 30., 31., 32., 34., 36., 37., 38., 39., 42., 44., 46.,
+ 48., 49., 50., 54., 56., 57., 59., 60., 61., 62., 67., 74., 76., 79.,
+ 80., 81., 83., 88., 93., 94., 97., 99., 100., 105., 111., 113., 114., 115.,
+ 116., 119., 123., 130., 132., 134., 135., 145., 152., 155., 158., 166., 172., 174.,
+ 188., 194., 202., 245., 249., 289., 454.])
+
+ self.run_synapse_test(neuron_model_name=self.neuron_model_name,
+ ref_neuron_model_name=self.ref_neuron_model_name,
+ synapse_model_name=self.synapse_model_name,
+ ref_synapse_model_name=self.ref_synapse_model_name,
+ resolution=1., # [ms]
+ delay=1., # [ms]
+ pre_spike_times=pre_spike_times,
+ post_spike_times=post_spike_times,
+ fname_snip=fname_snip)
+
+ def run_synapse_test(self, neuron_model_name,
+ ref_neuron_model_name,
+ synapse_model_name,
+ ref_synapse_model_name,
+ resolution=1., # [ms]
+ delay=1., # [ms]
+ sim_time=None, # if None, computed from pre and post spike times
+ pre_spike_times=None,
+ post_spike_times=None,
+ fname_snip=""):
+
+ if pre_spike_times is None:
+ pre_spike_times = []
+
+ if post_spike_times is None:
+ post_spike_times = []
+
+ if sim_time is None:
+ sim_time = max(np.amax(pre_spike_times), np.amax(post_spike_times)) + 5 * delay
+
+ nest.ResetKernel()
+ nest.set_verbosity("M_ALL")
+ nest.SetKernelStatus({'resolution': resolution})
+
+ if sim_mdl:
+ try:
+ nest.Install("nestml_jit_module")
+ except Exception:
+ # ResetKernel() does not unload modules for NEST Simulator < v3.7; ignore exception if module is already loaded on earlier versions
+ pass
+
+ if sim_ref:
+ try:
+ nest.Install("nestml_non_jit_module")
+ except Exception:
+ # ResetKernel() does not unload modules for NEST Simulator < v3.7; ignore exception if module is already loaded on earlier versions
+ pass
+
+ print("Pre spike times: " + str(pre_spike_times))
+ print("Post spike times: " + str(post_spike_times))
+
+ wr = nest.Create('weight_recorder')
+ wr_ref = nest.Create('weight_recorder')
+ if sim_mdl:
+ nest.CopyModel(synapse_model_name, "stdp_nestml_rec",
+ {"weight_recorder": wr[0], "w": 1., "d": 1., "receptor_type": 0})
+ if sim_ref:
+ nest.CopyModel(ref_synapse_model_name, "stdp_ref_rec",
+ {"weight_recorder": wr_ref[0], "weight": 1., "delay": 1., "receptor_type": 0})
+
+ # create spike_generators with these times
+ pre_sg = nest.Create("spike_generator",
+ params={"spike_times": pre_spike_times})
+ post_sg = nest.Create("spike_generator",
+ params={"spike_times": post_spike_times,
+ 'allow_offgrid_times': True})
+
+ # create parrot neurons and connect spike_generators
+ if sim_mdl:
+ pre_neuron = nest.Create("parrot_neuron")
+ post_neuron = nest.Create(neuron_model_name)
+
+ if sim_ref:
+ pre_neuron_ref = nest.Create("parrot_neuron")
+ post_neuron_ref = nest.Create(ref_neuron_model_name)
+
+ if sim_mdl:
+ spikedet_pre = nest.Create("spike_recorder")
+ spikedet_post = nest.Create("spike_recorder")
+ mm = nest.Create("multimeter", params={"record_from": ["V_m", "post_trace__for_stdp_nn_symm_synapse_nestml"]})
+ if sim_ref:
+ spikedet_pre_ref = nest.Create("spike_recorder")
+ spikedet_post_ref = nest.Create("spike_recorder")
+ mm_ref = nest.Create("multimeter", params={"record_from": ["V_m"]})
+
+ if sim_mdl:
+ nest.Connect(pre_sg, pre_neuron, "one_to_one", syn_spec={"delay": 1.})
+ nest.Connect(post_sg, post_neuron, "one_to_one", syn_spec={"delay": 1., "weight": 9999.})
+ nest.Connect(pre_neuron, post_neuron, "all_to_all", syn_spec={'synapse_model': 'stdp_nestml_rec'})
+ nest.Connect(mm, post_neuron)
+ nest.Connect(pre_neuron, spikedet_pre)
+ nest.Connect(post_neuron, spikedet_post)
+ if sim_ref:
+ nest.Connect(pre_sg, pre_neuron_ref, "one_to_one", syn_spec={"delay": 1.})
+ nest.Connect(post_sg, post_neuron_ref, "one_to_one", syn_spec={"delay": 1., "weight": 9999.})
+ nest.Connect(pre_neuron_ref, post_neuron_ref, "all_to_all",
+ syn_spec={'synapse_model': ref_synapse_model_name})
+ nest.Connect(mm_ref, post_neuron_ref)
+ nest.Connect(pre_neuron_ref, spikedet_pre_ref)
+ nest.Connect(post_neuron_ref, spikedet_post_ref)
+
+ # get STDP synapse and weight before protocol
+ if sim_mdl:
+ syn = nest.GetConnections(source=pre_neuron, synapse_model="stdp_nestml_rec")
+ if sim_ref:
+ syn_ref = nest.GetConnections(source=pre_neuron_ref, synapse_model=ref_synapse_model_name)
+
+ n_steps = int(np.ceil(sim_time / resolution)) + 1
+ t = 0.
+ t_hist = []
+ if sim_mdl:
+ w_hist = []
+ if sim_ref:
+ w_hist_ref = []
+ while t <= sim_time:
+ nest.Simulate(resolution)
+ t += resolution
+ t_hist.append(t)
+ if sim_ref:
+ w_hist_ref.append(nest.GetStatus(syn_ref)[0]['weight'])
+ if sim_mdl:
+ w_hist.append(nest.GetStatus(syn)[0]['w'])
+
+ # plot
+ if TEST_PLOTS:
+ fig, ax = plt.subplots(nrows=3)
+ ax1, ax2, ax3 = ax
+
+ if sim_mdl:
+ pre_spike_times_ = nest.GetStatus(spikedet_pre, "events")[0]["times"]
+ print("Actual pre spike times: " + str(pre_spike_times_))
+ if sim_ref:
+ pre_ref_spike_times_ = nest.GetStatus(spikedet_pre_ref, "events")[0]["times"]
+ print("Actual pre ref spike times: " + str(pre_ref_spike_times_))
+
+ if sim_mdl:
+ n_spikes = len(pre_spike_times_)
+ for i in range(n_spikes):
+ if i == 0:
+ _lbl = "nestml"
+ else:
+ _lbl = None
+ ax1.plot(2 * [pre_spike_times_[i] + delay], [0, 1], linewidth=2, color="blue", alpha=.4, label=_lbl)
+
+ if sim_mdl:
+ post_spike_times_ = nest.GetStatus(spikedet_post, "events")[0]["times"]
+ print("Actual post spike times: " + str(post_spike_times_))
+ if sim_ref:
+ post_ref_spike_times_ = nest.GetStatus(spikedet_post_ref, "events")[0]["times"]
+ print("Actual post ref spike times: " + str(post_ref_spike_times_))
+
+ if sim_ref:
+ n_spikes = len(pre_ref_spike_times_)
+ for i in range(n_spikes):
+ if i == 0:
+ _lbl = "nest ref"
+ else:
+ _lbl = None
+ ax1.plot(2 * [pre_ref_spike_times_[i] + delay], [0, 1],
+ linewidth=2, color="cyan", label=_lbl, alpha=.4)
+ ax1.set_ylabel("Pre spikes")
+
+ if sim_mdl:
+ n_spikes = len(post_spike_times_)
+ for i in range(n_spikes):
+ if i == 0:
+ _lbl = "nestml"
+ else:
+ _lbl = None
+ ax2.plot(2 * [post_spike_times_[i]], [0, 1], linewidth=2, color="black", alpha=.4, label=_lbl)
+ if sim_ref:
+ n_spikes = len(post_ref_spike_times_)
+ for i in range(n_spikes):
+ if i == 0:
+ _lbl = "nest ref"
+ else:
+ _lbl = None
+ ax2.plot(2 * [post_ref_spike_times_[i]], [0, 1], linewidth=2, color="red", alpha=.4, label=_lbl)
+ if sim_mdl:
+ ax2.plot(nest.GetStatus(mm, "events")[0]["times"], nest.GetStatus(mm, "events")[
+ 0]["post_trace__for_stdp_nn_symm_synapse_nestml"], label="nestml post tr")
+ ax2.set_ylabel("Post spikes")
+
+ if sim_mdl:
+ ax3.plot(t_hist, w_hist, marker="o", label="nestml")
+ if sim_ref:
+ ax3.plot(t_hist, w_hist_ref, linestyle="--", marker="x", label="ref")
+
+ ax3.set_xlabel("Time [ms]")
+ ax3.set_ylabel("w")
+ for _ax in ax:
+ _ax.grid(which="major", axis="both")
+ _ax.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.arange(0, np.ceil(sim_time))))
+ _ax.set_xlim(0., sim_time)
+ _ax.legend()
+ fig.savefig("/tmp/stdp_synapse_test" + fname_snip + ".png", dpi=300)
+
+ # verify
+ MAX_ABS_ERROR = 1E-6
+ assert np.all(np.abs(np.array(w_hist) - np.array(w_hist_ref)) < MAX_ABS_ERROR)
diff --git a/tests/nest_tests/nest_delay_based_variables_test.py b/tests/nest_tests/nest_delay_based_variables_test.py
index 51f863e19..a11c280f2 100644
--- a/tests/nest_tests/nest_delay_based_variables_test.py
+++ b/tests/nest_tests/nest_delay_based_variables_test.py
@@ -19,13 +19,12 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see .
+from typing import List
+
import numpy as np
import os
-from typing import List
import pytest
-import nest
-
try:
import matplotlib
import matplotlib.pyplot as plt
@@ -34,15 +33,12 @@
except BaseException:
TEST_PLOTS = False
+import nest
+
from pynestml.codegeneration.nest_tools import NESTTools
from pynestml.frontend.pynestml_frontend import generate_nest_target
-target_path = "target_delay"
-logging_level = "DEBUG"
-suffix = "_nestml"
-
-
def plot_fig(times, recordable_events_delay: dict, recordable_events: dict, filename: str):
fig, axes = plt.subplots(len(recordable_events), 1, figsize=(7, 9), sharex=True)
for i, recordable_name in enumerate(recordable_events_delay.keys()):
@@ -86,6 +82,9 @@ def run_simulation(neuron_model_name: str, module_name: str, recordables: List[s
("DelayDifferentialEquationsWithNumericSolver.nestml", "dde_numeric_nestml", ["x", "z"]),
("DelayDifferentialEquationsWithMixedSolver.nestml", "dde_mixed_nestml", ["x", "z"])])
def test_dde_with_analytic_solver(file_name: str, neuron_model_name: str, recordables: List[str]):
+ target_path = "target_delay"
+ logging_level = "DEBUG"
+ suffix = "_nestml"
input_path = os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), "resources", file_name)))
module_name = neuron_model_name + "_module"
print("Module name: ", module_name)
@@ -112,16 +111,3 @@ def test_dde_with_analytic_solver(file_name: str, neuron_model_name: str, record
if neuron_model_name == "dde_analytic_nestml":
np.testing.assert_allclose(recordable_events_delay[recordables[1]][int(delay):],
recordable_events[recordables[1]][:-int(delay)])
-
- @pytest.fixture(scope="function", autouse=True)
- def cleanup(self):
- # Run the test
- yield
-
- # clean up
- import shutil
- if self.target_path:
- try:
- shutil.rmtree(self.target_path)
- except Exception:
- pass
diff --git a/tests/nest_tests/nest_random_functions_test.py b/tests/nest_tests/nest_random_functions_test.py
new file mode 100644
index 000000000..3d24ed963
--- /dev/null
+++ b/tests/nest_tests/nest_random_functions_test.py
@@ -0,0 +1,57 @@
+# -*- coding: utf-8 -*-
+#
+# nest_random_functions_test.py
+#
+# This file is part of NEST.
+#
+# Copyright (C) 2004 The NEST Initiative
+#
+# NEST is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# NEST is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with NEST. If not, see .
+import os
+
+import pytest
+
+from pynestml.frontend.pynestml_frontend import generate_nest_target
+
+
+class TestNestRandomFunctions:
+ """
+ Tests that, for the NEST target, random number functions are called only in ``update``, ``onReceive``, and ``onCondition`` block
+ """
+
+ @pytest.mark.xfail(strict=True, raises=Exception)
+ def test_nest_random_function_neuron_illegal(self):
+ input_path = os.path.realpath(os.path.join(os.path.dirname(__file__),
+ "resources", "random_functions_illegal_neuron.nestml"))
+ generate_nest_target(input_path=input_path,
+ target_path="target",
+ logging_level="INFO",
+ suffix="_nestml")
+
+ @pytest.mark.xfail(strict=True, raises=Exception)
+ def test_nest_random_function_synapse_illegal(self):
+ input_path = [
+ os.path.realpath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, "models", "neurons",
+ "iaf_psc_exp_neuron.nestml")),
+ os.path.realpath(os.path.join(os.path.dirname(__file__),
+ "resources", "random_functions_illegal_synapse.nestml"))]
+
+ generate_nest_target(input_path=input_path,
+ target_path="target",
+ logging_level="INFO",
+ suffix="_nestml",
+ codegen_opts={"neuron_synapse_pairs": [{"neuron": "iaf_psc_exp_neuron",
+ "synapse": "random_functions_illegal_synapse",
+ "post_ports": ["post_spikes"]}],
+ "weight_variable": {"stdp_synapse": "w"}})
diff --git a/tests/nest_tests/resources/integrate_odes_test_params.nestml b/tests/nest_tests/resources/integrate_odes_test_params.nestml
index d07fe8fd4..d6430e537 100644
--- a/tests/nest_tests/resources/integrate_odes_test_params.nestml
+++ b/tests/nest_tests/resources/integrate_odes_test_params.nestml
@@ -8,7 +8,6 @@ model integrate_odes_test:
update:
integrate_odes(2 * test_1)
- integrate_odes(test_3)
integrate_odes(100 ms)
integrate_odes(test_1)
integrate_odes(test_2)
diff --git a/tests/nest_tests/resources/integrate_odes_test_params2.nestml b/tests/nest_tests/resources/integrate_odes_test_params2.nestml
new file mode 100644
index 000000000..616401e48
--- /dev/null
+++ b/tests/nest_tests/resources/integrate_odes_test_params2.nestml
@@ -0,0 +1,10 @@
+"""
+Model for testing the integrate_odes() function.
+"""
+model integrate_odes_test:
+ state:
+ test_1 real = 0.
+ test_2 real = 0.
+
+ update:
+ integrate_odes(test_3)
diff --git a/tests/nest_tests/resources/random_functions_illegal_neuron.nestml b/tests/nest_tests/resources/random_functions_illegal_neuron.nestml
new file mode 100644
index 000000000..9954459d3
--- /dev/null
+++ b/tests/nest_tests/resources/random_functions_illegal_neuron.nestml
@@ -0,0 +1,46 @@
+"""
+random_functions_illegal_neuron.nestml
+######################################
+
+
+Copyright statement
++++++++++++++++++++
+
+This file is part of NEST.
+
+Copyright (C) 2004 The NEST Initiative
+
+NEST is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, either version 2 of the License, or
+(at your option) any later version.
+
+NEST is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with NEST. If not, see .
+"""
+
+model random_functions_illegal_neuron:
+ state:
+ noise real = random_normal(0,sigma_noise)
+ v mV = -15 mV
+
+ parameters:
+ rate ms**-1 = 15.5 s**-1
+ sigma_noise real = 16.
+ u real = random_uniform(0,1)
+
+ internals:
+ poisson_input integer = random_poisson(rate * resolution() * 1E-3)
+
+ update:
+ if u < 0.5:
+ noise = 0.
+ else:
+ noise = random_normal(0,sigma_noise)
+
+ v += (poisson_input + noise) * mV
diff --git a/tests/nest_tests/resources/random_functions_illegal_synapse.nestml b/tests/nest_tests/resources/random_functions_illegal_synapse.nestml
new file mode 100644
index 000000000..473791ec7
--- /dev/null
+++ b/tests/nest_tests/resources/random_functions_illegal_synapse.nestml
@@ -0,0 +1,73 @@
+"""
+random_functions_illegal_synapse.nestml
+#######################################
+
+
+Copyright statement
++++++++++++++++++++
+
+This file is part of NEST.
+
+Copyright (C) 2004 The NEST Initiative
+
+NEST is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, either version 2 of the License, or
+(at your option) any later version.
+
+NEST is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with NEST. If not, see .
+"""
+
+model random_functions_illegal_synapse:
+ state:
+ w real = 1 # Synaptic weight
+ pre_trace real = 0.
+ post_trace real = 0.
+
+ parameters:
+ d ms = 1 ms # Synaptic transmission delay
+ lambda real = .01
+ tau_tr_pre ms = random_normal(110 ms, 55 ms)
+ tau_tr_post ms = random_normal(5 ms, 2.5 ms)
+ alpha real = 1
+ mu_plus real = 1
+ mu_minus real = 1
+ Wmax real = 100.
+ Wmin real = 0.
+
+ equations:
+ pre_trace' = -pre_trace / tau_tr_pre
+ post_trace' = -post_trace / tau_tr_post
+
+ input:
+ pre_spikes <- spike
+ post_spikes <- spike
+
+ output:
+ spike
+
+ onReceive(post_spikes):
+ post_trace += 1
+
+ # potentiate synapse
+ w_ real = Wmax * ( w / Wmax + (lambda * ( 1. - ( w / Wmax ) )**mu_plus * pre_trace ))
+ w = min(Wmax, w_)
+
+ onReceive(pre_spikes):
+ pre_trace += 1
+
+ # depress synapse
+ w_ real = Wmax * ( w / Wmax - ( alpha * lambda * ( w / Wmax )**mu_minus * post_trace ))
+ w = max(Wmin, w_)
+
+ # deliver spike to postsynaptic partner
+ emit_spike(w, d)
+
+ update:
+ integrate_odes()
diff --git a/tests/nest_tests/test_integrate_odes.py b/tests/nest_tests/test_integrate_odes.py
index 99b94c6ca..6ddb699b4 100644
--- a/tests/nest_tests/test_integrate_odes.py
+++ b/tests/nest_tests/test_integrate_odes.py
@@ -27,16 +27,9 @@
import nest
-from pynestml.utils.ast_source_location import ASTSourceLocation
-from pynestml.symbol_table.symbol_table import SymbolTable
-from pynestml.symbols.predefined_functions import PredefinedFunctions
-from pynestml.symbols.predefined_types import PredefinedTypes
-from pynestml.symbols.predefined_units import PredefinedUnits
-from pynestml.symbols.predefined_variables import PredefinedVariables
from pynestml.codegeneration.nest_tools import NESTTools
-from pynestml.frontend.pynestml_frontend import generate_nest_target
+from pynestml.frontend.pynestml_frontend import generate_nest_target, generate_target
from pynestml.utils.logger import LoggingLevel, Logger
-from pynestml.utils.model_parser import ModelParser
try:
import matplotlib
@@ -227,12 +220,15 @@ def test_integrate_odes_nonlinear(self):
def test_integrate_odes_params(self):
r"""Test the integrate_odes() function, in particular with respect to the parameter types."""
- Logger.init_logger(LoggingLevel.INFO)
- SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0))
- PredefinedUnits.register_units()
- PredefinedTypes.register_types()
- PredefinedVariables.register_variables()
- PredefinedFunctions.register_functions()
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params.nestml"))))
- assert len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)) == 6
+ fname = os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params.nestml")))
+ generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG")
+
+ assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2
+
+ def test_integrate_odes_params2(self):
+ r"""Test the integrate_odes() function, in particular with respect to non-existent parameter variables."""
+
+ fname = os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params2.nestml")))
+ generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG")
+
+ assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2
diff --git a/tests/test_cocos.py b/tests/test_cocos.py
new file mode 100644
index 000000000..81f519eaf
--- /dev/null
+++ b/tests/test_cocos.py
@@ -0,0 +1,403 @@
+# -*- coding: utf-8 -*-
+#
+# test_cocos.py
+#
+# This file is part of NEST.
+#
+# Copyright (C) 2004 The NEST Initiative
+#
+# NEST is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# NEST is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with NEST. If not, see .
+
+from __future__ import print_function
+
+from typing import Optional
+
+import os
+import pytest
+
+from pynestml.meta_model.ast_model import ASTModel
+from pynestml.symbol_table.symbol_table import SymbolTable
+from pynestml.symbols.predefined_functions import PredefinedFunctions
+from pynestml.symbols.predefined_types import PredefinedTypes
+from pynestml.symbols.predefined_units import PredefinedUnits
+from pynestml.symbols.predefined_variables import PredefinedVariables
+from pynestml.utils.ast_source_location import ASTSourceLocation
+from pynestml.utils.logger import LoggingLevel, Logger
+from pynestml.utils.model_parser import ModelParser
+
+
+class TestCoCos:
+
+ @pytest.fixture(scope="module", autouse=True)
+ def setUp(self):
+ SymbolTable.initialize_symbol_table(
+ ASTSourceLocation(
+ start_line=0,
+ start_column=0,
+ end_line=0,
+ end_column=0))
+ PredefinedUnits.register_units()
+ PredefinedTypes.register_types()
+ PredefinedVariables.register_variables()
+ PredefinedFunctions.register_functions()
+
+ def test_invalid_element_defined_after_usage(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVariableDefinedAfterUsage.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_element_defined_after_usage(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVariableDefinedAfterUsage.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_element_in_same_line(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoElementInSameLine.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_element_in_same_line(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoElementInSameLine.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_integrate_odes_called_if_equations_defined(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoIntegrateOdesCalledIfEquationsDefined.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_integrate_odes_called_if_equations_defined(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoIntegrateOdesCalledIfEquationsDefined.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_element_not_defined_in_scope(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVariableNotDefined.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 6
+
+ def test_valid_element_not_defined_in_scope(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVariableNotDefined.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_variable_with_same_name_as_unit(self):
+ Logger.set_logging_level(LoggingLevel.NO)
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVariableWithSameNameAsUnit.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.WARNING)) == 3
+
+ def test_invalid_variable_redeclaration(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVariableRedeclared.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_variable_redeclaration(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVariableRedeclared.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_each_block_unique(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoEachBlockUnique.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_each_block_unique(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoEachBlockUnique.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_function_unique_and_defined(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoFunctionNotUnique.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 8
+
+ def test_valid_function_unique_and_defined(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoFunctionNotUnique.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_inline_expressions_have_rhs(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInlineExpressionHasNoRhs.nestml'))
+ assert model is None
+
+ def test_valid_inline_expressions_have_rhs(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInlineExpressionHasNoRhs.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_inline_expression_has_several_lhs(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInlineExpressionWithSeveralLhs.nestml'))
+ assert model is None
+
+ def test_valid_inline_expression_has_several_lhs(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInlineExpressionWithSeveralLhs.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_no_values_assigned_to_input_ports(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoValueAssignedToInputPort.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_no_values_assigned_to_input_ports(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoValueAssignedToInputPort.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_order_of_equations_correct(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoNoOrderOfEquations.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_order_of_equations_correct(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoNoOrderOfEquations.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_numerator_of_unit_one(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoUnitNumeratorNotOne.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_numerator_of_unit_one(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoUnitNumeratorNotOne.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_names_of_neurons_unique(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoMultipleNeuronsWithEqualName.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 3
+
+ def test_valid_names_of_neurons_unique(self):
+ self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoMultipleNeuronsWithEqualName.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(None, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_no_nest_collision(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoNestNamespaceCollision.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_no_nest_collision(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoNestNamespaceCollision.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_redundant_input_port_keywords_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInputPortWithRedundantTypes.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_redundant_input_port_keywords_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInputPortWithRedundantTypes.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_parameters_assigned_only_in_parameters_block(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoParameterAssignedOutsideBlock.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_parameters_assigned_only_in_parameters_block(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoParameterAssignedOutsideBlock.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_inline_expressions_assigned_only_in_declaration(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoAssignmentToInlineExpression.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_invalid_internals_assigned_only_in_internals_block(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInternalAssignedOutsideBlock.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_internals_assigned_only_in_internals_block(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInternalAssignedOutsideBlock.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_function_with_wrong_arg_number_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoFunctionCallNotConsistentWrongArgNumber.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_function_with_wrong_arg_number_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoFunctionCallNotConsistentWrongArgNumber.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_init_values_have_rhs_and_ode(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInitValuesWithoutOde.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.WARNING)) == 2
+
+ def test_valid_init_values_have_rhs_and_ode(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInitValuesWithoutOde.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.WARNING)) == 2
+
+ def test_invalid_incorrect_return_stmt_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoIncorrectReturnStatement.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 8
+
+ def test_valid_incorrect_return_stmt_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoIncorrectReturnStatement.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_ode_vars_outside_init_block_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoOdeVarNotInInitialValues.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_ode_vars_outside_init_block_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoOdeVarNotInInitialValues.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_convolve_correctly_defined(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoConvolveNotCorrectlyProvided.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_convolve_correctly_defined(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoConvolveNotCorrectlyProvided.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_vector_in_non_vector_declaration_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorInNonVectorDeclaration.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_vector_in_non_vector_declaration_detected(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorInNonVectorDeclaration.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_vector_parameter_declaration(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorParameterDeclaration.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_vector_parameter_declaration(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorParameterDeclaration.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_vector_parameter_type(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorParameterType.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_vector_parameter_type(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorParameterType.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_vector_parameter_size(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorDeclarationSize.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_vector_parameter_size(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorDeclarationSize.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_convolve_correctly_parameterized(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoConvolveNotCorrectlyParametrized.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_convolve_correctly_parameterized(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoConvolveNotCorrectlyParametrized.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_invariant_correctly_typed(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInvariantNotBool.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_invariant_correctly_typed(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInvariantNotBool.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_expression_correctly_typed(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoIllegalExpression.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_expression_correctly_typed(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoIllegalExpression.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_compound_expression_correctly_typed(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CompoundOperatorWithDifferentButCompatibleUnits.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 10
+
+ def test_valid_compound_expression_correctly_typed(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CompoundOperatorWithDifferentButCompatibleUnits.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_ode_correctly_typed(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoOdeIncorrectlyTyped.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) > 0
+
+ def test_valid_ode_correctly_typed(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoOdeCorrectlyTyped.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_output_block_defined_if_emit_call(self):
+ """test that an error is raised when the emit_spike() function is called by the neuron, but an output block is not defined"""
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoOutputPortDefinedIfEmitCall.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) > 0
+
+ def test_invalid_output_port_defined_if_emit_call(self):
+ """test that an error is raised when the emit_spike() function is called by the neuron, but a spiking output port is not defined"""
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoOutputPortDefinedIfEmitCall-2.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) > 0
+
+ def test_valid_output_port_defined_if_emit_call(self):
+ """test that no error is raised when the output block is missing, but not emit_spike() functions are called"""
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoOutputPortDefinedIfEmitCall.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_valid_coco_kernel_type(self):
+ """
+ Test the functionality of CoCoKernelType.
+ """
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoKernelType.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_coco_kernel_type(self):
+ """
+ Test the functionality of CoCoKernelType.
+ """
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoKernelType.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_invalid_coco_kernel_type_initial_values(self):
+ """
+ Test the functionality of CoCoKernelType.
+ """
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoKernelTypeInitialValues.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 4
+
+ def test_valid_coco_state_variables_initialized(self):
+ """
+ Test that the CoCo condition is applicable for all the variables in the state block initialized with a value
+ """
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoStateVariablesInitialized.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_coco_state_variables_initialized(self):
+ """
+ Test that the CoCo condition is applicable for all the variables in the state block not initialized
+ """
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoStateVariablesInitialized.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_invalid_co_co_priorities_correctly_specified(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoPrioritiesCorrectlySpecified.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def test_valid_co_co_priorities_correctly_specified(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoPrioritiesCorrectlySpecified.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_co_co_resolution_legally_used(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoResolutionLegallyUsed.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2
+
+ def test_valid_co_co_resolution_legally_used(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoResolutionLegallyUsed.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_valid_co_co_vector_input_port(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorInputPortSizeAndType.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0
+
+ def test_invalid_co_co_vector_input_port(self):
+ model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorInputPortSizeAndType.nestml'))
+ assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1
+
+ def _parse_and_validate_model(self, fname: str) -> Optional[str]:
+ from pynestml.frontend.pynestml_frontend import generate_target
+
+ Logger.init_logger(LoggingLevel.DEBUG)
+
+ try:
+ generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG")
+ except BaseException:
+ return None
+
+ ast_compilation_unit = ModelParser.parse_file(fname)
+ if ast_compilation_unit is None or len(ast_compilation_unit.get_model_list()) == 0:
+ return None
+
+ model: ASTModel = ast_compilation_unit.get_model_list()[0]
+ model_name = model.get_name()
+
+ return model_name
diff --git a/tests/test_function_parameter_templating.py b/tests/test_function_parameter_templating.py
new file mode 100644
index 000000000..b93e06780
--- /dev/null
+++ b/tests/test_function_parameter_templating.py
@@ -0,0 +1,36 @@
+# -*- coding: utf-8 -*-
+#
+# test_function_parameter_templating.py
+#
+# This file is part of NEST.
+#
+# Copyright (C) 2004 The NEST Initiative
+#
+# NEST is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# NEST is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with NEST. If not, see .
+
+import os
+
+from pynestml.utils.logger import Logger, LoggingLevel
+from pynestml.frontend.pynestml_frontend import generate_target
+
+
+class TestFunctionParameterTemplating:
+ """
+ This test is used to test the correct derivation of types when functions use templated type parameters.
+ """
+
+ def test(self):
+ fname = os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), "resources", "FunctionParameterTemplatingTest.nestml")))
+ generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG")
+ assert len(Logger.get_all_messages_of_level_and_or_node("templated_function_parameters_type_test", LoggingLevel.ERROR)) == 5
diff --git a/tests/test_unit_system.py b/tests/test_unit_system.py
new file mode 100644
index 000000000..2cad0b98d
--- /dev/null
+++ b/tests/test_unit_system.py
@@ -0,0 +1,164 @@
+# -*- coding: utf-8 -*-
+#
+# test_unit_system.py
+#
+# This file is part of NEST.
+#
+# Copyright (C) 2004 The NEST Initiative
+#
+# NEST is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 2 of the License, or
+# (at your option) any later version.
+#
+# NEST is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with NEST. If not, see .
+
+import os
+import pytest
+
+from pynestml.codegeneration.printers.constant_printer import ConstantPrinter
+from pynestml.codegeneration.printers.cpp_expression_printer import CppExpressionPrinter
+from pynestml.codegeneration.printers.cpp_simple_expression_printer import CppSimpleExpressionPrinter
+from pynestml.codegeneration.printers.cpp_type_symbol_printer import CppTypeSymbolPrinter
+from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter
+from pynestml.codegeneration.printers.nest_cpp_function_call_printer import NESTCppFunctionCallPrinter
+from pynestml.codegeneration.printers.nestml_variable_printer import NestMLVariablePrinter
+from pynestml.frontend.pynestml_frontend import generate_target
+from pynestml.symbol_table.symbol_table import SymbolTable
+from pynestml.symbols.predefined_functions import PredefinedFunctions
+from pynestml.symbols.predefined_types import PredefinedTypes
+from pynestml.symbols.predefined_units import PredefinedUnits
+from pynestml.symbols.predefined_variables import PredefinedVariables
+from pynestml.utils.ast_source_location import ASTSourceLocation
+from pynestml.utils.logger import Logger, LoggingLevel
+from pynestml.utils.model_parser import ModelParser
+
+
+class TestUnitSystem:
+ r"""
+ Test class for units system.
+ """
+
+ @pytest.fixture(scope="class", autouse=True)
+ def setUp(self, request):
+ Logger.set_logging_level(LoggingLevel.INFO)
+
+ SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0))
+
+ PredefinedUnits.register_units()
+ PredefinedTypes.register_types()
+ PredefinedVariables.register_variables()
+ PredefinedFunctions.register_functions()
+
+ Logger.init_logger(LoggingLevel.INFO)
+
+ variable_printer = NestMLVariablePrinter(None)
+ function_call_printer = NESTCppFunctionCallPrinter(None)
+ cpp_variable_printer = CppVariablePrinter(None)
+ self.printer = CppExpressionPrinter(CppSimpleExpressionPrinter(cpp_variable_printer,
+ ConstantPrinter(),
+ function_call_printer))
+ cpp_variable_printer._expression_printer = self.printer
+ variable_printer._expression_printer = self.printer
+ function_call_printer._expression_printer = self.printer
+
+ request.cls.printer = self.printer
+
+ def get_first_statement_in_update_block(self, model):
+ if model.get_model_list()[0].get_update_blocks()[0]:
+ return model.get_model_list()[0].get_update_blocks()[0].get_block().get_stmts()[0]
+
+ return None
+
+ def get_first_declaration_in_state_block(self, model):
+ assert len(model.get_model_list()[0].get_state_blocks()) == 1
+
+ return model.get_model_list()[0].get_state_blocks()[0].get_declarations()[0]
+
+ def get_first_declared_function(self, model):
+ return model.get_model_list()[0].get_functions()[0]
+
+ def print_rhs_of_first_assignment_in_update_block(self, model):
+ assignment = self.get_first_statement_in_update_block(model).small_stmt.get_assignment()
+ expression = assignment.get_expression()
+
+ return self.printer.print(expression)
+
+ def print_first_function_call_in_update_block(self, model):
+ function_call = self.get_first_statement_in_update_block(model).small_stmt.get_function_call()
+
+ return self.printer.print(function_call)
+
+ def print_rhs_of_first_declaration_in_state_block(self, model):
+ declaration = self.get_first_declaration_in_state_block(model)
+ expression = declaration.get_expression()
+
+ return self.printer.print(expression)
+
+ def print_first_return_statement_in_first_declared_function(self, model):
+ func = self.get_first_declared_function(model)
+ return_expression = func.get_block().get_stmts()[0].small_stmt.get_return_stmt().get_expression()
+ return self.printer.print(return_expression)
+
+ def test_expression_after_magnitude_conversion_in_direct_assignment(self):
+ model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DirectAssignmentWithDifferentButCompatibleUnits.nestml'))
+ printed_rhs_expression = self.print_rhs_of_first_assignment_in_update_block(model)
+
+ assert printed_rhs_expression == '(1000.0 * (10 * V))'
+
+ def test_expression_after_nested_magnitude_conversion_in_direct_assignment(self):
+ model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DirectAssignmentWithDifferentButCompatibleNestedUnits.nestml'))
+ printed_rhs_expression = self.print_rhs_of_first_assignment_in_update_block(model)
+
+ assert printed_rhs_expression == '(1000.0 * (10 * V + (0.001 * (5 * mV)) + 20 * V + (1000.0 * (1 * kV))))'
+
+ def test_expression_after_magnitude_conversion_in_compound_assignment(self):
+ model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'CompoundAssignmentWithDifferentButCompatibleUnits.nestml'))
+ printed_rhs_expression = self.print_rhs_of_first_assignment_in_update_block(model)
+
+ assert printed_rhs_expression == '(0.001 * (1200 * mV))'
+
+ def test_expression_after_magnitude_conversion_in_declaration(self):
+ model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DeclarationWithDifferentButCompatibleUnitMagnitude.nestml'))
+ printed_rhs_expression = self.print_rhs_of_first_declaration_in_state_block(model)
+
+ assert printed_rhs_expression == '(1000.0 * (10 * V))'
+
+ def test_expression_after_type_conversion_in_declaration(self):
+ model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DeclarationWithDifferentButCompatibleUnits.nestml'))
+ declaration = self.get_first_declaration_in_state_block(model)
+ from astropy import units as u
+
+ assert declaration.get_expression().type.unit.unit == u.mV
+
+ def test_declaration_with_same_variable_name_as_unit(self):
+ Logger.init_logger(LoggingLevel.DEBUG)
+
+ generate_target(input_path=os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DeclarationWithSameVariableNameAsUnit.nestml'), target_platform="NONE", logging_level="DEBUG")
+
+ assert len(Logger.get_all_messages_of_level_and_or_node("BlockTest", LoggingLevel.ERROR)) == 0
+ assert len(Logger.get_all_messages_of_level_and_or_node("BlockTest", LoggingLevel.WARNING)) == 3
+
+ def test_expression_after_magnitude_conversion_in_standalone_function_call(self):
+ model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'FunctionCallWithDifferentButCompatibleUnits.nestml'))
+ printed_function_call = self.print_first_function_call_in_update_block(model)
+
+ assert printed_function_call == 'foo((1000.0 * (10 * V)))'
+
+ def test_expression_after_magnitude_conversion_in_rhs_function_call(self):
+ model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'RhsFunctionCallWithDifferentButCompatibleUnits.nestml'))
+ printed_function_call = self.print_rhs_of_first_assignment_in_update_block(model)
+
+ assert printed_function_call == 'foo((1000.0 * (10 * V)))'
+
+ def test_return_stmt_after_magnitude_conversion_in_function_body(self):
+ model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'FunctionBodyReturnStatementWithDifferentButCompatibleUnits.nestml'))
+ printed_return_stmt = self.print_first_return_statement_in_first_declared_function(model)
+
+ assert printed_return_stmt == '(0.001 * (bar))'
diff --git a/tests/unit_system_test.py b/tests/unit_system_test.py
deleted file mode 100644
index 1f7817b91..000000000
--- a/tests/unit_system_test.py
+++ /dev/null
@@ -1,177 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# unit_system_test.py
-#
-# This file is part of NEST.
-#
-# Copyright (C) 2004 The NEST Initiative
-#
-# NEST is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 2 of the License, or
-# (at your option) any later version.
-#
-# NEST is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with NEST. If not, see .
-
-import os
-import unittest
-from pynestml.codegeneration.printers.constant_printer import ConstantPrinter
-
-from pynestml.codegeneration.printers.cpp_expression_printer import CppExpressionPrinter
-from pynestml.codegeneration.printers.cpp_simple_expression_printer import CppSimpleExpressionPrinter
-from pynestml.codegeneration.printers.cpp_type_symbol_printer import CppTypeSymbolPrinter
-from pynestml.codegeneration.printers.nestml_variable_printer import NestMLVariablePrinter
-from pynestml.symbol_table.symbol_table import SymbolTable
-from pynestml.symbols.predefined_functions import PredefinedFunctions
-from pynestml.symbols.predefined_types import PredefinedTypes
-from pynestml.symbols.predefined_units import PredefinedUnits
-from pynestml.symbols.predefined_variables import PredefinedVariables
-from pynestml.utils.ast_source_location import ASTSourceLocation
-from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter
-from pynestml.codegeneration.printers.nest_cpp_function_call_printer import NESTCppFunctionCallPrinter
-from pynestml.codegeneration.printers.cpp_function_call_printer import CppFunctionCallPrinter
-from pynestml.utils.logger import Logger, LoggingLevel
-from pynestml.utils.model_parser import ModelParser
-
-
-SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0))
-
-PredefinedUnits.register_units()
-PredefinedTypes.register_types()
-PredefinedVariables.register_variables()
-PredefinedFunctions.register_functions()
-
-Logger.init_logger(LoggingLevel.INFO)
-
-type_symbol_printer = CppTypeSymbolPrinter()
-variable_printer = NestMLVariablePrinter(None)
-function_call_printer = NESTCppFunctionCallPrinter(None)
-cpp_variable_printer = CppVariablePrinter(None)
-printer = CppExpressionPrinter(CppSimpleExpressionPrinter(cpp_variable_printer,
- ConstantPrinter(),
- function_call_printer))
-cpp_variable_printer._expression_printer = printer
-variable_printer._expression_printer = printer
-function_call_printer._expression_printer = printer
-
-
-def get_first_statement_in_update_block(model):
- if model.get_model_list()[0].get_update_blocks()[0]:
- return model.get_model_list()[0].get_update_blocks()[0].get_block().get_stmts()[0]
- return None
-
-
-def get_first_declaration_in_state_block(model):
- assert len(model.get_model_list()[0].get_state_blocks()) == 1
- return model.get_model_list()[0].get_state_blocks()[0].get_declarations()[0]
-
-
-def get_first_declared_function(model):
- return model.get_model_list()[0].get_functions()[0]
-
-
-def print_rhs_of_first_assignment_in_update_block(model):
- assignment = get_first_statement_in_update_block(model).small_stmt.get_assignment()
- expression = assignment.get_expression()
- return printer.print(expression)
-
-
-def print_first_function_call_in_update_block(model):
- function_call = get_first_statement_in_update_block(model).small_stmt.get_function_call()
- return printer.print(function_call)
-
-
-def print_rhs_of_first_declaration_in_state_block(model):
- declaration = get_first_declaration_in_state_block(model)
- expression = declaration.get_expression()
- return printer.print(expression)
-
-
-def print_first_return_statement_in_first_declared_function(model):
- func = get_first_declared_function(model)
- return_expression = func.get_block().get_stmts()[0].small_stmt.get_return_stmt().get_expression()
- return printer.print(return_expression)
-
-
-class UnitSystemTest(unittest.TestCase):
- """
- Test class for everything Unit related.
- """
-
- def setUp(self):
- Logger.set_logging_level(LoggingLevel.INFO)
-
- def test_expression_after_magnitude_conversion_in_direct_assignment(self):
- Logger.set_logging_level(LoggingLevel.INFO)
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'DirectAssignmentWithDifferentButCompatibleUnits.nestml'))
- printed_rhs_expression = print_rhs_of_first_assignment_in_update_block(model)
-
- self.assertEqual(printed_rhs_expression, '(1000.0 * (10 * V))')
-
- def test_expression_after_nested_magnitude_conversion_in_direct_assignment(self):
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'DirectAssignmentWithDifferentButCompatibleNestedUnits.nestml'))
- printed_rhs_expression = print_rhs_of_first_assignment_in_update_block(model)
-
- self.assertEqual(printed_rhs_expression, '(1000.0 * (10 * V + (0.001 * (5 * mV)) + 20 * V + (1000.0 * (1 * kV))))')
-
- def test_expression_after_magnitude_conversion_in_compound_assignment(self):
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'CompoundAssignmentWithDifferentButCompatibleUnits.nestml'))
- printed_rhs_expression = print_rhs_of_first_assignment_in_update_block(model)
- self.assertEqual(printed_rhs_expression, '(0.001 * (1200 * mV))')
-
- def test_expression_after_magnitude_conversion_in_declaration(self):
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'DeclarationWithDifferentButCompatibleUnitMagnitude.nestml'))
- printed_rhs_expression = print_rhs_of_first_declaration_in_state_block(model)
- self.assertEqual(printed_rhs_expression, '(1000.0 * (10 * V))')
-
- def test_expression_after_type_conversion_in_declaration(self):
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'DeclarationWithDifferentButCompatibleUnits.nestml'))
- declaration = get_first_declaration_in_state_block(model)
- from astropy import units as u
- self.assertTrue(declaration.get_expression().type.unit.unit == u.mV)
-
- def test_declaration_with_same_variable_name_as_unit(self):
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'DeclarationWithSameVariableNameAsUnit.nestml'))
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0)
- self.assertEqual(len(
- Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.WARNING)), 3)
-
- def test_expression_after_magnitude_conversion_in_standalone_function_call(self):
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'FunctionCallWithDifferentButCompatibleUnits.nestml'))
- printed_function_call = print_first_function_call_in_update_block(model)
- self.assertEqual(printed_function_call, 'foo((1000.0 * (10 * V)))')
-
- def test_expression_after_magnitude_conversion_in_rhs_function_call(self):
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'RhsFunctionCallWithDifferentButCompatibleUnits.nestml'))
- printed_function_call = print_rhs_of_first_assignment_in_update_block(model)
- self.assertEqual(printed_function_call, 'foo((1000.0 * (10 * V)))')
-
- def test_return_stmt_after_magnitude_conversion_in_function_body(self):
- model = ModelParser.parse_file(
- os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')),
- 'FunctionBodyReturnStatementWithDifferentButCompatibleUnits.nestml'))
- printed_return_stmt = print_first_return_statement_in_first_declared_function(model)
- self.assertEqual(printed_return_stmt, '(0.001 * (bar))')