diff --git a/.github/workflows/continuous_benchmarking.yml b/.github/workflows/continuous_benchmarking.yml new file mode 100644 index 000000000..980da3473 --- /dev/null +++ b/.github/workflows/continuous_benchmarking.yml @@ -0,0 +1,91 @@ +on: + pull_request_target: + types: [opened, reopened, edited, synchronize] + +jobs: + fork_pr_requires_review: + environment: ${{ (github.event.pull_request.head.repo.full_name == github.repository && 'internal') || 'external' }} + runs-on: ubuntu-latest + steps: + - run: true + + benchmark_fork_pr_branch: + needs: fork_pr_requires_review + name: Continuous Benchmarking Fork PRs with Bencher + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + repository: ${{ github.event.pull_request.head.repo.full_name }} + ref: ${{ github.event.pull_request.head.sha }} + persist-credentials: false + + - uses: bencherdev/bencher@main + + # Setup Python version + - name: Setup Python 3.8 + uses: actions/setup-python@v5 + with: + python-version: 3.8 + + # Install dependencies + - name: Install apt dependencies + run: | + sudo apt-get update + sudo apt-get install libltdl7-dev libgsl0-dev libncurses5-dev libreadline6-dev pkg-config + sudo apt-get install python3-all-dev python3-matplotlib python3-numpy python3-scipy ipython3 + + # Install Python dependencies + - name: Python dependencies + run: | + python -m pip install --upgrade pip pytest jupyterlab matplotlib pycodestyle scipy pandas pytest-benchmark + python -m pip install -r requirements.txt + + # Install NEST simulator + - name: NEST simulator + run: | + python -m pip install cython + echo "GITHUB_WORKSPACE = $GITHUB_WORKSPACE" + NEST_SIMULATOR=$(pwd)/nest-simulator + NEST_INSTALL=$(pwd)/nest_install + echo "NEST_SIMULATOR = $NEST_SIMULATOR" + echo "NEST_INSTALL = $NEST_INSTALL" + + git clone --depth=1 https://github.com/nest/nest-simulator + mkdir nest_install + echo "NEST_INSTALL=$NEST_INSTALL" >> $GITHUB_ENV + cd nest_install + cmake -DCMAKE_INSTALL_PREFIX=$NEST_INSTALL $NEST_SIMULATOR + make && make install + cd .. + + # Install NESTML (repeated) + - name: Install NESTML + run: | + export PYTHONPATH=${{ env.PYTHONPATH }}:${{ env.NEST_INSTALL }}/lib/python3.8/site-packages + #echo PYTHONPATH=`pwd` >> $GITHUB_ENV + echo "PYTHONPATH=$PYTHONPATH" >> $GITHUB_ENV + python setup.py install + + - name: Track Fork PR Benchmarks with Bencher + env: + LD_LIBRARY_PATH: ${{ env.NEST_INSTALL }}/lib/nest + run: | + echo "NEST_INSTALL = $NEST_INSTALL" + LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${{ env.NEST_INSTALL }}/lib/nest bencher run \ + --project nestml \ + --token '${{ secrets.BENCHER_API_TOKEN }}' \ + --branch '${{ github.event.number }}/merge' \ + --branch-start-point '${{ github.base_ref }}' \ + --branch-start-point-hash '${{ github.event.pull_request.base.sha }}' \ + --branch-reset \ + --github-actions "${{ secrets.GITHUB_TOKEN }}" \ + --testbed ubuntu-latest \ + --adapter python_pytest \ + --file results.json \ + --err \ + 'LD_LIBRARY_PATH=$LD_LIBRARY_PATH:${{ env.NEST_INSTALL }}/lib/nest python3 -m pytest --benchmark-json results.json -s $GITHUB_WORKSPACE/tests/nest_continuous_benchmarking/test_nest_continuous_benchmarking.py' + + - name: Setup tmate session + if: ${{ failure() }} + uses: mxschmitt/action-tmate@v3 diff --git a/doc/tutorials/stdp_third_factor_active_dendrite/stdp_third_factor_active_dendrite.ipynb b/doc/tutorials/stdp_third_factor_active_dendrite/stdp_third_factor_active_dendrite.ipynb index 147f5b4a1..3922d0ac8 100644 --- a/doc/tutorials/stdp_third_factor_active_dendrite/stdp_third_factor_active_dendrite.ipynb +++ b/doc/tutorials/stdp_third_factor_active_dendrite/stdp_third_factor_active_dendrite.ipynb @@ -1347,7 +1347,7 @@ " NESTCodeGeneratorUtils.generate_code_for(nestml_neuron_model,\n", " nestml_synapse_model,\n", " codegen_opts=codegen_opts,\n", - " logging_level=\"INFO\") # try \"INFO\" or \"DEBUG\" for more debug information" + " logging_level=\"WARNING\") # try \"INFO\" or \"DEBUG\" for more debug information" ] }, { diff --git a/pynestml/cocos/co_co_all_variables_defined.py b/pynestml/cocos/co_co_all_variables_defined.py index e41b0727e..38cfa89ab 100644 --- a/pynestml/cocos/co_co_all_variables_defined.py +++ b/pynestml/cocos/co_co_all_variables_defined.py @@ -41,11 +41,10 @@ class CoCoAllVariablesDefined(CoCo): """ @classmethod - def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False): + def check_co_co(cls, node: ASTModel): """ Checks if this coco applies for the handed over neuron. Models which contain undefined variables are not correct. :param node: a single neuron instance. - :param after_ast_rewrite: indicates whether this coco is checked after the code generator has done rewriting of the abstract syntax tree. If True, checks are not as rigorous. Use False where possible. """ # for each variable in all expressions, check if the variable has been defined previously expression_collector_visitor = ASTExpressionCollectorVisitor() @@ -62,32 +61,6 @@ def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False): # test if the symbol has been defined at least if symbol is None: - if after_ast_rewrite: # after ODE-toolbox transformations, convolutions are replaced by state variables, so cannot perform this check properly - symbol2 = node.get_scope().resolve_to_symbol(var.get_name(), SymbolKind.VARIABLE) - if symbol2 is not None: - # an inline expression defining this variable name (ignoring differential order) exists - if "__X__" in str(symbol2): # if this variable was the result of a convolution... - continue - else: - # for kernels, also allow derivatives of that kernel to appear - - inline_expr_names = [] - inline_exprs = [] - for equations_block in node.get_equations_blocks(): - inline_expr_names.extend([inline_expr.variable_name for inline_expr in equations_block.get_inline_expressions()]) - inline_exprs.extend(equations_block.get_inline_expressions()) - - if var.get_name() in inline_expr_names: - inline_expr_idx = inline_expr_names.index(var.get_name()) - inline_expr = inline_exprs[inline_expr_idx] - from pynestml.utils.ast_utils import ASTUtils - if ASTUtils.inline_aliases_convolution(inline_expr): - symbol2 = node.get_scope().resolve_to_symbol(var.get_name(), SymbolKind.VARIABLE) - if symbol2 is not None: - # actually, no problem detected, skip error - # XXX: TODO: check that differential order is less than or equal to that of the kernel - continue - # check if this symbol is actually a type, e.g. "mV" in the expression "(1 + 2) * mV" symbol2 = var.get_scope().resolve_to_symbol(var.get_complete_name(), SymbolKind.TYPE) if symbol2 is not None: @@ -106,9 +79,14 @@ def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False): # in this case its ok if it is recursive or defined later on continue + if symbol.is_predefined: + continue + + if symbol.block_type == BlockType.LOCAL and symbol.get_referenced_object().get_source_position().before(var.get_source_position()): + continue + # check if it has been defined before usage, except for predefined symbols, input ports and variables added by the AST transformation functions - if (not symbol.is_predefined) \ - and symbol.block_type != BlockType.INPUT \ + if symbol.block_type != BlockType.INPUT \ and not symbol.get_referenced_object().get_source_position().is_added_source_position(): # except for parameters, those can be defined after if ((not symbol.get_referenced_object().get_source_position().before(var.get_source_position())) diff --git a/pynestml/cocos/co_co_function_unique.py b/pynestml/cocos/co_co_function_unique.py index 15643c0ad..bf0f2be60 100644 --- a/pynestml/cocos/co_co_function_unique.py +++ b/pynestml/cocos/co_co_function_unique.py @@ -65,4 +65,5 @@ def check_co_co(cls, model: ASTModel): log_level=LoggingLevel.ERROR, message=message, code=code) checked.append(funcA) + checked_funcs_names.append(func.get_name()) diff --git a/pynestml/cocos/co_co_illegal_expression.py b/pynestml/cocos/co_co_illegal_expression.py index b78396e3b..c362d0dc5 100644 --- a/pynestml/cocos/co_co_illegal_expression.py +++ b/pynestml/cocos/co_co_illegal_expression.py @@ -18,13 +18,13 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from pynestml.meta_model.ast_inline_expression import ASTInlineExpression -from pynestml.utils.ast_source_location import ASTSourceLocation -from pynestml.meta_model.ast_declaration import ASTDeclaration from pynestml.cocos.co_co import CoCo +from pynestml.meta_model.ast_declaration import ASTDeclaration +from pynestml.meta_model.ast_inline_expression import ASTInlineExpression from pynestml.symbols.error_type_symbol import ErrorTypeSymbol from pynestml.symbols.predefined_types import PredefinedTypes +from pynestml.utils.ast_source_location import ASTSourceLocation from pynestml.utils.logger import LoggingLevel, Logger from pynestml.utils.logging_helper import LoggingHelper from pynestml.utils.messages import Messages diff --git a/pynestml/cocos/co_co_nest_random_functions_legally_used.py b/pynestml/cocos/co_co_nest_random_functions_legally_used.py new file mode 100644 index 000000000..81e2fc464 --- /dev/null +++ b/pynestml/cocos/co_co_nest_random_functions_legally_used.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- +# +# co_co_nest_random_functions_legally_used.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from pynestml.cocos.co_co import CoCo +from pynestml.meta_model.ast_model import ASTModel +from pynestml.meta_model.ast_node import ASTNode +from pynestml.meta_model.ast_on_condition_block import ASTOnConditionBlock +from pynestml.meta_model.ast_on_receive_block import ASTOnReceiveBlock +from pynestml.meta_model.ast_update_block import ASTUpdateBlock +from pynestml.symbols.predefined_functions import PredefinedFunctions +from pynestml.utils.logger import LoggingLevel, Logger +from pynestml.utils.messages import Messages +from pynestml.visitors.ast_visitor import ASTVisitor + + +class CoCoNestRandomFunctionsLegallyUsed(CoCo): + """ + This CoCo ensure that the random functions are used only in the ``update``, ``onReceive``, and ``onCondition`` blocks. + This CoCo is only checked for the NEST Simulator target. + """ + + @classmethod + def check_co_co(cls, node: ASTNode): + """ + Checks the coco. + :param node: a single node (typically, a neuron or synapse) + """ + visitor = CoCoNestRandomFunctionsLegallyUsedVisitor() + visitor.neuron = node + node.accept(visitor) + + +class CoCoNestRandomFunctionsLegallyUsedVisitor(ASTVisitor): + def visit_function_call(self, node): + """ + Visits a function call + :param node: a function call + """ + function_name = node.get_name() + if function_name == PredefinedFunctions.RANDOM_NORMAL or function_name == PredefinedFunctions.RANDOM_UNIFORM \ + or function_name == PredefinedFunctions.RANDOM_POISSON: + parent = node + while parent: + parent = parent.get_parent() + + if isinstance(parent, ASTUpdateBlock) or isinstance(parent, ASTOnReceiveBlock) \ + or isinstance(parent, ASTOnConditionBlock): + # the random function is correctly defined, hence return + return + + if isinstance(parent, ASTModel): + # the random function is defined in other blocks (parameters, state, internals). Hence, an error. + code, message = Messages.get_random_functions_legally_used(function_name) + Logger.log_message(node=self.neuron, code=code, message=message, error_position=node.get_source_position(), + log_level=LoggingLevel.ERROR) diff --git a/pynestml/cocos/co_co_no_kernels_except_in_convolve.py b/pynestml/cocos/co_co_no_kernels_except_in_convolve.py index 18b862292..e318ae566 100644 --- a/pynestml/cocos/co_co_no_kernels_except_in_convolve.py +++ b/pynestml/cocos/co_co_no_kernels_except_in_convolve.py @@ -22,11 +22,14 @@ from typing import List from pynestml.cocos.co_co import CoCo +from pynestml.meta_model.ast_declaration import ASTDeclaration +from pynestml.meta_model.ast_external_variable import ASTExternalVariable from pynestml.meta_model.ast_function_call import ASTFunctionCall from pynestml.meta_model.ast_kernel import ASTKernel from pynestml.meta_model.ast_model import ASTModel from pynestml.meta_model.ast_node import ASTNode from pynestml.meta_model.ast_variable import ASTVariable +from pynestml.symbols.predefined_functions import PredefinedFunctions from pynestml.symbols.symbol import SymbolKind from pynestml.utils.logger import Logger, LoggingLevel from pynestml.utils.messages import Messages @@ -89,24 +92,44 @@ def visit_variable(self, node: ASTNode): if not (isinstance(node, ASTExternalVariable) and node.get_alternate_name()): code, message = Messages.get_no_variable_found(kernelName) Logger.log_message(node=self.__neuron_node, code=code, message=message, log_level=LoggingLevel.ERROR) + continue + if not symbol.is_kernel(): continue + if node.get_complete_name() == kernelName: - parent = node.get_parent() - if parent is not None: + parent = node + correct = False + while parent is not None and not isinstance(parent, ASTModel): + parent = parent.get_parent() + assert parent is not None + + if isinstance(parent, ASTDeclaration): + for lhs_var in parent.get_variables(): + if kernelName == lhs_var.get_complete_name(): + # kernel name appears on lhs of declaration, assume it is initial state + correct = True + parent = None # break out of outer loop + break + if isinstance(parent, ASTKernel): - continue - grandparent = parent.get_parent() - if grandparent is not None and isinstance(grandparent, ASTFunctionCall): - grandparent_func_name = grandparent.get_name() - if grandparent_func_name == 'convolve': - continue - code, message = Messages.get_kernel_outside_convolve(kernelName) - Logger.log_message(code=code, - message=message, - log_level=LoggingLevel.ERROR, - error_position=node.get_source_position()) + # kernel name is used inside kernel definition, e.g. for a node ``g``, it appears in ``kernel g'' = -1/tau**2 * g - 2/tau * g'`` + correct = True + break + + if isinstance(parent, ASTFunctionCall): + func_name = parent.get_name() + if func_name == PredefinedFunctions.CONVOLVE: + # kernel name is used inside convolve call + correct = True + + if not correct: + code, message = Messages.get_kernel_outside_convolve(kernelName) + Logger.log_message(code=code, + message=message, + log_level=LoggingLevel.ERROR, + error_position=node.get_source_position()) class KernelCollectingVisitor(ASTVisitor): diff --git a/pynestml/cocos/co_co_v_comp_exists.py b/pynestml/cocos/co_co_v_comp_exists.py index 4ef08c0ec..51308f2cc 100644 --- a/pynestml/cocos/co_co_v_comp_exists.py +++ b/pynestml/cocos/co_co_v_comp_exists.py @@ -43,9 +43,6 @@ def check_co_co(cls, neuron: ASTModel): Models which are supposed to be compartmental but do not contain state variable called v_comp are not correct. :param neuron: a single neuron instance. - :param after_ast_rewrite: indicates whether this coco is checked - after the code generator has done rewriting of the abstract syntax tree. - If True, checks are not as rigorous. Use False where possible. """ from pynestml.codegeneration.nest_compartmental_code_generator import NESTCompartmentalCodeGenerator diff --git a/pynestml/cocos/co_cos_manager.py b/pynestml/cocos/co_cos_manager.py index 01d008890..c90ffa2b1 100644 --- a/pynestml/cocos/co_cos_manager.py +++ b/pynestml/cocos/co_cos_manager.py @@ -23,11 +23,9 @@ from pynestml.cocos.co_co_all_variables_defined import CoCoAllVariablesDefined from pynestml.cocos.co_co_inline_expression_not_assigned_to import CoCoInlineExpressionNotAssignedTo -from pynestml.cocos.co_co_input_port_not_assigned_to import CoCoInputPortNotAssignedTo from pynestml.cocos.co_co_cm_channel_model import CoCoCmChannelModel from pynestml.cocos.co_co_cm_continuous_input_model import CoCoCmContinuousInputModel from pynestml.cocos.co_co_convolve_cond_correctly_built import CoCoConvolveCondCorrectlyBuilt -from pynestml.cocos.co_co_convolve_has_correct_parameter import CoCoConvolveHasCorrectParameter from pynestml.cocos.co_co_input_port_not_assigned_to import CoCoInputPortNotAssignedTo from pynestml.cocos.co_co_integrate_odes_params_correct import CoCoIntegrateODEsParamsCorrect from pynestml.cocos.co_co_correct_numerator_of_unit import CoCoCorrectNumeratorOfUnit @@ -43,6 +41,7 @@ from pynestml.cocos.co_co_invariant_is_boolean import CoCoInvariantIsBoolean from pynestml.cocos.co_co_kernel_type import CoCoKernelType from pynestml.cocos.co_co_model_name_unique import CoCoModelNameUnique +from pynestml.cocos.co_co_nest_random_functions_legally_used import CoCoNestRandomFunctionsLegallyUsed from pynestml.cocos.co_co_no_kernels_except_in_convolve import CoCoNoKernelsExceptInConvolve from pynestml.cocos.co_co_no_nest_name_space_collision import CoCoNoNestNameSpaceCollision from pynestml.cocos.co_co_no_duplicate_compilation_unit_names import CoCoNoDuplicateCompilationUnitNames @@ -69,6 +68,7 @@ from pynestml.cocos.co_co_priorities_correctly_specified import CoCoPrioritiesCorrectlySpecified from pynestml.meta_model.ast_model import ASTModel from pynestml.frontend.frontend_configuration import FrontendConfiguration +from pynestml.utils.logger import Logger class CoCosManager: @@ -123,12 +123,12 @@ def check_state_variables_initialized(cls, model: ASTModel): CoCoStateVariablesInitialized.check_co_co(model) @classmethod - def check_variables_defined_before_usage(cls, model: ASTModel, after_ast_rewrite: bool) -> None: + def check_variables_defined_before_usage(cls, model: ASTModel) -> None: """ Checks that all variables are defined before being used. :param model: a single model. """ - CoCoAllVariablesDefined.check_co_co(model, after_ast_rewrite) + CoCoAllVariablesDefined.check_co_co(model) @classmethod def check_v_comp_requirement(cls, neuron: ASTModel): @@ -402,17 +402,27 @@ def check_input_port_size_type(cls, model: ASTModel): CoCoVectorInputPortsCorrectSizeType.check_co_co(model) @classmethod - def post_symbol_table_builder_checks(cls, model: ASTModel, after_ast_rewrite: bool = False): + def check_co_co_nest_random_functions_legally_used(cls, model: ASTModel): + """ + Checks if the random number functions are used only in the update block. + :param model: a single model object. + """ + CoCoNestRandomFunctionsLegallyUsed.check_co_co(model) + + @classmethod + def check_cocos(cls, model: ASTModel, after_ast_rewrite: bool = False): """ Checks all context conditions. :param model: a single model object. """ + Logger.set_current_node(model) + cls.check_each_block_defined_at_most_once(model) cls.check_function_defined(model) cls.check_variables_unique_in_scope(model) cls.check_inline_expression_not_assigned_to(model) cls.check_state_variables_initialized(model) - cls.check_variables_defined_before_usage(model, after_ast_rewrite) + cls.check_variables_defined_before_usage(model) if FrontendConfiguration.get_target_platform().upper() == 'NEST_COMPARTMENTAL': # XXX: TODO: refactor this out; define a ``cocos_from_target_name()`` in the frontend instead. cls.check_v_comp_requirement(model) @@ -452,3 +462,5 @@ def post_symbol_table_builder_checks(cls, model: ASTModel, after_ast_rewrite: bo cls.check_co_co_priorities_correctly_specified(model) cls.check_resolution_func_legally_used(model) cls.check_input_port_size_type(model) + + Logger.set_current_node(None) diff --git a/pynestml/codegeneration/builder.py b/pynestml/codegeneration/builder.py index 2e6757c1a..a9f98bf58 100644 --- a/pynestml/codegeneration/builder.py +++ b/pynestml/codegeneration/builder.py @@ -20,12 +20,12 @@ # along with NEST. If not, see . from __future__ import annotations -import subprocess -import os from typing import Any, Mapping, Optional from abc import ABCMeta, abstractmethod +import os +import subprocess from pynestml.exceptions.invalid_target_exception import InvalidTargetException from pynestml.frontend.frontend_configuration import FrontendConfiguration diff --git a/pynestml/codegeneration/nest_code_generator.py b/pynestml/codegeneration/nest_code_generator.py index 0551e9a6e..cc3c91b93 100644 --- a/pynestml/codegeneration/nest_code_generator.py +++ b/pynestml/codegeneration/nest_code_generator.py @@ -28,6 +28,7 @@ import pynestml from pynestml.cocos.co_co_nest_synapse_delay_not_assigned_to import CoCoNESTSynapseDelayNotAssignedTo +from pynestml.cocos.co_cos_manager import CoCosManager from pynestml.codegeneration.code_generator import CodeGenerator from pynestml.codegeneration.code_generator_utils import CodeGeneratorUtils from pynestml.codegeneration.nest_assignments_helper import NestAssignmentsHelper @@ -172,21 +173,30 @@ def __init__(self, options: Optional[Mapping[str, Any]] = None): self.setup_printers() def run_nest_target_specific_cocos(self, neurons: Sequence[ASTModel], synapses: Sequence[ASTModel]): - for synapse in synapses: - synapse_name_stripped = removesuffix(removesuffix(synapse.name.split("_with_")[0], "_"), FrontendConfiguration.suffix) + for model in neurons + synapses: + # Check if the random number functions are used in the right blocks + CoCosManager.check_co_co_nest_random_functions_legally_used(model) + + if Logger.has_errors(model): + raise Exception("Error(s) occurred during code generation") + + if self.get_option("neuron_synapse_pairs"): + for model in synapses: + synapse_name_stripped = removesuffix(removesuffix(model.name.split("_with_")[0], "_"), + FrontendConfiguration.suffix) + # special case for NEST delay variable (state or parameter) + assert synapse_name_stripped in self.get_option("delay_variable").keys(), "Please specify a delay variable for synapse '" + synapse_name_stripped + "' in the code generator options (see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay-and-synaptic-weight)" + assert ASTUtils.get_variable_by_name(model, self.get_option("delay_variable")[synapse_name_stripped]), "Delay variable '" + self.get_option("delay_variable")[synapse_name_stripped] + "' not found in synapse '" + synapse_name_stripped + "' (see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay-and-synaptic-weight)" - # special case for NEST delay variable (state or parameter) - assert synapse_name_stripped in self.get_option("delay_variable").keys(), "Please specify a delay variable for synapse '" + synapse_name_stripped + "' in the code generator options (see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay-and-synaptic-weight)" - assert ASTUtils.get_variable_by_name(synapse, self.get_option("delay_variable")[synapse_name_stripped]), "Delay variable '" + self.get_option("delay_variable")[synapse_name_stripped] + "' not found in synapse '" + synapse_name_stripped + "' (see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay-and-synaptic-weight)" + # special case for NEST weight variable (state or parameter) + assert synapse_name_stripped in self.get_option("weight_variable").keys(), "Please specify a weight variable for synapse '" + synapse_name_stripped + "' in the code generator options (see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay-and-synaptic-weight)" + assert ASTUtils.get_variable_by_name(model, self.get_option("weight_variable")[synapse_name_stripped]), "Weight variable '" + self.get_option("weight_variable")[synapse_name_stripped] + "' not found in synapse '" + synapse_name_stripped + "' (see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay-and-synaptic-weight)" - # special case for NEST weight variable (state or parameter) - assert synapse_name_stripped in self.get_option("weight_variable").keys(), "Please specify a weight variable for synapse '" + synapse_name_stripped + "' in the code generator options (see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay-and-synaptic-weight)" - assert ASTUtils.get_variable_by_name(synapse, self.get_option("weight_variable")[synapse_name_stripped]), "Weight variable '" + self.get_option("weight_variable")[synapse_name_stripped] + "' not found in synapse '" + synapse_name_stripped + "' (see https://nestml.readthedocs.io/en/latest/running/running_nest.html#dendritic-delay-and-synaptic-weight)" + if self.option_exists("delay_variable") and synapse_name_stripped in self.get_option("delay_variable").keys(): + delay_variable = self.get_option("delay_variable")[synapse_name_stripped] + CoCoNESTSynapseDelayNotAssignedTo.check_co_co(delay_variable, model) - if self.option_exists("delay_variable") and synapse_name_stripped in self.get_option("delay_variable").keys(): - delay_variable = self.get_option("delay_variable")[synapse_name_stripped] - CoCoNESTSynapseDelayNotAssignedTo.check_co_co(delay_variable, synapse) - if Logger.has_errors(synapse): + if Logger.has_errors(model): raise Exception("Error(s) occurred during code generation") def setup_printers(self): @@ -374,6 +384,9 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di if not used_in_eq: self.non_equations_state_variables[neuron.get_name()].append(var) + # cache state variables before symbol table update for the sake of delay variables + state_vars_before_update = neuron.get_state_symbols() + ASTUtils.remove_initial_values_for_kernels(neuron) kernels = ASTUtils.remove_kernel_definitions_from_equations_block(neuron) ASTUtils.update_initial_values_for_odes(neuron, [analytic_solver, numeric_solver]) @@ -388,7 +401,6 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di neuron = ASTUtils.add_declarations_to_internals( neuron, self.analytic_solver[neuron.get_name()]["propagators"]) - state_vars_before_update = neuron.get_state_symbols() self.update_symbol_table(neuron) # Update the delay parameter parameters after symbol table update @@ -898,8 +910,8 @@ def update_symbol_table(self, neuron) -> None: """ SymbolTable.delete_model_scope(neuron.get_name()) symbol_table_visitor = ASTSymbolTableVisitor() - symbol_table_visitor.after_ast_rewrite_ = True neuron.accept(symbol_table_visitor) + CoCosManager.check_cocos(neuron, after_ast_rewrite=True) SymbolTable.add_model_scope(neuron.get_name(), neuron.get_scope()) def get_spike_update_expressions(self, neuron: ASTModel, kernel_buffers, solver_dicts, delta_factors) -> Tuple[Dict[str, ASTAssignment], Dict[str, ASTAssignment]]: diff --git a/pynestml/codegeneration/nest_compartmental_code_generator.py b/pynestml/codegeneration/nest_compartmental_code_generator.py index 4711bc497..00f061775 100644 --- a/pynestml/codegeneration/nest_compartmental_code_generator.py +++ b/pynestml/codegeneration/nest_compartmental_code_generator.py @@ -18,14 +18,18 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -import shutil + from typing import Any, Dict, List, Mapping, Optional import datetime import os from jinja2 import TemplateRuntimeError + +from odetoolbox import analysis + import pynestml +from pynestml.cocos.co_cos_manager import CoCosManager from pynestml.codegeneration.code_generator import CodeGenerator from pynestml.codegeneration.nest_assignments_helper import NestAssignmentsHelper from pynestml.codegeneration.nest_declarations_helper import NestDeclarationsHelper @@ -53,9 +57,9 @@ from pynestml.meta_model.ast_variable import ASTVariable from pynestml.symbol_table.symbol_table import SymbolTable from pynestml.symbols.symbol import SymbolKind +from pynestml.transformers.inline_expression_expansion_transformer import InlineExpressionExpansionTransformer from pynestml.utils.ast_vector_parameter_setter_and_printer import ASTVectorParameterSetterAndPrinter from pynestml.utils.ast_vector_parameter_setter_and_printer_factory import ASTVectorParameterSetterAndPrinterFactory -from pynestml.transformers.inline_expression_expansion_transformer import InlineExpressionExpansionTransformer from pynestml.utils.mechanism_processing import MechanismProcessing from pynestml.utils.channel_processing import ChannelProcessing from pynestml.utils.concentration_processing import ConcentrationProcessing @@ -72,7 +76,6 @@ from pynestml.utils.synapse_processing import SynapseProcessing from pynestml.visitors.ast_random_number_generator_visitor import ASTRandomNumberGeneratorVisitor from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor -from odetoolbox import analysis class NESTCompartmentalCodeGenerator(CodeGenerator): @@ -740,8 +743,8 @@ def update_symbol_table(self, neuron, kernel_buffers): """ SymbolTable.delete_model_scope(neuron.get_name()) symbol_table_visitor = ASTSymbolTableVisitor() - symbol_table_visitor.after_ast_rewrite_ = True neuron.accept(symbol_table_visitor) + CoCosManager.check_cocos(neuron, after_ast_rewrite=True) SymbolTable.add_model_scope(neuron.get_name(), neuron.get_scope()) def _get_ast_variable(self, neuron, var_name) -> Optional[ASTVariable]: diff --git a/pynestml/codegeneration/python_standalone_code_generator.py b/pynestml/codegeneration/python_standalone_code_generator.py index f44123743..d6afaa095 100644 --- a/pynestml/codegeneration/python_standalone_code_generator.py +++ b/pynestml/codegeneration/python_standalone_code_generator.py @@ -111,7 +111,6 @@ def setup_printers(self): # GSL printers self._gsl_variable_printer = PythonSteppingFunctionVariablePrinter(None) - print("In Python code generator: created self._gsl_variable_printer = " + str(self._gsl_variable_printer)) self._gsl_function_call_printer = PythonSteppingFunctionFunctionCallPrinter(None) self._gsl_printer = PythonExpressionPrinter(simple_expression_printer=PythonSimpleExpressionPrinter(variable_printer=self._gsl_variable_printer, constant_printer=self._constant_printer, diff --git a/pynestml/codegeneration/spinnaker_code_generator.py b/pynestml/codegeneration/spinnaker_code_generator.py index 2a8fed7de..dce247e9c 100644 --- a/pynestml/codegeneration/spinnaker_code_generator.py +++ b/pynestml/codegeneration/spinnaker_code_generator.py @@ -137,7 +137,6 @@ def setup_printers(self): # GSL printers self._gsl_variable_printer = PythonSteppingFunctionVariablePrinter(None) - print("In Python code generator: created self._gsl_variable_printer = " + str(self._gsl_variable_printer)) self._gsl_function_call_printer = PythonSteppingFunctionFunctionCallPrinter(None) self._gsl_printer = PythonExpressionPrinter(simple_expression_printer=SpinnakerPythonSimpleExpressionPrinter( variable_printer=self._gsl_variable_printer, @@ -216,6 +215,7 @@ def generate_code(self, models: Sequence[ASTModel]) -> None: for model in models: cloned_model = model.clone() cloned_model.accept(ASTSymbolTableVisitor()) + CoCosManager.check_cocos(cloned_model) cloned_models.append(cloned_model) self.codegen_cpp.generate_code(cloned_models) @@ -224,6 +224,7 @@ def generate_code(self, models: Sequence[ASTModel]) -> None: for model in models: cloned_model = model.clone() cloned_model.accept(ASTSymbolTableVisitor()) + CoCosManager.check_cocos(cloned_model) cloned_models.append(cloned_model) self.codegen_py.generate_code(cloned_models) diff --git a/pynestml/frontend/frontend_configuration.py b/pynestml/frontend/frontend_configuration.py index 173534c95..aae1fc29a 100644 --- a/pynestml/frontend/frontend_configuration.py +++ b/pynestml/frontend/frontend_configuration.py @@ -244,8 +244,8 @@ def handle_module_name(cls, module_name): @classmethod def handle_target_platform(cls, target_platform: Optional[str]): - if target_platform is None or target_platform.upper() == 'NONE': - target_platform = '' # make sure `target_platform` is always a string + if target_platform is None: + target_platform = "NONE" # make sure `target_platform` is always a string from pynestml.frontend.pynestml_frontend import get_known_targets diff --git a/pynestml/frontend/pynestml_frontend.py b/pynestml/frontend/pynestml_frontend.py index c257822de..ca3866619 100644 --- a/pynestml/frontend/pynestml_frontend.py +++ b/pynestml/frontend/pynestml_frontend.py @@ -41,6 +41,8 @@ from pynestml.utils.logger import Logger, LoggingLevel from pynestml.utils.messages import Messages from pynestml.utils.model_parser import ModelParser +from pynestml.visitors.ast_parent_visitor import ASTParentVisitor +from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor def get_known_targets(): @@ -131,10 +133,10 @@ def code_generator_from_target_name(target_name: str, options: Optional[Mapping[ return SpiNNakerCodeGenerator(options) if target_name.upper() == "NONE": - # dummy/null target: user requested to not generate any code + # dummy/null target: user requested to not generate any code (for instance, when just doing validation of a model) code, message = Messages.get_no_code_generated() Logger.log_message(None, code, message, None, LoggingLevel.INFO) - return CodeGenerator("", options) + return CodeGenerator(options) # cannot reach here due to earlier assert -- silence static checker warnings assert "Unknown code generator requested: " + target_name @@ -193,12 +195,17 @@ def generate_target(input_path: Union[str, Sequence[str]], target_platform: str, Enable development mode: code generation is attempted even for models that contain errors, and extra information is rendered in the generated code. codegen_opts : Optional[Mapping[str, Any]] A dictionary containing additional options for the target code generator. + + Return + ------ + errors_occurred + Flag indicating whether errors occurred during processing. False if processing was successful; True if errors occurred in any of the models. """ configure_front_end(input_path, target_platform, target_path, install_path, logging_level, module_name, store_log, suffix, dev, codegen_opts) - if not process() == 0: - raise Exception("Error(s) occurred while processing the model") + + return process() def configure_front_end(input_path: Union[str, Sequence[str]], target_platform: str, target_path=None, @@ -373,34 +380,36 @@ def generate_nest_compartmental_target(input_path: Union[str, Sequence[str]], ta def main() -> int: - """ + r""" Entry point for the command-line application. Returns ------- - The process exit code: 0 for success, > 0 for failure + exit_code + The process exit code: 0 for success, > 0 for failure """ try: FrontendConfiguration.parse_config(sys.argv[1:]) except InvalidPathException as e: print(e) + return 1 + # the default Python recursion limit is 1000, which might not be enough in practice when running an AST visitor on a deep tree, e.g. containing an automatically generated expression sys.setrecursionlimit(10000) + # after all argument have been collected, start the actual processing return int(process()) -def get_parsed_models(): +def get_parsed_models() -> List[ASTModel]: r""" Handle the parsing and validation of the NESTML files Returns ------- - models: Sequence[ASTModel] + models List of correctly parsed models - errors_occurred : bool - Flag indicating whether errors occurred during processing """ # init log dir create_report_dir() @@ -417,36 +426,29 @@ def get_parsed_models(): for nestml_file in nestml_files: parsed_unit = ModelParser.parse_file(nestml_file) - if parsed_unit is None: - # Parsing error in the NESTML model, return True - return [], True - - compilation_units.append(parsed_unit) + if parsed_unit: + compilation_units.append(parsed_unit) - if len(compilation_units) > 0: - # generate a list of all models - models: Sequence[ASTModel] = [] - for compilationUnit in compilation_units: - models.extend(compilationUnit.get_model_list()) + # generate a list of all models + models: Sequence[ASTModel] = [] + for compilation_unit in compilation_units: + CoCosManager.check_model_names_unique(compilation_unit) + models.extend(compilation_unit.get_model_list()) - # check that no models with duplicate names have been defined - CoCosManager.check_no_duplicate_compilation_unit_names(models) + # check that no models with duplicate names have been defined + CoCosManager.check_no_duplicate_compilation_unit_names(models) - # now exclude those which are broken, i.e. have errors. - for model in models: - if Logger.has_errors(model): - code, message = Messages.get_model_contains_errors(model.get_name()) - Logger.log_message(node=model, code=code, message=message, - error_position=model.get_source_position(), - log_level=LoggingLevel.WARNING) - return [model], True + for model in models: + model.accept(ASTParentVisitor()) + model.accept(ASTSymbolTableVisitor()) - return models, False + return models def transform_models(transformers, models): for transformer in transformers: models = transformer.transform(models) + return models @@ -454,44 +456,65 @@ def generate_code(code_generators, models): code_generators.generate_code(models) -def process(): +def process() -> bool: r""" The main toolchain workflow entry point. For all models: parse, validate, transform, generate code and build. - Returns - ------- - errors_occurred : bool - Flag indicating whether errors occurred during processing + Return + ------ + errors_occurred + Flag indicating whether errors occurred during processing. False if processing was successful; True if errors occurred in any of the models. """ - # initialize and set options for transformers, code generator and builder - codegen_and_builder_opts = FrontendConfiguration.get_codegen_opts() - - transformers, codegen_and_builder_opts = transformers_from_target_name(FrontendConfiguration.get_target_platform(), - options=codegen_and_builder_opts) + # initialise model transformers + transformers, unused_opts_transformer = transformers_from_target_name(FrontendConfiguration.get_target_platform(), + options=FrontendConfiguration.get_codegen_opts()) + # initialise code generator code_generator = code_generator_from_target_name(FrontendConfiguration.get_target_platform()) - codegen_and_builder_opts = code_generator.set_options(codegen_and_builder_opts) + unused_opts_codegen = code_generator.set_options(FrontendConfiguration.get_codegen_opts()) + + # initialise builder + _builder, unused_opts_builder = builder_from_target_name(FrontendConfiguration.get_target_platform(), + options=FrontendConfiguration.get_codegen_opts()) + + # check for unused codegen options + for opt_key in FrontendConfiguration.get_codegen_opts().keys(): + if opt_key in unused_opts_transformer.keys() and opt_key in unused_opts_codegen.keys() and opt_key in unused_opts_builder.keys(): + raise CodeGeneratorOptionsException("The code generator option \"" + opt_key + "\" does not exist.") + + models = get_parsed_models() - _builder, codegen_and_builder_opts = builder_from_target_name(FrontendConfiguration.get_target_platform(), options=codegen_and_builder_opts) + # validation -- check cocos for models that do not have errors already + excluded_models = [] + for model in models: + if Logger.has_errors(model.name): + code, message = Messages.get_model_contains_errors(model.get_name()) + Logger.log_message(node=model, code=code, message=message, + error_position=model.get_source_position(), + log_level=LoggingLevel.WARNING) + excluded_models.append(model) + else: + CoCosManager.check_cocos(model) - if len(codegen_and_builder_opts) > 0: - raise CodeGeneratorOptionsException("The code generator option(s) \"" + ", ".join(codegen_and_builder_opts.keys()) + "\" do not exist.") + # exclude models that have errors + models = list(set(models) - set(excluded_models)) - models, errors_occurred = get_parsed_models() + # transformation(s) + models = transform_models(transformers, models) - if not errors_occurred: - models = transform_models(transformers, models) - generate_code(code_generator, models) + # generate code + generate_code(code_generator, models) - # perform build - if _builder is not None: - _builder.build() + # perform build + if _builder is not None: + _builder.build() if FrontendConfiguration.store_log: store_log_to_file() - return errors_occurred + # return a boolean indicating whether errors occurred + return len(Logger.get_all_messages_of_level(LoggingLevel.ERROR)) > 0 def init_predefined(): diff --git a/pynestml/meta_model/ast_model.py b/pynestml/meta_model/ast_model.py index 834e56897..c4b7374bf 100644 --- a/pynestml/meta_model/ast_model.py +++ b/pynestml/meta_model/ast_model.py @@ -459,23 +459,27 @@ def add_to_internals_block(self, declaration: ASTDeclaration, index: int = -1) - Adds the handed over declaration the internals block :param declaration: a single declaration """ - assert len(self.get_internals_blocks()) <= 1, "Only one internals block supported for now" from pynestml.utils.ast_utils import ASTUtils + from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor + + assert len(self.get_internals_blocks()) <= 1, "Only one internals block supported for now" + if not self.get_internals_blocks(): ASTUtils.create_internal_block(self) + n_declarations = len(self.get_internals_blocks()[0].get_declarations()) if n_declarations == 0: index = 0 else: index = 1 + (index % len(self.get_internals_blocks()[0].get_declarations())) + self.get_internals_blocks()[0].get_declarations().insert(index, declaration) declaration.update_scope(self.get_internals_blocks()[0].get_scope()) - from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor - from pynestml.visitors.ast_parent_visitor import ASTParentVisitor symtable_vistor = ASTSymbolTableVisitor() symtable_vistor.block_type_stack.push(BlockType.INTERNALS) - declaration.accept(symtable_vistor) - self.get_internals_blocks()[0].accept(ASTParentVisitor()) + self.accept(ASTParentVisitor()) + self.accept(symtable_vistor) symtable_vistor.block_type_stack.pop() def add_to_state_block(self, declaration: ASTDeclaration) -> None: @@ -483,24 +487,26 @@ def add_to_state_block(self, declaration: ASTDeclaration) -> None: Adds the handed over declaration to an arbitrary state block. A state block will be created if none exists. :param declaration: a single declaration. """ - assert len(self.get_state_blocks()) <= 1, "Only one internals block supported for now" + from pynestml.symbols.symbol import SymbolKind from pynestml.utils.ast_utils import ASTUtils + from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor + + assert len(self.get_state_blocks()) <= 1, "Only one internals block supported for now" + if not self.get_state_blocks(): ASTUtils.create_state_block(self) + self.get_state_blocks()[0].get_declarations().append(declaration) declaration.update_scope(self.get_state_blocks()[0].get_scope()) - from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor - from pynestml.visitors.ast_parent_visitor import ASTParentVisitor symtable_vistor = ASTSymbolTableVisitor() symtable_vistor.block_type_stack.push(BlockType.STATE) - declaration.accept(symtable_vistor) - self.get_state_blocks()[0].accept(ASTParentVisitor()) + self.accept(ASTParentVisitor()) + self.accept(symtable_vistor) symtable_vistor.block_type_stack.pop() - from pynestml.symbols.symbol import SymbolKind - assert declaration.get_variables()[0].get_scope().resolve_to_symbol( - declaration.get_variables()[0].get_name(), SymbolKind.VARIABLE) is not None - assert declaration.get_scope().resolve_to_symbol(declaration.get_variables()[0].get_name(), - SymbolKind.VARIABLE) is not None + + assert declaration.get_variables()[0].get_scope().resolve_to_symbol(declaration.get_variables()[0].get_name(), SymbolKind.VARIABLE) is not None + assert declaration.get_scope().resolve_to_symbol(declaration.get_variables()[0].get_name(), SymbolKind.VARIABLE) is not None def print_comment(self, prefix: str = "") -> str: """ @@ -566,7 +572,6 @@ def get_spike_input_port_names(self) -> List[str]: """ Returns a list of all spike input ports defined in the model. """ - print("get_spike_input_port_names = " + str([port.get_symbol_name() for port in self.get_spike_input_ports()])) return [port.get_symbol_name() for port in self.get_spike_input_ports()] def get_continuous_input_ports(self) -> List[VariableSymbol]: diff --git a/pynestml/symbols/symbol.py b/pynestml/symbols/symbol.py index 1e294566b..c73435c6d 100644 --- a/pynestml/symbols/symbol.py +++ b/pynestml/symbols/symbol.py @@ -18,8 +18,8 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from abc import ABCMeta, abstractmethod +from abc import ABCMeta, abstractmethod from enum import Enum diff --git a/pynestml/symbols/type_symbol.py b/pynestml/symbols/type_symbol.py index 7047cdbca..a3eb28a12 100644 --- a/pynestml/symbols/type_symbol.py +++ b/pynestml/symbols/type_symbol.py @@ -18,11 +18,11 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . + from abc import ABCMeta, abstractmethod from pynestml.symbols.symbol import Symbol from pynestml.utils.logger import Logger, LoggingLevel -from pynestml.utils.messages import Messages class TypeSymbol(Symbol): @@ -198,6 +198,7 @@ def is_castable_to(self, _other_type): def binary_operation_not_defined_error(self, _operator, _other): from pynestml.symbols.error_type_symbol import ErrorTypeSymbol + from pynestml.utils.messages import Messages result = ErrorTypeSymbol() code, message = Messages.get_binary_operation_not_defined( lhs=self.print_nestml_type(), operator=_operator, rhs=_other.print_nestml_type()) @@ -208,6 +209,7 @@ def binary_operation_not_defined_error(self, _operator, _other): def unary_operation_not_defined_error(self, _operator): from pynestml.symbols.error_type_symbol import ErrorTypeSymbol result = ErrorTypeSymbol() + from pynestml.utils.messages import Messages code, message = Messages.get_unary_operation_not_defined(_operator, self.print_symbol()) Logger.log_message(code=code, message=message, error_position=self.referenced_object.get_source_position(), @@ -226,6 +228,7 @@ def inverse_of_unit(cls, other): return result def warn_implicit_cast_from_to(self, _from, _to): + from pynestml.utils.messages import Messages code, message = Messages.get_implicit_cast_rhs_to_lhs(_to.print_symbol(), _from.print_symbol()) Logger.log_message(code=code, message=message, error_position=self.get_referenced_object().get_source_position(), diff --git a/pynestml/symbols/unit_type_symbol.py b/pynestml/symbols/unit_type_symbol.py index 37c43b035..1f9977de0 100644 --- a/pynestml/symbols/unit_type_symbol.py +++ b/pynestml/symbols/unit_type_symbol.py @@ -19,6 +19,7 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import Optional from pynestml.symbols.type_symbol import TypeSymbol from pynestml.utils.logger import Logger, LoggingLevel from pynestml.utils.messages import Messages @@ -131,12 +132,12 @@ def __sub__(self, other): def add_or_sub_another_unit(self, other): if self.equals(other): return other - else: - return self.attempt_magnitude_cast(other) + + return self.attempt_magnitude_cast(other) def attempt_magnitude_cast(self, other): if self.differs_only_in_magnitude(other): - factor = UnitTypeSymbol.get_conversion_factor(self.astropy_unit, other.astropy_unit) + factor = UnitTypeSymbol.get_conversion_factor(other.astropy_unit, self.astropy_unit) other.referenced_object.set_implicit_conversion_factor(factor) code, message = Messages.get_implicit_magnitude_conversion(self, other, factor) Logger.log_message(code=code, message=message, @@ -144,18 +145,20 @@ def attempt_magnitude_cast(self, other): log_level=LoggingLevel.INFO) return self - else: - return self.binary_operation_not_defined_error('+/-', other) - # TODO: change order of parameters to conform with the from_to scheme. - # TODO: Also rename to reflect that, i.e. get_conversion_factor_from_to + return self.binary_operation_not_defined_error('+/-', other) + @classmethod - def get_conversion_factor(cls, to, _from): + def get_conversion_factor(cls, _from, to) -> Optional[float]: """ - Calculates the conversion factor from _convertee_unit to target_unit. - Behaviour is only well-defined if both units have the same physical base type + Calculates the conversion factor from _convertee_unit to target_unit. Behaviour is only well-defined if both units have the same physical base type. """ - factor = (_from / to).si.scale + try: + factor = (_from / to).si.scale + except BaseException: + # this can fail in case of e.g. trying to convert from "1/s" to "2/s" + return None + return factor def is_castable_to(self, _other_type): diff --git a/pynestml/transformers/assign_implicit_conversion_factors_transformer.py b/pynestml/transformers/assign_implicit_conversion_factors_transformer.py new file mode 100644 index 000000000..f44ee12d5 --- /dev/null +++ b/pynestml/transformers/assign_implicit_conversion_factors_transformer.py @@ -0,0 +1,335 @@ +# -*- coding: utf-8 -*- +# +# assign_implicit_conversion_factors_transformer.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from typing import Sequence, Union + +from pynestml.meta_model.ast_compound_stmt import ASTCompoundStmt +from pynestml.meta_model.ast_declaration import ASTDeclaration +from pynestml.meta_model.ast_inline_expression import ASTInlineExpression +from pynestml.meta_model.ast_node import ASTNode +from pynestml.meta_model.ast_small_stmt import ASTSmallStmt +from pynestml.meta_model.ast_stmt import ASTStmt +from pynestml.symbols.error_type_symbol import ErrorTypeSymbol +from pynestml.symbols.predefined_types import PredefinedTypes +from pynestml.symbols.symbol import SymbolKind +from pynestml.symbols.template_type_symbol import TemplateTypeSymbol +from pynestml.symbols.variadic_type_symbol import VariadicTypeSymbol +from pynestml.transformers.transformer import Transformer +from pynestml.utils.ast_source_location import ASTSourceLocation +from pynestml.utils.ast_utils import ASTUtils +from pynestml.utils.logger import LoggingLevel, Logger +from pynestml.utils.logging_helper import LoggingHelper +from pynestml.utils.messages import Messages +from pynestml.utils.type_caster import TypeCaster +from pynestml.visitors.ast_visitor import ASTVisitor + + +class AssignImplicitConversionFactorsTransformer(Transformer): + r""" + Assign implicit conversion factors in expressions. + """ + + def transform(self, models: Union[ASTNode, Sequence[ASTNode]]) -> Union[ASTNode, Sequence[ASTNode]]: + single = False + if isinstance(models, ASTNode): + single = True + models = [models] + + for model in models: + model.accept(AssignImplicitConversionFactorVisitor()) + self.__assign_return_types(model) + + if single: + return models[0] + return models + + def __assign_return_types(self, _node): + for userDefinedFunction in _node.get_functions(): + symbol = userDefinedFunction.get_scope().resolve_to_symbol(userDefinedFunction.get_name(), + SymbolKind.FUNCTION) + # first ensure that the block contains at least one statement + if symbol is not None and len(userDefinedFunction.get_block().get_stmts()) > 0: + # now check that the last statement is a return + self.__check_return_recursively(userDefinedFunction, + symbol.get_return_type(), + userDefinedFunction.get_block().get_stmts(), + False) + # now if it does not have a statement, but uses a return type, it is an error + elif symbol is not None and userDefinedFunction.has_return_type() and \ + not symbol.get_return_type().equals(PredefinedTypes.get_void_type()): + code, message = Messages.get_no_return() + Logger.log_message(node=_node, code=code, message=message, + error_position=userDefinedFunction.get_source_position(), + log_level=LoggingLevel.ERROR) + + def __check_return_recursively(self, processed_function, type_symbol=None, stmts=None, ret_defined: bool = False) -> None: + """ + For a handed over statement, it checks if the statement is a return statement and if it is typed according to the handed over type symbol. + :param type_symbol: a single type symbol + :type type_symbol: type_symbol + :param stmts: a list of statements, either simple or compound + :type stmts: list(ASTSmallStmt,ASTCompoundStmt) + :param ret_defined: indicates whether a ret has already been defined after this block of stmt, thus is not + necessary. Implies that the return has been defined in the higher level block + """ + # in order to ensure that in the sub-blocks, a return is not necessary, we check if the last one in this + # block is a return statement, thus it is not required to have a return in the sub-blocks, but optional + last_statement = stmts[len(stmts) - 1] + ret_defined = False or ret_defined + if (len(stmts) > 0 and isinstance(last_statement, ASTStmt) + and last_statement.is_small_stmt() + and last_statement.small_stmt.is_return_stmt()): + ret_defined = True + + # now check that returns are there if necessary and correctly typed + for c_stmt in stmts: + if c_stmt.is_small_stmt(): + stmt = c_stmt.small_stmt + else: + stmt = c_stmt.compound_stmt + + # if it is a small statement, check if it is a return statement + if isinstance(stmt, ASTSmallStmt) and stmt.is_return_stmt(): + # first check if the return is the last one in this block of statements + if stmts.index(c_stmt) != (len(stmts) - 1): + code, message = Messages.get_not_last_statement('Return') + Logger.log_message(error_position=stmt.get_source_position(), + code=code, message=message, + log_level=LoggingLevel.WARNING) + + # now check that it corresponds to the declared type + if stmt.get_return_stmt().has_expression() and type_symbol is PredefinedTypes.get_void_type(): + code, message = Messages.get_type_different_from_expected(PredefinedTypes.get_void_type(), + stmt.get_return_stmt().get_expression().type) + Logger.log_message(error_position=stmt.get_source_position(), + message=message, code=code, log_level=LoggingLevel.ERROR) + + # if it is not void check if the type corresponds to the one stated + if not stmt.get_return_stmt().has_expression() and \ + not type_symbol.equals(PredefinedTypes.get_void_type()): + code, message = Messages.get_type_different_from_expected(PredefinedTypes.get_void_type(), + type_symbol) + Logger.log_message(error_position=stmt.get_source_position(), + message=message, code=code, log_level=LoggingLevel.ERROR) + + if stmt.get_return_stmt().has_expression(): + type_of_return = stmt.get_return_stmt().get_expression().type + if isinstance(type_of_return, ErrorTypeSymbol): + code, message = Messages.get_type_could_not_be_derived(processed_function.get_name()) + Logger.log_message(error_position=stmt.get_source_position(), + code=code, message=message, log_level=LoggingLevel.ERROR) + elif not type_of_return.equals(type_symbol): + TypeCaster.try_to_recover_or_error(type_symbol, type_of_return, + stmt.get_return_stmt().get_expression()) + elif isinstance(stmt, ASTCompoundStmt): + # otherwise it is a compound stmt, thus check recursively + if stmt.is_if_stmt(): + self.__check_return_recursively(processed_function, + type_symbol, + stmt.get_if_stmt().get_if_clause().get_block().get_stmts(), + ret_defined) + for else_ifs in stmt.get_if_stmt().get_elif_clauses(): + self.__check_return_recursively(processed_function, + type_symbol, else_ifs.get_block().get_stmts(), ret_defined) + if stmt.get_if_stmt().has_else_clause(): + self.__check_return_recursively(processed_function, + type_symbol, + stmt.get_if_stmt().get_else_clause().get_block().get_stmts(), + ret_defined) + elif stmt.is_while_stmt(): + self.__check_return_recursively(processed_function, + type_symbol, stmt.get_while_stmt().get_block().get_stmts(), + ret_defined) + elif stmt.is_for_stmt(): + self.__check_return_recursively(processed_function, + type_symbol, stmt.get_for_stmt().get_block().get_stmts(), + ret_defined) + # now, if a return statement has not been defined in the corresponding higher level block, we have to ensure that it is defined here + elif not ret_defined and stmts.index(c_stmt) == (len(stmts) - 1): + if not (isinstance(stmt, ASTSmallStmt) and stmt.is_return_stmt()): + code, message = Messages.get_no_return() + Logger.log_message(error_position=stmt.get_source_position(), log_level=LoggingLevel.ERROR, + code=code, message=message) + + +class AssignImplicitConversionFactorVisitor(ASTVisitor): + """ + This visitor checks that all expression correspond to the expected type. + """ + + def visit_declaration(self, node): + """ + Visits a single declaration and asserts that type of lhs is equal to type of rhs. + :param node: a single declaration. + :type node: ASTDeclaration + """ + assert isinstance(node, ASTDeclaration) + if node.has_expression(): + if node.get_expression().get_source_position().equals(ASTSourceLocation.get_added_source_position()): + # no type checks are executed for added nodes, since we assume correctness + return + lhs_type = node.get_data_type().get_type_symbol() + rhs_type = node.get_expression().type + if isinstance(rhs_type, ErrorTypeSymbol): + LoggingHelper.drop_missing_type_error(node) + return + if self.__types_do_not_match(lhs_type, rhs_type): + TypeCaster.try_to_recover_or_error(lhs_type, rhs_type, node.get_expression()) + + def visit_inline_expression(self, node): + """ + Visits a single inline expression and asserts that type of lhs is equal to type of rhs. + """ + assert isinstance(node, ASTInlineExpression) + lhs_type = node.get_data_type().get_type_symbol() + rhs_type = node.get_expression().type + if isinstance(rhs_type, ErrorTypeSymbol): + LoggingHelper.drop_missing_type_error(node) + return + + if self.__types_do_not_match(lhs_type, rhs_type): + TypeCaster.try_to_recover_or_error(lhs_type, rhs_type, node.get_expression()) + + def visit_assignment(self, node): + """ + Visits a single expression and assures that type(lhs) == type(rhs). + :param node: a single assignment. + :type node: ASTAssignment + """ + from pynestml.meta_model.ast_assignment import ASTAssignment + assert isinstance(node, ASTAssignment) + + if node.get_source_position().equals(ASTSourceLocation.get_added_source_position()): + # no type checks are executed for added nodes, since we assume correctness + return + if node.is_direct_assignment: # case a = b is simple + self.handle_simple_assignment(node) + else: + self.handle_compound_assignment(node) # e.g. a *= b + + def handle_compound_assignment(self, node): + rhs_expr = node.get_expression() + lhs_variable_symbol = node.get_variable().resolve_in_own_scope() + rhs_type_symbol = rhs_expr.type + + if lhs_variable_symbol is None: + code, message = Messages.get_equation_var_not_in_state_block(node.get_variable().get_complete_name()) + Logger.log_message(code=code, message=message, error_position=node.get_source_position(), + log_level=LoggingLevel.ERROR) + return + + if isinstance(rhs_type_symbol, ErrorTypeSymbol): + LoggingHelper.drop_missing_type_error(node) + return + + lhs_type_symbol = lhs_variable_symbol.get_type_symbol() + + if node.is_compound_product: + if self.__types_do_not_match(lhs_type_symbol, lhs_type_symbol * rhs_type_symbol): + TypeCaster.try_to_recover_or_error(lhs_type_symbol, lhs_type_symbol * rhs_type_symbol, + node.get_expression()) + return + return + + if node.is_compound_quotient: + if self.__types_do_not_match(lhs_type_symbol, lhs_type_symbol / rhs_type_symbol): + TypeCaster.try_to_recover_or_error(lhs_type_symbol, lhs_type_symbol / rhs_type_symbol, + node.get_expression()) + return + return + + assert node.is_compound_sum or node.is_compound_minus + if self.__types_do_not_match(lhs_type_symbol, rhs_type_symbol): + TypeCaster.try_to_recover_or_error(lhs_type_symbol, rhs_type_symbol, + node.get_expression()) + + @staticmethod + def __types_do_not_match(lhs_type_symbol, rhs_type_symbol): + if lhs_type_symbol is None: + return True + + return not lhs_type_symbol.equals(rhs_type_symbol) + + def handle_simple_assignment(self, node): + from pynestml.symbols.symbol import SymbolKind + lhs_variable_symbol = node.get_scope().resolve_to_symbol(node.get_variable().get_complete_name(), + SymbolKind.VARIABLE) + + rhs_type_symbol = node.get_expression().type + if isinstance(rhs_type_symbol, ErrorTypeSymbol): + LoggingHelper.drop_missing_type_error(node) + return + + if lhs_variable_symbol is not None and self.__types_do_not_match(lhs_variable_symbol.get_type_symbol(), + rhs_type_symbol): + TypeCaster.try_to_recover_or_error(lhs_variable_symbol.get_type_symbol(), rhs_type_symbol, + node.get_expression()) + + def visit_function_call(self, node): + """ + Check consistency for a single function call: check if the called function has been declared, whether the number and types of arguments correspond to the declaration, etc. + + :param node: a single function call. + :type node: ASTFunctionCall + """ + func_name = node.get_name() + + if func_name == 'convolve': + return + + symbol = node.get_scope().resolve_to_symbol(node.get_name(), SymbolKind.FUNCTION) + + if symbol is None and ASTUtils.is_function_delay_variable(node): + return + + # first check if the function has been declared + if symbol is None: + code, message = Messages.get_function_not_declared(node.get_name()) + Logger.log_message(error_position=node.get_source_position(), log_level=LoggingLevel.ERROR, + code=code, message=message) + return + + # check if the number of arguments is the same as in the symbol; accept anything for variadic types + is_variadic: bool = len(symbol.get_parameter_types()) == 1 and isinstance(symbol.get_parameter_types()[0], VariadicTypeSymbol) + if (not is_variadic) and len(node.get_args()) != len(symbol.get_parameter_types()): + code, message = Messages.get_wrong_number_of_args(str(node), len(symbol.get_parameter_types()), + len(node.get_args())) + Logger.log_message(code=code, message=message, log_level=LoggingLevel.ERROR, + error_position=node.get_source_position()) + return + + # finally check if the call is correctly typed + expected_types = symbol.get_parameter_types() + actual_args = node.get_args() + actual_types = [arg.type for arg in actual_args] + for actual_arg, actual_type, expected_type in zip(actual_args, actual_types, expected_types): + if isinstance(actual_type, ErrorTypeSymbol): + code, message = Messages.get_type_could_not_be_derived(actual_arg) + Logger.log_message(code=code, message=message, log_level=LoggingLevel.ERROR, + error_position=actual_arg.get_source_position()) + return + + if isinstance(expected_type, VariadicTypeSymbol): + # variadic type symbol accepts anything + return + + if not actual_type.equals(expected_type) and not isinstance(expected_type, TemplateTypeSymbol): + TypeCaster.try_to_recover_or_error(expected_type, actual_type, actual_arg) diff --git a/pynestml/transformers/synapse_post_neuron_transformer.py b/pynestml/transformers/synapse_post_neuron_transformer.py index 68cc70a62..a06260824 100644 --- a/pynestml/transformers/synapse_post_neuron_transformer.py +++ b/pynestml/transformers/synapse_post_neuron_transformer.py @@ -23,6 +23,7 @@ from typing import Any, Sequence, Mapping, Optional, Union +from pynestml.cocos.co_cos_manager import CoCosManager from pynestml.frontend.frontend_configuration import FrontendConfiguration from pynestml.meta_model.ast_assignment import ASTAssignment from pynestml.meta_model.ast_equations_block import ASTEquationsBlock @@ -563,11 +564,6 @@ def mark_post_port(_expr=None): # replace occurrences of the variables in expressions in the original synapse with calls to the corresponding neuron getters # - # make sure the moved symbols can be resolved in the scope of the neuron (that's where ``ASTExternalVariable._altscope`` will be pointing to) - ast_symbol_table_visitor = ASTSymbolTableVisitor() - ast_symbol_table_visitor.after_ast_rewrite_ = True - new_neuron.accept(ast_symbol_table_visitor) - Logger.log_message( None, -1, "In synapse: replacing variables with suffixed external variable references", None, LoggingLevel.INFO) for state_var in syn_to_neuron_state_vars: @@ -609,7 +605,6 @@ def mark_post_port(_expr=None): new_neuron.accept(ASTParentVisitor()) new_synapse.accept(ASTParentVisitor()) ast_symbol_table_visitor = ASTSymbolTableVisitor() - ast_symbol_table_visitor.after_ast_rewrite_ = True new_neuron.accept(ast_symbol_table_visitor) new_synapse.accept(ast_symbol_table_visitor) diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py index d3d6f6ef5..a3983694d 100644 --- a/pynestml/utils/ast_utils.py +++ b/pynestml/utils/ast_utils.py @@ -28,7 +28,6 @@ from pynestml.codegeneration.printers.ast_printer import ASTPrinter from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter -from pynestml.codegeneration.printers.nestml_printer import NESTMLPrinter from pynestml.frontend.frontend_configuration import FrontendConfiguration from pynestml.generated.PyNestMLLexer import PyNestMLLexer from pynestml.meta_model.ast_assignment import ASTAssignment @@ -66,7 +65,6 @@ from pynestml.utils.messages import Messages from pynestml.utils.string_utils import removesuffix from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor -from pynestml.visitors.ast_parent_visitor import ASTParentVisitor from pynestml.visitors.ast_visitor import ASTVisitor @@ -1766,10 +1764,12 @@ def remove_initial_values_for_kernels(cls, model: ASTModel) -> None: @classmethod def update_initial_values_for_odes(cls, model: ASTModel, solver_dicts: List[dict]) -> None: """ - Update initial values for original ODE declarations (e.g. V_m', g_ahp'') that are present in the model - before ODE-toolbox processing, with the formatted variable names and initial values returned by ODE-toolbox. + Update initial values for original ODE declarations (e.g. V_m', g_ahp'') that are present in the model before ODE-toolbox processing, with the formatted variable names and initial values returned by ODE-toolbox. """ from pynestml.utils.model_parser import ModelParser + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor + from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor + assert len(model.get_equations_blocks()) == 1, "Only one equation block should be present" if not model.get_state_blocks(): @@ -1782,10 +1782,6 @@ def update_initial_values_for_odes(cls, model: ASTModel, solver_dicts: List[dict if cls.is_ode_variable(var.get_name(), model): assert cls.variable_in_solver(cls.to_ode_toolbox_processed_name(var_name), solver_dicts) - # replace the left-hand side variable name by the ode-toolbox format - var.set_name(cls.to_ode_toolbox_processed_name(var.get_complete_name())) - var.set_differential_order(0) - # replace the defining expression by the ode-toolbox result iv_expr = cls.get_initial_value_from_ode_toolbox_result( cls.to_ode_toolbox_processed_name(var_name), solver_dicts) @@ -1794,6 +1790,9 @@ def update_initial_values_for_odes(cls, model: ASTModel, solver_dicts: List[dict iv_expr.update_scope(state_block.get_scope()) iv_decl.set_expression(iv_expr) + model.accept(ASTParentVisitor()) + model.accept(ASTSymbolTableVisitor()) + @classmethod def integrate_odes_args_strs_from_function_call(cls, function_call: ASTFunctionCall): arg_names = [] @@ -2296,6 +2295,7 @@ def replace_convolve_calls_with_buffers_(cls, model: ASTModel, equations_block: r""" Replace all occurrences of `convolve(kernel[']^n, spike_input_port)` with the corresponding buffer variable, e.g. `g_E__X__spikes_exc[__d]^n` for a kernel named `g_E` and a spike input port named `spikes_exc`. """ + from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor def replace_function_call_through_var(_expr=None): if _expr.is_function_call() and _expr.get_function_call().get_name() == "convolve": @@ -2326,6 +2326,7 @@ def func(x): return replace_function_call_through_var(x) if isinstance(x, ASTSimpleExpression) else True equations_block.accept(ASTHigherOrderVisitor(func)) + equations_block.accept(ASTSymbolTableVisitor()) @classmethod def update_blocktype_for_common_parameters(cls, node): diff --git a/pynestml/utils/logger.py b/pynestml/utils/logger.py index 06e95b804..8404f1245 100644 --- a/pynestml/utils/logger.py +++ b/pynestml/utils/logger.py @@ -19,7 +19,7 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from typing import List, Mapping, Optional, Tuple +from typing import List, Mapping, Optional, Tuple, Union from collections import OrderedDict from enum import Enum @@ -75,6 +75,7 @@ class Logger: def init_logger(cls, logging_level: LoggingLevel): """ Initializes the logger. + :param logging_level: the logging level as required :type logging_level: LoggingLevel """ @@ -82,7 +83,6 @@ def init_logger(cls, logging_level: LoggingLevel): cls.curr_message = 0 cls.log = {} cls.log_frozen = False - return @classmethod def freeze_log(cls, do_freeze: bool = True): @@ -95,6 +95,7 @@ def freeze_log(cls, do_freeze: bool = True): def get_log(cls) -> Mapping[int, Tuple[ASTNode, LoggingLevel, str]]: """ Returns the overall log of messages. The structure of the log is: (NODE, LEVEL, MESSAGE) + :return: mapping from id to ASTNode, log level and message. """ return cls.log @@ -103,6 +104,7 @@ def get_log(cls) -> Mapping[int, Tuple[ASTNode, LoggingLevel, str]]: def set_log(cls, log, counter): """ Restores log from the 'log' variable + :param log: the log :param counter: the counter """ @@ -113,20 +115,19 @@ def set_log(cls, log, counter): def log_message(cls, node: ASTNode = None, code: MessageCode = None, message: str = None, error_position: ASTSourceLocation = None, log_level: LoggingLevel = None): """ Logs the handed over message on the handed over node. If the current logging is appropriate, the message is also printed. + :param node: the node in which the error occurred :param code: a single error code - :type code: ErrorCode :param error_position: the position on which the error occurred. - :type error_position: SourcePosition :param message: a message. - :type message: str :param log_level: the corresponding log level. - :type log_level: LoggingLevel """ if cls.log_frozen: return + if cls.curr_message is None: cls.init_logger(LoggingLevel.INFO) + from pynestml.meta_model.ast_node import ASTNode from pynestml.utils.ast_source_location import ASTSourceLocation assert (node is None or isinstance(node, ASTNode)), \ @@ -134,15 +135,23 @@ def log_message(cls, node: ASTNode = None, code: MessageCode = None, message: st assert (error_position is None or isinstance(error_position, ASTSourceLocation)), \ '(PyNestML.Logger) Wrong type of error position provided (%s)!' % type(error_position) from pynestml.meta_model.ast_model import ASTModel + if isinstance(node, ASTModel): cls.log[cls.curr_message] = ( node.get_artifact_name(), node, log_level, code, error_position, message) - elif cls.current_node is not None: - cls.log[cls.curr_message] = (cls.current_node.get_artifact_name(), cls.current_node, + else: + if cls.current_node is not None: + artifact_name = cls.current_node.get_artifact_name() + else: + artifact_name = "" + + cls.log[cls.curr_message] = (artifact_name, cls.current_node, log_level, code, error_position, message) + cls.curr_message += 1 if cls.no_print: return + if cls.logging_level.value <= log_level.value: if isinstance(node, ASTInlineExpression): node_name = node.variable_name @@ -163,10 +172,9 @@ def log_message(cls, node: ASTNode = None, code: MessageCode = None, message: st def string_to_level(cls, string: str) -> LoggingLevel: """ Returns the logging level corresponding to the handed over string. If no such exits, returns None. + :param string: a single string representing the level. - :type string: str :return: a single logging level. - :rtype: LoggingLevel """ if string == 'DEBUG': return LoggingLevel.DEBUG @@ -183,7 +191,7 @@ def string_to_level(cls, string: str) -> LoggingLevel: if string == 'NO' or string == 'NONE': return LoggingLevel.NO - raise Exception('Tried to convert unknown string \"' + string + '\" to logging level') + raise Exception("Tried to convert unknown string '" + string + "' to logging level") @classmethod def level_to_string(cls, level: LoggingLevel) -> str: @@ -207,7 +215,7 @@ def level_to_string(cls, level: LoggingLevel) -> str: if level == LoggingLevel.NO: return 'NO' - raise Exception('Tried to convert unknown logging level \"' + str(level) + '\" to string') + raise Exception("Tried to convert unknown logging level '" + str(level) + "' to string") @classmethod def set_logging_level(cls, level: LoggingLevel) -> None: @@ -218,79 +226,89 @@ def set_logging_level(cls, level: LoggingLevel) -> None: """ if cls.log_frozen: return + cls.logging_level = level @classmethod def set_current_node(cls, node: Optional[ASTNode]) -> None: """ - Sets the handed over node as the currently processed one. This enables a retrieval of messages for a - specific node. - :param node: a single node instance + Sets the handed over node as the currently processed one. This enables a retrieval of messages for a specific node. + + :param node: a single node instance """ cls.current_node = node @classmethod - def get_all_messages_of_level_and_or_node(cls, node: ASTNode, level: LoggingLevel) -> List[Tuple[ASTNode, LoggingLevel, str]]: + def get_all_messages_of_level_and_or_node(cls, node: Union[ASTNode, str], level: LoggingLevel) -> List[Tuple[ASTNode, LoggingLevel, str]]: """ - Returns all messages which have a certain logging level, or have been reported for a certain node, or - both. + Returns all messages which have a certain logging level, or have been reported for a certain node, or both. + :param node: a single node instance :param level: a logging level - :type level: LoggingLevel :return: a list of messages with their levels. - :rtype: list((str,Logging_Level) """ if level is None and node is None: return cls.get_log() + + if isinstance(node, str): + # search by artifact name + node_artifact_name = node + node = None + else: + # search by artifact class object + node_artifact_name = None + ret = list() for (artifactName, node_i, logLevel, code, errorPosition, message) in cls.log.values(): - if (level == logLevel if level is not None else True) and ( - node if node is not None else True) and ( - node.get_artifact_name() == artifactName if node is not None else True): + if (level == logLevel if level is not None else True) and (node if node is not None else True) and (node_artifact_name == artifactName if node is not None else True): ret.append((node, logLevel, message)) + return ret @classmethod def get_all_messages_of_level(cls, level: LoggingLevel) -> List[Tuple[ASTNode, LoggingLevel, str]]: """ Returns all messages which have a certain logging level. + :param level: a logging level - :type level: LoggingLevel :return: a list of messages with their levels. - :rtype: list((str,Logging_Level) """ if level is None: return cls.get_log() + ret = list() for (artifactName, node, logLevel, code, errorPosition, message) in cls.log.values(): if level == logLevel: ret.append((node, logLevel, message)) + return ret @classmethod def get_all_messages_of_node(cls, node: ASTNode) -> List[Tuple[ASTNode, LoggingLevel, str]]: """ Returns all messages which have been reported for a certain node. + :param node: a single node instance :return: a list of messages with their levels. - :rtype: list((str,Logging_Level) """ if node is None: return cls.get_log() + ret = list() for (artifactName, node_i, logLevel, code, errorPosition, message) in cls.log.values(): if (node_i == node if node is not None else True) and \ (node.get_artifact_name() == artifactName if node is not None else True): ret.append((node, logLevel, message)) + return ret @classmethod def has_errors(cls, node: ASTNode) -> bool: """ Indicates whether the handed over node, thus the corresponding model, has errors. + :param node: a single node instance. :return: True if errors detected, otherwise False - :rtype: bool """ return len(cls.get_all_messages_of_level_and_or_node(node, LoggingLevel.ERROR)) > 0 @@ -311,6 +329,7 @@ def get_json_format(cls) -> str: (node.get_name() if node is not None else 'GLOBAL') + '", ' + \ '"severity":"' \ + str(logLevel.name) + '", ' + if code is not None: ret += '"code":"' + \ code.name + \ @@ -323,10 +342,12 @@ def get_json_format(cls) -> str: '", ' + \ '"message":"' + str(message).replace('"', "'") + '"}' ret += ',' + if len(cls.log.keys()) == 0: parsed = json.loads('[]', object_pairs_hook=OrderedDict) else: ret = ret[:-1] # delete the last "," ret += ']' parsed = json.loads(ret, object_pairs_hook=OrderedDict) + return json.dumps(parsed, indent=2, sort_keys=False) diff --git a/pynestml/utils/mechs_info_enricher.py b/pynestml/utils/mechs_info_enricher.py index 456ece178..ea645a02c 100644 --- a/pynestml/utils/mechs_info_enricher.py +++ b/pynestml/utils/mechs_info_enricher.py @@ -22,13 +22,14 @@ from collections import defaultdict from pynestml.meta_model.ast_model import ASTModel +from pynestml.symbols.predefined_functions import PredefinedFunctions +from pynestml.symbols.symbol import SymbolKind +from pynestml.utils.ast_vector_parameter_setter_and_printer_factory import ASTVectorParameterSetterAndPrinterFactory from pynestml.visitors.ast_parent_visitor import ASTParentVisitor from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor from pynestml.utils.ast_utils import ASTUtils -from pynestml.visitors.ast_visitor import ASTVisitor from pynestml.utils.model_parser import ModelParser -from pynestml.symbols.predefined_functions import PredefinedFunctions -from pynestml.symbols.symbol import SymbolKind +from pynestml.visitors.ast_visitor import ASTVisitor class MechsInfoEnricher: @@ -57,33 +58,6 @@ def transform_ode_solutions(cls, neuron, mechs_info): solution_transformed["states"] = defaultdict() solution_transformed["propagators"] = defaultdict() - for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["initial_values"].items(): - variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, - SymbolKind.VARIABLE) - - expression = ModelParser.parse_expression(rhs_str) - # pretend that update expressions are in "equations" block, - # which should always be present, as synapses have been - # defined to get here - expression.update_scope(neuron.get_equations_blocks()[0].get_scope()) - expression.accept(ASTSymbolTableVisitor()) - - update_expr_str = ode_info["ode_toolbox_output"][ode_solution_index]["update_expressions"][ - variable_name] - update_expr_ast = ModelParser.parse_expression( - update_expr_str) - # pretend that update expressions are in "equations" block, - # which should always be present, as differential equations - # must have been defined to get here - update_expr_ast.update_scope( - neuron.get_equations_blocks()[0].get_scope()) - update_expr_ast.accept(ASTSymbolTableVisitor()) - - solution_transformed["states"][variable_name] = { - "ASTVariable": variable, - "init_expression": expression, - "update_expression": update_expr_ast, - } for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["propagators"].items(): prop_variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE) @@ -118,6 +92,36 @@ def transform_ode_solutions(cls, neuron, mechs_info): PredefinedFunctions.TIME_RESOLUTION: mechanism_info["time_resolution_var"] = variable + for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["initial_values"].items(): + variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, + SymbolKind.VARIABLE) + + expression = ModelParser.parse_expression(rhs_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as synapses have been + # defined to get here + expression.update_scope(neuron.get_equations_blocks()[0].get_scope()) + expression.accept(ASTSymbolTableVisitor()) + + update_expr_str = ode_info["ode_toolbox_output"][ode_solution_index]["update_expressions"][ + variable_name] + update_expr_ast = ModelParser.parse_expression( + update_expr_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as differential equations + # must have been defined to get here + update_expr_ast.update_scope( + neuron.get_scope()) + update_expr_ast.accept(ASTParentVisitor()) + update_expr_ast.accept(ASTSymbolTableVisitor()) + neuron.accept(ASTSymbolTableVisitor()) + + solution_transformed["states"][variable_name] = { + "ASTVariable": variable, + "init_expression": expression, + "update_expression": update_expr_ast, + } + mechanism_info["ODEs"][ode_var_name]["transformed_solutions"].append(solution_transformed) neuron.accept(ASTParentVisitor()) diff --git a/pynestml/utils/messages.py b/pynestml/utils/messages.py index 69b32a8f4..90efd06b7 100644 --- a/pynestml/utils/messages.py +++ b/pynestml/utils/messages.py @@ -18,11 +18,15 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from enum import Enum + +from __future__ import annotations + from typing import Tuple -from pynestml.meta_model.ast_inline_expression import ASTInlineExpression from collections.abc import Iterable +from enum import Enum + +from pynestml.meta_model.ast_inline_expression import ASTInlineExpression from pynestml.meta_model.ast_function import ASTFunction @@ -132,6 +136,7 @@ class MessageCode(Enum): VOID_FUNCTION_ON_RHS = 110 NON_CONSTANT_EXPONENT = 111 EXPONENT_MUST_BE_INTEGER = 112 + RANDOM_FUNCTIONS_LEGALLY_USED = 113 class Messages: @@ -158,8 +163,8 @@ def get_input_path_not_found(cls, path): return MessageCode.INPUT_PATH_NOT_FOUND, message @classmethod - def get_unknown_target(cls, target): - message = 'Unknown target ("%s")' % (target) + def get_unknown_target_platform(cls, target: str): + message = "Unknown target: '" + target + "'" return MessageCode.UNKNOWN_TARGET, message @classmethod @@ -313,22 +318,13 @@ def get_different_type_rhs_lhs( return MessageCode.CAST_NOT_POSSIBLE, message @classmethod - def get_type_different_from_expected(cls, expected_type, got_type): + def get_type_different_from_expected(cls, expected_type, got_type) -> Tuple[MessageCode, str]: """ Returns a message indicating that the received type is different from the expected one. :param expected_type: the expected type - :type expected_type: TypeSymbol :param got_type: the actual type - :type got_type: type_symbol :return: a message - :rtype: (MessageCode,str) """ - from pynestml.symbols.type_symbol import TypeSymbol - assert (expected_type is not None and isinstance(expected_type, TypeSymbol)), \ - '(PyNestML.Utils.Message) Not a type symbol provided (%s)!' % type( - expected_type) - assert (got_type is not None and isinstance(got_type, TypeSymbol)), \ - '(PyNestML.Utils.Message) Not a type symbol provided (%s)!' % type(got_type) message = 'Actual type different from expected. Expected: \'%s\', got: \'%s\'!' % ( expected_type.print_symbol(), got_type.print_symbol()) return MessageCode.TYPE_DIFFERENT_FROM_EXPECTED, message @@ -430,11 +426,10 @@ def get_module_generated(cls, path: str) -> Tuple[MessageCode, str]: return MessageCode.MODULE_SUCCESSFULLY_GENERATED, message @classmethod - def get_variable_used_before_declaration(cls, variable_name): + def get_variable_used_before_declaration(cls, variable_name: str): """ Returns a message indicating that a variable is used before declaration. :param variable_name: a variable name - :type variable_name: str :return: a message :rtype: (MessageCode,str) """ @@ -701,7 +696,7 @@ def get_model_redeclared(cls, name: str) -> Tuple[MessageCode, str]: '(PyNestML.Utils.Message) Not a string provided (%s)!' % type(name) assert (name is not None and isinstance(name, str)), \ '(PyNestML.Utils.Message) Not a string provided (%s)!' % type(name) - message = 'model \'%s\' redeclared!' % name + message = 'Model \'%s\' redeclared!' % name return MessageCode.MODEL_REDECLARED, message @classmethod @@ -1375,3 +1370,8 @@ def get_non_constant_exponent(cls) -> Tuple[MessageCode, str]: message = "Cannot calculate value of exponent. Must be a constant value!" return MessageCode.NON_CONSTANT_EXPONENT, message + + @classmethod + def get_random_functions_legally_used(cls, name): + message = "The function '" + name + "' can only be used in the update, onReceive, or onCondition blocks." + return MessageCode.RANDOM_FUNCTIONS_LEGALLY_USED, message diff --git a/pynestml/utils/model_parser.py b/pynestml/utils/model_parser.py index 7fabf361e..62a8669bb 100644 --- a/pynestml/utils/model_parser.py +++ b/pynestml/utils/model_parser.py @@ -24,6 +24,7 @@ from antlr4 import CommonTokenStream, FileStream, InputStream from antlr4.error.ErrorStrategy import BailErrorStrategy, DefaultErrorStrategy from antlr4.error.ErrorListener import ConsoleErrorListener +from pynestml.cocos.co_cos_manager import CoCosManager from pynestml.generated.PyNestMLLexer import PyNestMLLexer from pynestml.generated.PyNestMLParser import PyNestMLParser @@ -65,6 +66,7 @@ from pynestml.meta_model.ast_variable import ASTVariable from pynestml.meta_model.ast_while_stmt import ASTWhileStmt from pynestml.symbol_table.symbol_table import SymbolTable +from pynestml.transformers.assign_implicit_conversion_factors_transformer import AssignImplicitConversionFactorsTransformer from pynestml.utils.ast_source_location import ASTSourceLocation from pynestml.utils.error_listener import NestMLErrorListener from pynestml.utils.logger import Logger, LoggingLevel @@ -142,10 +144,14 @@ def parse_file(cls, file_path=None): for model in ast.get_model_list(): model.accept(ASTSymbolTableVisitor()) SymbolTable.add_model_scope(model.get_name(), model.get_scope()) + Logger.set_current_node(model) + AssignImplicitConversionFactorsTransformer().transform(model) + Logger.set_current_node(None) # store source paths for model in ast.get_model_list(): model.file_path = file_path + ast.file_path = file_path return ast diff --git a/pynestml/utils/type_caster.py b/pynestml/utils/type_caster.py index 34e4e6ccc..4ce2624dd 100644 --- a/pynestml/utils/type_caster.py +++ b/pynestml/utils/type_caster.py @@ -28,12 +28,11 @@ class TypeCaster: @staticmethod def do_magnitude_conversion_rhs_to_lhs(_rhs_type_symbol, _lhs_type_symbol, _containing_expression): """ - determine conversion factor from rhs to lhs, register it with the relevant expression + Determine conversion factor from rhs to lhs, register it with the relevant expression """ _containing_expression.set_implicit_conversion_factor( - UnitTypeSymbol.get_conversion_factor(_lhs_type_symbol.astropy_unit, - _rhs_type_symbol.astropy_unit)) - _containing_expression.type = _lhs_type_symbol + UnitTypeSymbol.get_conversion_factor(_rhs_type_symbol.astropy_unit, + _lhs_type_symbol.astropy_unit)) code, message = Messages.get_implicit_magnitude_conversion(_lhs_type_symbol, _rhs_type_symbol, _containing_expression.get_implicit_conversion_factor()) Logger.log_message(code=code, message=message, @@ -45,18 +44,26 @@ def try_to_recover_or_error(_lhs_type_symbol, _rhs_type_symbol, _containing_expr if _rhs_type_symbol.is_castable_to(_lhs_type_symbol): if isinstance(_lhs_type_symbol, UnitTypeSymbol) \ and isinstance(_rhs_type_symbol, UnitTypeSymbol): - conversion_factor = UnitTypeSymbol.get_conversion_factor( - _lhs_type_symbol.astropy_unit, _rhs_type_symbol.astropy_unit) + conversion_factor = UnitTypeSymbol.get_conversion_factor(_rhs_type_symbol.astropy_unit, _lhs_type_symbol.astropy_unit) + + if conversion_factor is None: + # error during conversion + code, message = Messages.get_type_different_from_expected(_lhs_type_symbol, _rhs_type_symbol) + Logger.log_message(error_position=_containing_expression.get_source_position(), + code=code, message=message, log_level=LoggingLevel.ERROR) + return + if not conversion_factor == 1.: # the units are mutually convertible, but require a factor unequal to 1 (e.g. mV and A*Ohm) - TypeCaster.do_magnitude_conversion_rhs_to_lhs( - _rhs_type_symbol, _lhs_type_symbol, _containing_expression) + TypeCaster.do_magnitude_conversion_rhs_to_lhs(_rhs_type_symbol, _lhs_type_symbol, _containing_expression) + # the units are mutually convertible (e.g. V and A*Ohm) code, message = Messages.get_implicit_cast_rhs_to_lhs(_rhs_type_symbol.print_symbol(), _lhs_type_symbol.print_symbol()) Logger.log_message(error_position=_containing_expression.get_source_position(), code=code, message=message, log_level=LoggingLevel.INFO) - else: - code, message = Messages.get_type_different_from_expected(_lhs_type_symbol, _rhs_type_symbol) - Logger.log_message(error_position=_containing_expression.get_source_position(), - code=code, message=message, log_level=LoggingLevel.ERROR) + return + + code, message = Messages.get_type_different_from_expected(_lhs_type_symbol, _rhs_type_symbol) + Logger.log_message(error_position=_containing_expression.get_source_position(), + code=code, message=message, log_level=LoggingLevel.ERROR) diff --git a/pynestml/visitors/ast_builder_visitor.py b/pynestml/visitors/ast_builder_visitor.py index 0e766d530..bfc4dd902 100644 --- a/pynestml/visitors/ast_builder_visitor.py +++ b/pynestml/visitors/ast_builder_visitor.py @@ -52,16 +52,17 @@ def visitNestMLCompilationUnit(self, ctx): models = list() for child in ctx.model(): models.append(self.visit(child)) + # extract the name of the artifact from the context if hasattr(ctx.start.source[1], 'fileName'): artifact_name = ntpath.basename(ctx.start.source[1].fileName) else: artifact_name = 'parsed_from_string' + compilation_unit = ASTNodeFactory.create_ast_nestml_compilation_unit(list_of_models=models, source_position=create_source_pos(ctx), artifact_name=artifact_name) - # first ensure certain properties of the model - CoCosManager.check_model_names_unique(compilation_unit) + return compilation_unit # Visit a parse tree produced by PyNESTMLParser#datatype. @@ -387,15 +388,6 @@ def visitDeclaration(self, ctx): expression = self.visit(ctx.rhs) if ctx.rhs is not None else None invariant = self.visit(ctx.invariant) if ctx.invariant is not None else None - # print("Visiting variable \"" + str(str(ctx.NAME())) + "\"...") - # # check if this variable was decorated as homogeneous - # import pynestml.generated.PyNestMLLexer - # is_homogeneous = any([isinstance(ch, pynestml.generated.PyNestMLParser.PyNestMLParser.AnyDecoratorContext) \ - # and len(ch.getTokens(pynestml.generated.PyNestMLLexer.PyNestMLLexer.DECORATOR_HOMOGENEOUS)) > 0 \ - # for ch in ctx.parentCtx.children]) - # if is_homogeneous: - # print("\t----> is homogeneous") - declaration = ASTNodeFactory.create_ast_declaration(is_recordable=is_recordable, variables=variables, data_type=data_type, diff --git a/pynestml/visitors/ast_function_call_visitor.py b/pynestml/visitors/ast_function_call_visitor.py index 7d7bf75c4..e4ec8650e 100644 --- a/pynestml/visitors/ast_function_call_visitor.py +++ b/pynestml/visitors/ast_function_call_visitor.py @@ -94,7 +94,6 @@ def visit_simple_expression(self, node: ASTSimpleExpression) -> None: # return type of the convolve function is the type of the second parameter multiplied by the unit of time (s) if function_name == PredefinedFunctions.CONVOLVE: - # Deviations from the assumptions made here are handled in the convolveCoco buffer_parameter = node.get_function_call().get_args()[1] if buffer_parameter.get_variable() is not None: diff --git a/pynestml/visitors/ast_symbol_table_visitor.py b/pynestml/visitors/ast_symbol_table_visitor.py index 011182543..bc85d4cdd 100644 --- a/pynestml/visitors/ast_symbol_table_visitor.py +++ b/pynestml/visitors/ast_symbol_table_visitor.py @@ -19,7 +19,6 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from pynestml.cocos.co_cos_manager import CoCosManager from pynestml.meta_model.ast_model import ASTModel from pynestml.meta_model.ast_model_body import ASTModelBody from pynestml.meta_model.ast_namespace_decorator import ASTNamespaceDecorator @@ -53,7 +52,6 @@ def __init__(self): self.symbol_stack = Stack() self.scope_stack = Stack() self.block_type_stack = Stack() - self.after_ast_rewrite_ = False def visit_model(self, node: ASTModel) -> None: """ @@ -79,10 +77,6 @@ def visit_model(self, node: ASTModel) -> None: node.get_scope().add_symbol(types[symbol]) def endvisit_model(self, node: ASTModel): - # before following checks occur, we need to ensure several simple properties - CoCosManager.post_symbol_table_builder_checks( - node, after_ast_rewrite=self.after_ast_rewrite_) - # update the equations for equation_block in node.get_equations_blocks(): ASTUtils.assign_ode_to_variables(equation_block) @@ -287,8 +281,7 @@ def visit_declaration(self, node: ASTDeclaration) -> None: namespace_decorators = {} for d in node.get_decorators(): if isinstance(d, ASTNamespaceDecorator): - namespace_decorators[str(d.get_namespace())] = str( - d.get_name()) + namespace_decorators[str(d.get_namespace())] = str(d.get_name()) else: decorators.append(d) @@ -296,6 +289,7 @@ def visit_declaration(self, node: ASTDeclaration) -> None: block_type = None if not self.block_type_stack.is_empty(): block_type = self.block_type_stack.top() + for var in node.get_variables(): # for all variables declared create a new symbol var.update_scope(node.get_scope()) @@ -324,11 +318,14 @@ def visit_declaration(self, node: ASTDeclaration) -> None: symbol.set_comment(node.get_comment()) node.get_scope().add_symbol(symbol) var.set_type_symbol(type_symbol) + # the data type node.get_data_type().update_scope(node.get_scope()) + # the rhs update if node.has_expression(): node.get_expression().update_scope(node.get_scope()) + # the invariant update if node.has_invariant(): node.get_invariant().update_scope(node.get_scope()) diff --git a/tests/cocos_test.py b/tests/cocos_test.py deleted file mode 100644 index f557faaf0..000000000 --- a/tests/cocos_test.py +++ /dev/null @@ -1,698 +0,0 @@ -# -*- coding: utf-8 -*- -# -# cocos_test.py -# -# This file is part of NEST. -# -# Copyright (C) 2004 The NEST Initiative -# -# NEST is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 2 of the License, or -# (at your option) any later version. -# -# NEST is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with NEST. If not, see . - -from __future__ import print_function - -import os -import unittest - -from pynestml.utils.ast_source_location import ASTSourceLocation -from pynestml.symbol_table.symbol_table import SymbolTable -from pynestml.symbols.predefined_functions import PredefinedFunctions -from pynestml.symbols.predefined_types import PredefinedTypes -from pynestml.symbols.predefined_units import PredefinedUnits -from pynestml.symbols.predefined_variables import PredefinedVariables -from pynestml.utils.logger import LoggingLevel, Logger -from pynestml.utils.model_parser import ModelParser - - -class CoCosTest(unittest.TestCase): - - def setUp(self): - Logger.init_logger(LoggingLevel.INFO) - SymbolTable.initialize_symbol_table( - ASTSourceLocation( - start_line=0, - start_column=0, - end_line=0, - end_column=0)) - PredefinedUnits.register_units() - PredefinedTypes.register_types() - PredefinedVariables.register_variables() - PredefinedFunctions.register_functions() - - def test_invalid_element_defined_after_usage(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoVariableDefinedAfterUsage.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_element_defined_after_usage(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVariableDefinedAfterUsage.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_element_in_same_line(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoElementInSameLine.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_element_in_same_line(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoElementInSameLine.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_integrate_odes_called_if_equations_defined(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoIntegrateOdesCalledIfEquationsDefined.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_integrate_odes_called_if_equations_defined(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoIntegrateOdesCalledIfEquationsDefined.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_element_not_defined_in_scope(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoVariableNotDefined.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)), 5) - - def test_valid_element_not_defined_in_scope(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVariableNotDefined.nestml')) - self.assertEqual( - len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), - 0) - - def test_variable_with_same_name_as_unit(self): - Logger.set_logging_level(LoggingLevel.NO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVariableWithSameNameAsUnit.nestml')) - self.assertEqual( - len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.WARNING)), - 3) - - def test_invalid_variable_redeclaration(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoVariableRedeclared.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_variable_redeclaration(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVariableRedeclared.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_each_block_unique(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoEachBlockUnique.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2) - - def test_valid_each_block_unique(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoEachBlockUnique.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_function_unique_and_defined(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoFunctionNotUnique.nestml')) - self.assertEqual( - len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 5) - - def test_valid_function_unique_and_defined(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoFunctionNotUnique.nestml')) - self.assertEqual( - len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_inline_expressions_have_rhs(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoInlineExpressionHasNoRhs.nestml')) - assert model is None - - def test_valid_inline_expressions_have_rhs(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoInlineExpressionHasNoRhs.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_inline_expression_has_several_lhs(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoInlineExpressionWithSeveralLhs.nestml')) - assert model is None - - def test_valid_inline_expression_has_several_lhs(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoInlineExpressionWithSeveralLhs.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_no_values_assigned_to_input_ports(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoValueAssignedToInputPort.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_no_values_assigned_to_input_ports(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoValueAssignedToInputPort.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_order_of_equations_correct(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoNoOrderOfEquations.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2) - - def test_valid_order_of_equations_correct(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoNoOrderOfEquations.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_numerator_of_unit_one(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoUnitNumeratorNotOne.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)), 2) - - def test_valid_numerator_of_unit_one(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoUnitNumeratorNotOne.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_names_of_neurons_unique(self): - Logger.init_logger(LoggingLevel.INFO) - ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoMultipleNeuronsWithEqualName.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(None, LoggingLevel.ERROR)), 1) - - def test_valid_names_of_neurons_unique(self): - Logger.init_logger(LoggingLevel.INFO) - ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoMultipleNeuronsWithEqualName.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(None, LoggingLevel.ERROR)), 0) - - def test_invalid_no_nest_collision(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoNestNamespaceCollision.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_no_nest_collision(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoNestNamespaceCollision.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_redundant_input_port_keywords_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoInputPortWithRedundantTypes.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_redundant_input_port_keywords_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoInputPortWithRedundantTypes.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_parameters_assigned_only_in_parameters_block(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoParameterAssignedOutsideBlock.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_parameters_assigned_only_in_parameters_block(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoParameterAssignedOutsideBlock.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_inline_expressions_assigned_only_in_declaration(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoAssignmentToInlineExpression.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_invalid_internals_assigned_only_in_internals_block(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoInternalAssignedOutsideBlock.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_internals_assigned_only_in_internals_block(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoInternalAssignedOutsideBlock.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_function_with_wrong_arg_number_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoFunctionCallNotConsistentWrongArgNumber.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_function_with_wrong_arg_number_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoFunctionCallNotConsistentWrongArgNumber.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_init_values_have_rhs_and_ode(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoInitValuesWithoutOde.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.WARNING)), 2) - - def test_valid_init_values_have_rhs_and_ode(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoInitValuesWithoutOde.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.WARNING)), 2) - - def test_invalid_incorrect_return_stmt_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoIncorrectReturnStatement.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 4) - - def test_valid_incorrect_return_stmt_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoIncorrectReturnStatement.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_ode_vars_outside_init_block_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoOdeVarNotInInitialValues.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_ode_vars_outside_init_block_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoOdeVarNotInInitialValues.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_convolve_correctly_defined(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoConvolveNotCorrectlyProvided.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)), 3) - - def test_valid_convolve_correctly_defined(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoConvolveNotCorrectlyProvided.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_vector_in_non_vector_declaration_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoVectorInNonVectorDeclaration.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_vector_in_non_vector_declaration_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVectorInNonVectorDeclaration.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_vector_parameter_declaration(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoVectorParameterDeclaration.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_vector_parameter_declaration(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVectorParameterDeclaration.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_vector_parameter_type(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoVectorParameterType.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_vector_parameter_type(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVectorParameterType.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_vector_parameter_size(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoVectorDeclarationSize.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2) - - def test_valid_vector_parameter_size(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVectorDeclarationSize.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_convolve_correctly_parameterized(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoConvolveNotCorrectlyParametrized.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2) - - def test_valid_convolve_correctly_parameterized(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoConvolveNotCorrectlyParametrized.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)), 0) - - def test_invalid_invariant_correctly_typed(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoInvariantNotBool.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_invariant_correctly_typed(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoInvariantNotBool.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_expression_correctly_typed(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoIllegalExpression.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)), 6) - - def test_valid_expression_correctly_typed(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoIllegalExpression.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_compound_expression_correctly_typed(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CompoundOperatorWithDifferentButCompatibleUnits.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 5) - - def test_valid_compound_expression_correctly_typed(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CompoundOperatorWithDifferentButCompatibleUnits.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_ode_correctly_typed(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoOdeIncorrectlyTyped.nestml')) - self.assertTrue(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)) > 0) - - def test_valid_ode_correctly_typed(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoOdeCorrectlyTyped.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_output_block_defined_if_emit_call(self): - """test that an error is raised when the emit_spike() function is called by the neuron, but an output block is not defined""" - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoOutputPortDefinedIfEmitCall.nestml')) - self.assertTrue(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)) > 0) - - def test_invalid_output_port_defined_if_emit_call(self): - """test that an error is raised when the emit_spike() function is called by the neuron, but a spiking output port is not defined""" - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoOutputPortDefinedIfEmitCall-2.nestml')) - self.assertTrue(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)) > 0) - - def test_valid_output_port_defined_if_emit_call(self): - """test that no error is raised when the output block is missing, but not emit_spike() functions are called""" - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoOutputPortDefinedIfEmitCall.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_valid_coco_kernel_type(self): - """ - Test the functionality of CoCoKernelType. - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoKernelType.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_coco_kernel_type(self): - """ - Test the functionality of CoCoKernelType. - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoKernelType.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_invalid_coco_kernel_type_initial_values(self): - """ - Test the functionality of CoCoKernelType. - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoKernelTypeInitialValues.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 4) - - def test_valid_coco_state_variables_initialized(self): - """ - Test that the CoCo condition is applicable for all the variables in the state block initialized with a value - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoStateVariablesInitialized.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_coco_state_variables_initialized(self): - """ - Test that the CoCo condition is applicable for all the variables in the state block not initialized - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoStateVariablesInitialized.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2) - - def test_invalid_co_co_priorities_correctly_specified(self): - """ - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoPrioritiesCorrectlySpecified.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_co_co_priorities_correctly_specified(self): - """ - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoPrioritiesCorrectlySpecified.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_co_co_resolution_legally_used(self): - """ - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoResolutionLegallyUsed.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2) - - def test_valid_co_co_resolution_legally_used(self): - """ - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoResolutionLegallyUsed.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_valid_co_co_vector_input_port(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVectorInputPortSizeAndType.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_co_co_vector_input_port(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoVectorInputPortSizeAndType.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) diff --git a/tests/function_parameter_templating_test.py b/tests/function_parameter_templating_test.py deleted file mode 100644 index e3cb89e41..000000000 --- a/tests/function_parameter_templating_test.py +++ /dev/null @@ -1,57 +0,0 @@ -# -*- coding: utf-8 -*- -# -# function_parameter_templating_test.py -# -# This file is part of NEST. -# -# Copyright (C) 2004 The NEST Initiative -# -# NEST is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 2 of the License, or -# (at your option) any later version. -# -# NEST is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with NEST. If not, see . - -import os -import unittest - -from pynestml.symbol_table.symbol_table import SymbolTable -from pynestml.symbols.predefined_functions import PredefinedFunctions -from pynestml.symbols.predefined_types import PredefinedTypes -from pynestml.symbols.predefined_units import PredefinedUnits -from pynestml.symbols.predefined_variables import PredefinedVariables -from pynestml.utils.ast_source_location import ASTSourceLocation -from pynestml.utils.logger import Logger, LoggingLevel -from pynestml.utils.model_parser import ModelParser - -# minor setup steps required -SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0)) -PredefinedUnits.register_units() -PredefinedTypes.register_types() -PredefinedVariables.register_variables() -PredefinedFunctions.register_functions() - - -class FunctionParameterTemplatingTest(unittest.TestCase): - """ - This test is used to test the correct derivation of types when functions use templated type parameters. - """ - - def test(self): - Logger.init_logger(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), - "resources", "FunctionParameterTemplatingTest.nestml")))) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)), 7) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/nest_compartmental_tests/test__cocos.py b/tests/nest_compartmental_tests/test__cocos.py index dc4daa28c..7ee55f8a1 100644 --- a/tests/nest_compartmental_tests/test__cocos.py +++ b/tests/nest_compartmental_tests/test__cocos.py @@ -21,41 +21,39 @@ from __future__ import print_function +from typing import Optional + import os import pytest -from pynestml.frontend.frontend_configuration import FrontendConfiguration - -from pynestml.utils.ast_source_location import ASTSourceLocation +from pynestml.meta_model.ast_model import ASTModel from pynestml.symbol_table.symbol_table import SymbolTable from pynestml.symbols.predefined_functions import PredefinedFunctions from pynestml.symbols.predefined_types import PredefinedTypes from pynestml.symbols.predefined_units import PredefinedUnits from pynestml.symbols.predefined_variables import PredefinedVariables +from pynestml.utils.ast_source_location import ASTSourceLocation from pynestml.utils.logger import LoggingLevel, Logger from pynestml.utils.model_parser import ModelParser -@pytest.fixture -def setUp(): - Logger.init_logger(LoggingLevel.INFO) - SymbolTable.initialize_symbol_table( - ASTSourceLocation( - start_line=0, - start_column=0, - end_line=0, - end_column=0)) - PredefinedUnits.register_units() - PredefinedTypes.register_types() - PredefinedVariables.register_variables() - PredefinedFunctions.register_functions() - FrontendConfiguration.target_platform = "NEST_COMPARTMENTAL" - - class TestCoCos: - def test_invalid_cm_variables_declared(self, setUp): - model = ModelParser.parse_file( + @pytest.fixture(scope="module", autouse=True) + def setUp(self): + SymbolTable.initialize_symbol_table( + ASTSourceLocation( + start_line=0, + start_column=0, + end_line=0, + end_column=0)) + PredefinedUnits.register_units() + PredefinedTypes.register_types() + PredefinedVariables.register_variables() + PredefinedFunctions.register_functions() + + def test_invalid_cm_variables_declared(self): + model = self._parse_and_validate_model( os.path.join( os.path.realpath( os.path.join( @@ -63,11 +61,10 @@ def test_invalid_cm_variables_declared(self, setUp): 'invalid')), 'CoCoCmVariablesDeclared.nestml')) assert len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)) == 5 + model, LoggingLevel.ERROR)) == 6 - def test_valid_cm_variables_declared(self, setUp): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( + def test_valid_cm_variables_declared(self): + model = self._parse_and_validate_model( os.path.join( os.path.realpath( os.path.join( @@ -75,12 +72,12 @@ def test_valid_cm_variables_declared(self, setUp): 'valid')), 'CoCoCmVariablesDeclared.nestml')) assert len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)) == 0 + model, LoggingLevel.ERROR)) == 0 # it is currently not enforced for the non-cm parameter block, but cm # needs that - def test_invalid_cm_variable_has_rhs(self, setUp): - model = ModelParser.parse_file( + def test_invalid_cm_variable_has_rhs(self): + model = self._parse_and_validate_model( os.path.join( os.path.realpath( os.path.join( @@ -88,11 +85,11 @@ def test_invalid_cm_variable_has_rhs(self, setUp): 'invalid')), 'CoCoCmVariableHasRhs.nestml')) assert len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)) == 2 + model, LoggingLevel.ERROR)) == 2 - def test_valid_cm_variable_has_rhs(self, setUp): + def test_valid_cm_variable_has_rhs(self): Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( + model = self._parse_and_validate_model( os.path.join( os.path.realpath( os.path.join( @@ -100,12 +97,12 @@ def test_valid_cm_variable_has_rhs(self, setUp): 'valid')), 'CoCoCmVariableHasRhs.nestml')) assert len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)) == 0 + model, LoggingLevel.ERROR)) == 0 # it is currently not enforced for the non-cm parameter block, but cm # needs that - def test_invalid_cm_v_comp_exists(self, setUp): - model = ModelParser.parse_file( + def test_invalid_cm_v_comp_exists(self): + model = self._parse_and_validate_model( os.path.join( os.path.realpath( os.path.join( @@ -113,11 +110,11 @@ def test_invalid_cm_v_comp_exists(self, setUp): 'invalid')), 'CoCoCmVcompExists.nestml')) assert len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)) == 4 + model, LoggingLevel.ERROR)) == 4 - def test_valid_cm_v_comp_exists(self, setUp): + def test_valid_cm_v_comp_exists(self): Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( + model = self._parse_and_validate_model( os.path.join( os.path.realpath( os.path.join( @@ -125,4 +122,23 @@ def test_valid_cm_v_comp_exists(self, setUp): 'valid')), 'CoCoCmVcompExists.nestml')) assert len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)) == 0 + model, LoggingLevel.ERROR)) == 0 + + def _parse_and_validate_model(self, fname: str) -> Optional[str]: + from pynestml.frontend.pynestml_frontend import generate_target + + Logger.init_logger(LoggingLevel.DEBUG) + + try: + generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG") + except BaseException: + return None + + ast_compilation_unit = ModelParser.parse_file(fname) + if ast_compilation_unit is None or len(ast_compilation_unit.get_model_list()) == 0: + return None + + model: ASTModel = ast_compilation_unit.get_model_list()[0] + model_name = model.get_name() + + return model_name diff --git a/tests/nest_continuous_benchmarking/test_nest_continuous_benchmarking.py b/tests/nest_continuous_benchmarking/test_nest_continuous_benchmarking.py new file mode 100644 index 000000000..70e94706b --- /dev/null +++ b/tests/nest_continuous_benchmarking/test_nest_continuous_benchmarking.py @@ -0,0 +1,311 @@ +# -*- coding: utf-8 -*- +# +# test_nest_continuous_benchmarking.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import numpy as np +import os +import pytest + +import nest + +from pynestml.frontend.pynestml_frontend import generate_nest_target + +try: + import matplotlib + matplotlib.use('Agg') + import matplotlib.ticker + import matplotlib.pyplot as plt + TEST_PLOTS = True +except Exception: + TEST_PLOTS = False + +sim_mdl = True +sim_ref = True + + +class TestNESTContinuousBenchmarking: + + neuron_model_name = "iaf_psc_exp_neuron_nestml__with_stdp_nn_symm_synapse_nestml" + ref_neuron_model_name = "iaf_psc_exp_neuron_nestml_non_jit" + + synapse_model_name = "stdp_nn_symm_synapse_nestml__with_iaf_psc_exp_neuron_nestml" + ref_synapse_model_name = "stdp_nn_symm_synapse" + + @pytest.fixture(scope="module", autouse=True) + def setUp(self): + """Generate the neuron model code""" + + # generate the "jit" model (co-generated neuron and synapse), that does not rely on ArchivingNode + files = [os.path.join("models", "neurons", "iaf_psc_exp_neuron.nestml"), + os.path.join("models", "synapses", "stdp_nn_symm_synapse.nestml")] + input_path = [os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join( + os.pardir, os.pardir, s))) for s in files] + generate_nest_target(input_path=input_path, + target_path="/tmp/nestml-jit", + logging_level="INFO", + module_name="nestml_jit_module", + suffix="_nestml", + codegen_opts={"neuron_parent_class": "StructuralPlasticityNode", + "neuron_parent_class_include": "structural_plasticity_node.h", + "neuron_synapse_pairs": [{"neuron": "iaf_psc_exp_neuron", + "synapse": "stdp_nn_symm_synapse", + "post_ports": ["post_spikes"]}], + "delay_variable": {"stdp_nn_symm_synapse": "d"}, + "weight_variable": {"stdp_nn_symm_synapse": "w"}}) + + # generate the "non-jit" model, that relies on ArchivingNode + generate_nest_target(input_path=os.path.realpath(os.path.join(os.path.dirname(__file__), + os.path.join(os.pardir, os.pardir, "models", "neurons", "iaf_psc_exp_neuron.nestml"))), + target_path="/tmp/nestml-non-jit", + logging_level="INFO", + module_name="nestml_non_jit_module", + suffix="_nestml_non_jit", + codegen_opts={"neuron_parent_class": "ArchivingNode", + "neuron_parent_class_include": "archiving_node.h"}) + + @pytest.mark.benchmark + def test_stdp_nn_synapse(self, benchmark): + return benchmark(self._test_stdp_nn_synapse) + + def _test_stdp_nn_synapse(self): + + fname_snip = "" + + pre_spike_times = [1., 11., 21.] # [ms] + post_spike_times = [6., 16., 26.] # [ms] + + post_spike_times = np.sort(np.unique(1 + np.round(10 * np.sort(np.abs(np.random.randn(10)))))) # [ms] + pre_spike_times = np.sort(np.unique(1 + np.round(10 * np.sort(np.abs(np.random.randn(10)))))) # [ms] + + post_spike_times = np.sort(np.unique(1 + np.round(100 * np.sort(np.abs(np.random.randn(100)))))) # [ms] + pre_spike_times = np.sort(np.unique(1 + np.round(100 * np.sort(np.abs(np.random.randn(100)))))) # [ms] + + pre_spike_times = np.array([2., 4., 7., 8., 12., 13., 19., 23., 24., 28., 29., 30., 33., 34., + 35., 36., 38., 40., 42., 46., 51., 53., 54., 55., 56., 59., 63., 64., + 65., 66., 68., 72., 73., 76., 79., 80., 83., 84., 86., 87., 90., 95., + 99., 100., 103., 104., 105., 111., 112., 126., 131., 133., 134., 139., 147., 150., + 152., 155., 172., 175., 176., 181., 196., 197., 199., 202., 213., 215., 217., 265.]) + post_spike_times = np.array([4., 5., 6., 7., 10., 11., 12., 16., 17., 18., 19., 20., 22., 23., + 25., 27., 29., 30., 31., 32., 34., 36., 37., 38., 39., 42., 44., 46., + 48., 49., 50., 54., 56., 57., 59., 60., 61., 62., 67., 74., 76., 79., + 80., 81., 83., 88., 93., 94., 97., 99., 100., 105., 111., 113., 114., 115., + 116., 119., 123., 130., 132., 134., 135., 145., 152., 155., 158., 166., 172., 174., + 188., 194., 202., 245., 249., 289., 454.]) + + self.run_synapse_test(neuron_model_name=self.neuron_model_name, + ref_neuron_model_name=self.ref_neuron_model_name, + synapse_model_name=self.synapse_model_name, + ref_synapse_model_name=self.ref_synapse_model_name, + resolution=1., # [ms] + delay=1., # [ms] + pre_spike_times=pre_spike_times, + post_spike_times=post_spike_times, + fname_snip=fname_snip) + + def run_synapse_test(self, neuron_model_name, + ref_neuron_model_name, + synapse_model_name, + ref_synapse_model_name, + resolution=1., # [ms] + delay=1., # [ms] + sim_time=None, # if None, computed from pre and post spike times + pre_spike_times=None, + post_spike_times=None, + fname_snip=""): + + if pre_spike_times is None: + pre_spike_times = [] + + if post_spike_times is None: + post_spike_times = [] + + if sim_time is None: + sim_time = max(np.amax(pre_spike_times), np.amax(post_spike_times)) + 5 * delay + + nest.ResetKernel() + nest.set_verbosity("M_ALL") + nest.SetKernelStatus({'resolution': resolution}) + + if sim_mdl: + try: + nest.Install("nestml_jit_module") + except Exception: + # ResetKernel() does not unload modules for NEST Simulator < v3.7; ignore exception if module is already loaded on earlier versions + pass + + if sim_ref: + try: + nest.Install("nestml_non_jit_module") + except Exception: + # ResetKernel() does not unload modules for NEST Simulator < v3.7; ignore exception if module is already loaded on earlier versions + pass + + print("Pre spike times: " + str(pre_spike_times)) + print("Post spike times: " + str(post_spike_times)) + + wr = nest.Create('weight_recorder') + wr_ref = nest.Create('weight_recorder') + if sim_mdl: + nest.CopyModel(synapse_model_name, "stdp_nestml_rec", + {"weight_recorder": wr[0], "w": 1., "d": 1., "receptor_type": 0}) + if sim_ref: + nest.CopyModel(ref_synapse_model_name, "stdp_ref_rec", + {"weight_recorder": wr_ref[0], "weight": 1., "delay": 1., "receptor_type": 0}) + + # create spike_generators with these times + pre_sg = nest.Create("spike_generator", + params={"spike_times": pre_spike_times}) + post_sg = nest.Create("spike_generator", + params={"spike_times": post_spike_times, + 'allow_offgrid_times': True}) + + # create parrot neurons and connect spike_generators + if sim_mdl: + pre_neuron = nest.Create("parrot_neuron") + post_neuron = nest.Create(neuron_model_name) + + if sim_ref: + pre_neuron_ref = nest.Create("parrot_neuron") + post_neuron_ref = nest.Create(ref_neuron_model_name) + + if sim_mdl: + spikedet_pre = nest.Create("spike_recorder") + spikedet_post = nest.Create("spike_recorder") + mm = nest.Create("multimeter", params={"record_from": ["V_m", "post_trace__for_stdp_nn_symm_synapse_nestml"]}) + if sim_ref: + spikedet_pre_ref = nest.Create("spike_recorder") + spikedet_post_ref = nest.Create("spike_recorder") + mm_ref = nest.Create("multimeter", params={"record_from": ["V_m"]}) + + if sim_mdl: + nest.Connect(pre_sg, pre_neuron, "one_to_one", syn_spec={"delay": 1.}) + nest.Connect(post_sg, post_neuron, "one_to_one", syn_spec={"delay": 1., "weight": 9999.}) + nest.Connect(pre_neuron, post_neuron, "all_to_all", syn_spec={'synapse_model': 'stdp_nestml_rec'}) + nest.Connect(mm, post_neuron) + nest.Connect(pre_neuron, spikedet_pre) + nest.Connect(post_neuron, spikedet_post) + if sim_ref: + nest.Connect(pre_sg, pre_neuron_ref, "one_to_one", syn_spec={"delay": 1.}) + nest.Connect(post_sg, post_neuron_ref, "one_to_one", syn_spec={"delay": 1., "weight": 9999.}) + nest.Connect(pre_neuron_ref, post_neuron_ref, "all_to_all", + syn_spec={'synapse_model': ref_synapse_model_name}) + nest.Connect(mm_ref, post_neuron_ref) + nest.Connect(pre_neuron_ref, spikedet_pre_ref) + nest.Connect(post_neuron_ref, spikedet_post_ref) + + # get STDP synapse and weight before protocol + if sim_mdl: + syn = nest.GetConnections(source=pre_neuron, synapse_model="stdp_nestml_rec") + if sim_ref: + syn_ref = nest.GetConnections(source=pre_neuron_ref, synapse_model=ref_synapse_model_name) + + n_steps = int(np.ceil(sim_time / resolution)) + 1 + t = 0. + t_hist = [] + if sim_mdl: + w_hist = [] + if sim_ref: + w_hist_ref = [] + while t <= sim_time: + nest.Simulate(resolution) + t += resolution + t_hist.append(t) + if sim_ref: + w_hist_ref.append(nest.GetStatus(syn_ref)[0]['weight']) + if sim_mdl: + w_hist.append(nest.GetStatus(syn)[0]['w']) + + # plot + if TEST_PLOTS: + fig, ax = plt.subplots(nrows=3) + ax1, ax2, ax3 = ax + + if sim_mdl: + pre_spike_times_ = nest.GetStatus(spikedet_pre, "events")[0]["times"] + print("Actual pre spike times: " + str(pre_spike_times_)) + if sim_ref: + pre_ref_spike_times_ = nest.GetStatus(spikedet_pre_ref, "events")[0]["times"] + print("Actual pre ref spike times: " + str(pre_ref_spike_times_)) + + if sim_mdl: + n_spikes = len(pre_spike_times_) + for i in range(n_spikes): + if i == 0: + _lbl = "nestml" + else: + _lbl = None + ax1.plot(2 * [pre_spike_times_[i] + delay], [0, 1], linewidth=2, color="blue", alpha=.4, label=_lbl) + + if sim_mdl: + post_spike_times_ = nest.GetStatus(spikedet_post, "events")[0]["times"] + print("Actual post spike times: " + str(post_spike_times_)) + if sim_ref: + post_ref_spike_times_ = nest.GetStatus(spikedet_post_ref, "events")[0]["times"] + print("Actual post ref spike times: " + str(post_ref_spike_times_)) + + if sim_ref: + n_spikes = len(pre_ref_spike_times_) + for i in range(n_spikes): + if i == 0: + _lbl = "nest ref" + else: + _lbl = None + ax1.plot(2 * [pre_ref_spike_times_[i] + delay], [0, 1], + linewidth=2, color="cyan", label=_lbl, alpha=.4) + ax1.set_ylabel("Pre spikes") + + if sim_mdl: + n_spikes = len(post_spike_times_) + for i in range(n_spikes): + if i == 0: + _lbl = "nestml" + else: + _lbl = None + ax2.plot(2 * [post_spike_times_[i]], [0, 1], linewidth=2, color="black", alpha=.4, label=_lbl) + if sim_ref: + n_spikes = len(post_ref_spike_times_) + for i in range(n_spikes): + if i == 0: + _lbl = "nest ref" + else: + _lbl = None + ax2.plot(2 * [post_ref_spike_times_[i]], [0, 1], linewidth=2, color="red", alpha=.4, label=_lbl) + if sim_mdl: + ax2.plot(nest.GetStatus(mm, "events")[0]["times"], nest.GetStatus(mm, "events")[ + 0]["post_trace__for_stdp_nn_symm_synapse_nestml"], label="nestml post tr") + ax2.set_ylabel("Post spikes") + + if sim_mdl: + ax3.plot(t_hist, w_hist, marker="o", label="nestml") + if sim_ref: + ax3.plot(t_hist, w_hist_ref, linestyle="--", marker="x", label="ref") + + ax3.set_xlabel("Time [ms]") + ax3.set_ylabel("w") + for _ax in ax: + _ax.grid(which="major", axis="both") + _ax.xaxis.set_major_locator(matplotlib.ticker.FixedLocator(np.arange(0, np.ceil(sim_time)))) + _ax.set_xlim(0., sim_time) + _ax.legend() + fig.savefig("/tmp/stdp_synapse_test" + fname_snip + ".png", dpi=300) + + # verify + MAX_ABS_ERROR = 1E-6 + assert np.all(np.abs(np.array(w_hist) - np.array(w_hist_ref)) < MAX_ABS_ERROR) diff --git a/tests/nest_tests/nest_delay_based_variables_test.py b/tests/nest_tests/nest_delay_based_variables_test.py index 51f863e19..a11c280f2 100644 --- a/tests/nest_tests/nest_delay_based_variables_test.py +++ b/tests/nest_tests/nest_delay_based_variables_test.py @@ -19,13 +19,12 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + import numpy as np import os -from typing import List import pytest -import nest - try: import matplotlib import matplotlib.pyplot as plt @@ -34,15 +33,12 @@ except BaseException: TEST_PLOTS = False +import nest + from pynestml.codegeneration.nest_tools import NESTTools from pynestml.frontend.pynestml_frontend import generate_nest_target -target_path = "target_delay" -logging_level = "DEBUG" -suffix = "_nestml" - - def plot_fig(times, recordable_events_delay: dict, recordable_events: dict, filename: str): fig, axes = plt.subplots(len(recordable_events), 1, figsize=(7, 9), sharex=True) for i, recordable_name in enumerate(recordable_events_delay.keys()): @@ -86,6 +82,9 @@ def run_simulation(neuron_model_name: str, module_name: str, recordables: List[s ("DelayDifferentialEquationsWithNumericSolver.nestml", "dde_numeric_nestml", ["x", "z"]), ("DelayDifferentialEquationsWithMixedSolver.nestml", "dde_mixed_nestml", ["x", "z"])]) def test_dde_with_analytic_solver(file_name: str, neuron_model_name: str, recordables: List[str]): + target_path = "target_delay" + logging_level = "DEBUG" + suffix = "_nestml" input_path = os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), "resources", file_name))) module_name = neuron_model_name + "_module" print("Module name: ", module_name) @@ -112,16 +111,3 @@ def test_dde_with_analytic_solver(file_name: str, neuron_model_name: str, record if neuron_model_name == "dde_analytic_nestml": np.testing.assert_allclose(recordable_events_delay[recordables[1]][int(delay):], recordable_events[recordables[1]][:-int(delay)]) - - @pytest.fixture(scope="function", autouse=True) - def cleanup(self): - # Run the test - yield - - # clean up - import shutil - if self.target_path: - try: - shutil.rmtree(self.target_path) - except Exception: - pass diff --git a/tests/nest_tests/nest_random_functions_test.py b/tests/nest_tests/nest_random_functions_test.py new file mode 100644 index 000000000..3d24ed963 --- /dev/null +++ b/tests/nest_tests/nest_random_functions_test.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# +# nest_random_functions_test.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . +import os + +import pytest + +from pynestml.frontend.pynestml_frontend import generate_nest_target + + +class TestNestRandomFunctions: + """ + Tests that, for the NEST target, random number functions are called only in ``update``, ``onReceive``, and ``onCondition`` block + """ + + @pytest.mark.xfail(strict=True, raises=Exception) + def test_nest_random_function_neuron_illegal(self): + input_path = os.path.realpath(os.path.join(os.path.dirname(__file__), + "resources", "random_functions_illegal_neuron.nestml")) + generate_nest_target(input_path=input_path, + target_path="target", + logging_level="INFO", + suffix="_nestml") + + @pytest.mark.xfail(strict=True, raises=Exception) + def test_nest_random_function_synapse_illegal(self): + input_path = [ + os.path.realpath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, "models", "neurons", + "iaf_psc_exp_neuron.nestml")), + os.path.realpath(os.path.join(os.path.dirname(__file__), + "resources", "random_functions_illegal_synapse.nestml"))] + + generate_nest_target(input_path=input_path, + target_path="target", + logging_level="INFO", + suffix="_nestml", + codegen_opts={"neuron_synapse_pairs": [{"neuron": "iaf_psc_exp_neuron", + "synapse": "random_functions_illegal_synapse", + "post_ports": ["post_spikes"]}], + "weight_variable": {"stdp_synapse": "w"}}) diff --git a/tests/nest_tests/resources/integrate_odes_test_params.nestml b/tests/nest_tests/resources/integrate_odes_test_params.nestml index d07fe8fd4..d6430e537 100644 --- a/tests/nest_tests/resources/integrate_odes_test_params.nestml +++ b/tests/nest_tests/resources/integrate_odes_test_params.nestml @@ -8,7 +8,6 @@ model integrate_odes_test: update: integrate_odes(2 * test_1) - integrate_odes(test_3) integrate_odes(100 ms) integrate_odes(test_1) integrate_odes(test_2) diff --git a/tests/nest_tests/resources/integrate_odes_test_params2.nestml b/tests/nest_tests/resources/integrate_odes_test_params2.nestml new file mode 100644 index 000000000..616401e48 --- /dev/null +++ b/tests/nest_tests/resources/integrate_odes_test_params2.nestml @@ -0,0 +1,10 @@ +""" +Model for testing the integrate_odes() function. +""" +model integrate_odes_test: + state: + test_1 real = 0. + test_2 real = 0. + + update: + integrate_odes(test_3) diff --git a/tests/nest_tests/resources/random_functions_illegal_neuron.nestml b/tests/nest_tests/resources/random_functions_illegal_neuron.nestml new file mode 100644 index 000000000..9954459d3 --- /dev/null +++ b/tests/nest_tests/resources/random_functions_illegal_neuron.nestml @@ -0,0 +1,46 @@ +""" +random_functions_illegal_neuron.nestml +###################################### + + +Copyright statement ++++++++++++++++++++ + +This file is part of NEST. + +Copyright (C) 2004 The NEST Initiative + +NEST is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 2 of the License, or +(at your option) any later version. + +NEST is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with NEST. If not, see . +""" + +model random_functions_illegal_neuron: + state: + noise real = random_normal(0,sigma_noise) + v mV = -15 mV + + parameters: + rate ms**-1 = 15.5 s**-1 + sigma_noise real = 16. + u real = random_uniform(0,1) + + internals: + poisson_input integer = random_poisson(rate * resolution() * 1E-3) + + update: + if u < 0.5: + noise = 0. + else: + noise = random_normal(0,sigma_noise) + + v += (poisson_input + noise) * mV diff --git a/tests/nest_tests/resources/random_functions_illegal_synapse.nestml b/tests/nest_tests/resources/random_functions_illegal_synapse.nestml new file mode 100644 index 000000000..473791ec7 --- /dev/null +++ b/tests/nest_tests/resources/random_functions_illegal_synapse.nestml @@ -0,0 +1,73 @@ +""" +random_functions_illegal_synapse.nestml +####################################### + + +Copyright statement ++++++++++++++++++++ + +This file is part of NEST. + +Copyright (C) 2004 The NEST Initiative + +NEST is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, either version 2 of the License, or +(at your option) any later version. + +NEST is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with NEST. If not, see . +""" + +model random_functions_illegal_synapse: + state: + w real = 1 # Synaptic weight + pre_trace real = 0. + post_trace real = 0. + + parameters: + d ms = 1 ms # Synaptic transmission delay + lambda real = .01 + tau_tr_pre ms = random_normal(110 ms, 55 ms) + tau_tr_post ms = random_normal(5 ms, 2.5 ms) + alpha real = 1 + mu_plus real = 1 + mu_minus real = 1 + Wmax real = 100. + Wmin real = 0. + + equations: + pre_trace' = -pre_trace / tau_tr_pre + post_trace' = -post_trace / tau_tr_post + + input: + pre_spikes <- spike + post_spikes <- spike + + output: + spike + + onReceive(post_spikes): + post_trace += 1 + + # potentiate synapse + w_ real = Wmax * ( w / Wmax + (lambda * ( 1. - ( w / Wmax ) )**mu_plus * pre_trace )) + w = min(Wmax, w_) + + onReceive(pre_spikes): + pre_trace += 1 + + # depress synapse + w_ real = Wmax * ( w / Wmax - ( alpha * lambda * ( w / Wmax )**mu_minus * post_trace )) + w = max(Wmin, w_) + + # deliver spike to postsynaptic partner + emit_spike(w, d) + + update: + integrate_odes() diff --git a/tests/nest_tests/test_integrate_odes.py b/tests/nest_tests/test_integrate_odes.py index 99b94c6ca..6ddb699b4 100644 --- a/tests/nest_tests/test_integrate_odes.py +++ b/tests/nest_tests/test_integrate_odes.py @@ -27,16 +27,9 @@ import nest -from pynestml.utils.ast_source_location import ASTSourceLocation -from pynestml.symbol_table.symbol_table import SymbolTable -from pynestml.symbols.predefined_functions import PredefinedFunctions -from pynestml.symbols.predefined_types import PredefinedTypes -from pynestml.symbols.predefined_units import PredefinedUnits -from pynestml.symbols.predefined_variables import PredefinedVariables from pynestml.codegeneration.nest_tools import NESTTools -from pynestml.frontend.pynestml_frontend import generate_nest_target +from pynestml.frontend.pynestml_frontend import generate_nest_target, generate_target from pynestml.utils.logger import LoggingLevel, Logger -from pynestml.utils.model_parser import ModelParser try: import matplotlib @@ -227,12 +220,15 @@ def test_integrate_odes_nonlinear(self): def test_integrate_odes_params(self): r"""Test the integrate_odes() function, in particular with respect to the parameter types.""" - Logger.init_logger(LoggingLevel.INFO) - SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0)) - PredefinedUnits.register_units() - PredefinedTypes.register_types() - PredefinedVariables.register_variables() - PredefinedFunctions.register_functions() - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file(os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params.nestml")))) - assert len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)) == 6 + fname = os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params.nestml"))) + generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG") + + assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2 + + def test_integrate_odes_params2(self): + r"""Test the integrate_odes() function, in particular with respect to non-existent parameter variables.""" + + fname = os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params2.nestml"))) + generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG") + + assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2 diff --git a/tests/test_cocos.py b/tests/test_cocos.py new file mode 100644 index 000000000..81f519eaf --- /dev/null +++ b/tests/test_cocos.py @@ -0,0 +1,403 @@ +# -*- coding: utf-8 -*- +# +# test_cocos.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from __future__ import print_function + +from typing import Optional + +import os +import pytest + +from pynestml.meta_model.ast_model import ASTModel +from pynestml.symbol_table.symbol_table import SymbolTable +from pynestml.symbols.predefined_functions import PredefinedFunctions +from pynestml.symbols.predefined_types import PredefinedTypes +from pynestml.symbols.predefined_units import PredefinedUnits +from pynestml.symbols.predefined_variables import PredefinedVariables +from pynestml.utils.ast_source_location import ASTSourceLocation +from pynestml.utils.logger import LoggingLevel, Logger +from pynestml.utils.model_parser import ModelParser + + +class TestCoCos: + + @pytest.fixture(scope="module", autouse=True) + def setUp(self): + SymbolTable.initialize_symbol_table( + ASTSourceLocation( + start_line=0, + start_column=0, + end_line=0, + end_column=0)) + PredefinedUnits.register_units() + PredefinedTypes.register_types() + PredefinedVariables.register_variables() + PredefinedFunctions.register_functions() + + def test_invalid_element_defined_after_usage(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVariableDefinedAfterUsage.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_element_defined_after_usage(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVariableDefinedAfterUsage.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_element_in_same_line(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoElementInSameLine.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_element_in_same_line(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoElementInSameLine.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_integrate_odes_called_if_equations_defined(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoIntegrateOdesCalledIfEquationsDefined.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_integrate_odes_called_if_equations_defined(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoIntegrateOdesCalledIfEquationsDefined.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_element_not_defined_in_scope(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVariableNotDefined.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 6 + + def test_valid_element_not_defined_in_scope(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVariableNotDefined.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_variable_with_same_name_as_unit(self): + Logger.set_logging_level(LoggingLevel.NO) + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVariableWithSameNameAsUnit.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.WARNING)) == 3 + + def test_invalid_variable_redeclaration(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVariableRedeclared.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_variable_redeclaration(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVariableRedeclared.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_each_block_unique(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoEachBlockUnique.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_each_block_unique(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoEachBlockUnique.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_function_unique_and_defined(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoFunctionNotUnique.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 8 + + def test_valid_function_unique_and_defined(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoFunctionNotUnique.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_inline_expressions_have_rhs(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInlineExpressionHasNoRhs.nestml')) + assert model is None + + def test_valid_inline_expressions_have_rhs(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInlineExpressionHasNoRhs.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_inline_expression_has_several_lhs(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInlineExpressionWithSeveralLhs.nestml')) + assert model is None + + def test_valid_inline_expression_has_several_lhs(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInlineExpressionWithSeveralLhs.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_no_values_assigned_to_input_ports(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoValueAssignedToInputPort.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_no_values_assigned_to_input_ports(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoValueAssignedToInputPort.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_order_of_equations_correct(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoNoOrderOfEquations.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_order_of_equations_correct(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoNoOrderOfEquations.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_numerator_of_unit_one(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoUnitNumeratorNotOne.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_numerator_of_unit_one(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoUnitNumeratorNotOne.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_names_of_neurons_unique(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoMultipleNeuronsWithEqualName.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 3 + + def test_valid_names_of_neurons_unique(self): + self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoMultipleNeuronsWithEqualName.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(None, LoggingLevel.ERROR)) == 0 + + def test_invalid_no_nest_collision(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoNestNamespaceCollision.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_no_nest_collision(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoNestNamespaceCollision.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_redundant_input_port_keywords_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInputPortWithRedundantTypes.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_redundant_input_port_keywords_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInputPortWithRedundantTypes.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_parameters_assigned_only_in_parameters_block(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoParameterAssignedOutsideBlock.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_parameters_assigned_only_in_parameters_block(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoParameterAssignedOutsideBlock.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_inline_expressions_assigned_only_in_declaration(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoAssignmentToInlineExpression.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_invalid_internals_assigned_only_in_internals_block(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInternalAssignedOutsideBlock.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_internals_assigned_only_in_internals_block(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInternalAssignedOutsideBlock.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_function_with_wrong_arg_number_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoFunctionCallNotConsistentWrongArgNumber.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_function_with_wrong_arg_number_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoFunctionCallNotConsistentWrongArgNumber.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_init_values_have_rhs_and_ode(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInitValuesWithoutOde.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.WARNING)) == 2 + + def test_valid_init_values_have_rhs_and_ode(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInitValuesWithoutOde.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.WARNING)) == 2 + + def test_invalid_incorrect_return_stmt_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoIncorrectReturnStatement.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 8 + + def test_valid_incorrect_return_stmt_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoIncorrectReturnStatement.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_ode_vars_outside_init_block_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoOdeVarNotInInitialValues.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_ode_vars_outside_init_block_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoOdeVarNotInInitialValues.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_convolve_correctly_defined(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoConvolveNotCorrectlyProvided.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_convolve_correctly_defined(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoConvolveNotCorrectlyProvided.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_vector_in_non_vector_declaration_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorInNonVectorDeclaration.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_vector_in_non_vector_declaration_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorInNonVectorDeclaration.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_vector_parameter_declaration(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorParameterDeclaration.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_vector_parameter_declaration(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorParameterDeclaration.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_vector_parameter_type(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorParameterType.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_vector_parameter_type(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorParameterType.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_vector_parameter_size(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorDeclarationSize.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_vector_parameter_size(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorDeclarationSize.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_convolve_correctly_parameterized(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoConvolveNotCorrectlyParametrized.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_convolve_correctly_parameterized(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoConvolveNotCorrectlyParametrized.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_invariant_correctly_typed(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInvariantNotBool.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_invariant_correctly_typed(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInvariantNotBool.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_expression_correctly_typed(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoIllegalExpression.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_expression_correctly_typed(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoIllegalExpression.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_compound_expression_correctly_typed(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CompoundOperatorWithDifferentButCompatibleUnits.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 10 + + def test_valid_compound_expression_correctly_typed(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CompoundOperatorWithDifferentButCompatibleUnits.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_ode_correctly_typed(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoOdeIncorrectlyTyped.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) > 0 + + def test_valid_ode_correctly_typed(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoOdeCorrectlyTyped.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_output_block_defined_if_emit_call(self): + """test that an error is raised when the emit_spike() function is called by the neuron, but an output block is not defined""" + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoOutputPortDefinedIfEmitCall.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) > 0 + + def test_invalid_output_port_defined_if_emit_call(self): + """test that an error is raised when the emit_spike() function is called by the neuron, but a spiking output port is not defined""" + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoOutputPortDefinedIfEmitCall-2.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) > 0 + + def test_valid_output_port_defined_if_emit_call(self): + """test that no error is raised when the output block is missing, but not emit_spike() functions are called""" + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoOutputPortDefinedIfEmitCall.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_valid_coco_kernel_type(self): + """ + Test the functionality of CoCoKernelType. + """ + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoKernelType.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_coco_kernel_type(self): + """ + Test the functionality of CoCoKernelType. + """ + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoKernelType.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_invalid_coco_kernel_type_initial_values(self): + """ + Test the functionality of CoCoKernelType. + """ + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoKernelTypeInitialValues.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 4 + + def test_valid_coco_state_variables_initialized(self): + """ + Test that the CoCo condition is applicable for all the variables in the state block initialized with a value + """ + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoStateVariablesInitialized.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_coco_state_variables_initialized(self): + """ + Test that the CoCo condition is applicable for all the variables in the state block not initialized + """ + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoStateVariablesInitialized.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_invalid_co_co_priorities_correctly_specified(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoPrioritiesCorrectlySpecified.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_co_co_priorities_correctly_specified(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoPrioritiesCorrectlySpecified.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_co_co_resolution_legally_used(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoResolutionLegallyUsed.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_co_co_resolution_legally_used(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoResolutionLegallyUsed.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_valid_co_co_vector_input_port(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorInputPortSizeAndType.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_co_co_vector_input_port(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorInputPortSizeAndType.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def _parse_and_validate_model(self, fname: str) -> Optional[str]: + from pynestml.frontend.pynestml_frontend import generate_target + + Logger.init_logger(LoggingLevel.DEBUG) + + try: + generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG") + except BaseException: + return None + + ast_compilation_unit = ModelParser.parse_file(fname) + if ast_compilation_unit is None or len(ast_compilation_unit.get_model_list()) == 0: + return None + + model: ASTModel = ast_compilation_unit.get_model_list()[0] + model_name = model.get_name() + + return model_name diff --git a/tests/test_function_parameter_templating.py b/tests/test_function_parameter_templating.py new file mode 100644 index 000000000..b93e06780 --- /dev/null +++ b/tests/test_function_parameter_templating.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# +# test_function_parameter_templating.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import os + +from pynestml.utils.logger import Logger, LoggingLevel +from pynestml.frontend.pynestml_frontend import generate_target + + +class TestFunctionParameterTemplating: + """ + This test is used to test the correct derivation of types when functions use templated type parameters. + """ + + def test(self): + fname = os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), "resources", "FunctionParameterTemplatingTest.nestml"))) + generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG") + assert len(Logger.get_all_messages_of_level_and_or_node("templated_function_parameters_type_test", LoggingLevel.ERROR)) == 5 diff --git a/tests/test_unit_system.py b/tests/test_unit_system.py new file mode 100644 index 000000000..2cad0b98d --- /dev/null +++ b/tests/test_unit_system.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +# +# test_unit_system.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import os +import pytest + +from pynestml.codegeneration.printers.constant_printer import ConstantPrinter +from pynestml.codegeneration.printers.cpp_expression_printer import CppExpressionPrinter +from pynestml.codegeneration.printers.cpp_simple_expression_printer import CppSimpleExpressionPrinter +from pynestml.codegeneration.printers.cpp_type_symbol_printer import CppTypeSymbolPrinter +from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter +from pynestml.codegeneration.printers.nest_cpp_function_call_printer import NESTCppFunctionCallPrinter +from pynestml.codegeneration.printers.nestml_variable_printer import NestMLVariablePrinter +from pynestml.frontend.pynestml_frontend import generate_target +from pynestml.symbol_table.symbol_table import SymbolTable +from pynestml.symbols.predefined_functions import PredefinedFunctions +from pynestml.symbols.predefined_types import PredefinedTypes +from pynestml.symbols.predefined_units import PredefinedUnits +from pynestml.symbols.predefined_variables import PredefinedVariables +from pynestml.utils.ast_source_location import ASTSourceLocation +from pynestml.utils.logger import Logger, LoggingLevel +from pynestml.utils.model_parser import ModelParser + + +class TestUnitSystem: + r""" + Test class for units system. + """ + + @pytest.fixture(scope="class", autouse=True) + def setUp(self, request): + Logger.set_logging_level(LoggingLevel.INFO) + + SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0)) + + PredefinedUnits.register_units() + PredefinedTypes.register_types() + PredefinedVariables.register_variables() + PredefinedFunctions.register_functions() + + Logger.init_logger(LoggingLevel.INFO) + + variable_printer = NestMLVariablePrinter(None) + function_call_printer = NESTCppFunctionCallPrinter(None) + cpp_variable_printer = CppVariablePrinter(None) + self.printer = CppExpressionPrinter(CppSimpleExpressionPrinter(cpp_variable_printer, + ConstantPrinter(), + function_call_printer)) + cpp_variable_printer._expression_printer = self.printer + variable_printer._expression_printer = self.printer + function_call_printer._expression_printer = self.printer + + request.cls.printer = self.printer + + def get_first_statement_in_update_block(self, model): + if model.get_model_list()[0].get_update_blocks()[0]: + return model.get_model_list()[0].get_update_blocks()[0].get_block().get_stmts()[0] + + return None + + def get_first_declaration_in_state_block(self, model): + assert len(model.get_model_list()[0].get_state_blocks()) == 1 + + return model.get_model_list()[0].get_state_blocks()[0].get_declarations()[0] + + def get_first_declared_function(self, model): + return model.get_model_list()[0].get_functions()[0] + + def print_rhs_of_first_assignment_in_update_block(self, model): + assignment = self.get_first_statement_in_update_block(model).small_stmt.get_assignment() + expression = assignment.get_expression() + + return self.printer.print(expression) + + def print_first_function_call_in_update_block(self, model): + function_call = self.get_first_statement_in_update_block(model).small_stmt.get_function_call() + + return self.printer.print(function_call) + + def print_rhs_of_first_declaration_in_state_block(self, model): + declaration = self.get_first_declaration_in_state_block(model) + expression = declaration.get_expression() + + return self.printer.print(expression) + + def print_first_return_statement_in_first_declared_function(self, model): + func = self.get_first_declared_function(model) + return_expression = func.get_block().get_stmts()[0].small_stmt.get_return_stmt().get_expression() + return self.printer.print(return_expression) + + def test_expression_after_magnitude_conversion_in_direct_assignment(self): + model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DirectAssignmentWithDifferentButCompatibleUnits.nestml')) + printed_rhs_expression = self.print_rhs_of_first_assignment_in_update_block(model) + + assert printed_rhs_expression == '(1000.0 * (10 * V))' + + def test_expression_after_nested_magnitude_conversion_in_direct_assignment(self): + model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DirectAssignmentWithDifferentButCompatibleNestedUnits.nestml')) + printed_rhs_expression = self.print_rhs_of_first_assignment_in_update_block(model) + + assert printed_rhs_expression == '(1000.0 * (10 * V + (0.001 * (5 * mV)) + 20 * V + (1000.0 * (1 * kV))))' + + def test_expression_after_magnitude_conversion_in_compound_assignment(self): + model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'CompoundAssignmentWithDifferentButCompatibleUnits.nestml')) + printed_rhs_expression = self.print_rhs_of_first_assignment_in_update_block(model) + + assert printed_rhs_expression == '(0.001 * (1200 * mV))' + + def test_expression_after_magnitude_conversion_in_declaration(self): + model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DeclarationWithDifferentButCompatibleUnitMagnitude.nestml')) + printed_rhs_expression = self.print_rhs_of_first_declaration_in_state_block(model) + + assert printed_rhs_expression == '(1000.0 * (10 * V))' + + def test_expression_after_type_conversion_in_declaration(self): + model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DeclarationWithDifferentButCompatibleUnits.nestml')) + declaration = self.get_first_declaration_in_state_block(model) + from astropy import units as u + + assert declaration.get_expression().type.unit.unit == u.mV + + def test_declaration_with_same_variable_name_as_unit(self): + Logger.init_logger(LoggingLevel.DEBUG) + + generate_target(input_path=os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DeclarationWithSameVariableNameAsUnit.nestml'), target_platform="NONE", logging_level="DEBUG") + + assert len(Logger.get_all_messages_of_level_and_or_node("BlockTest", LoggingLevel.ERROR)) == 0 + assert len(Logger.get_all_messages_of_level_and_or_node("BlockTest", LoggingLevel.WARNING)) == 3 + + def test_expression_after_magnitude_conversion_in_standalone_function_call(self): + model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'FunctionCallWithDifferentButCompatibleUnits.nestml')) + printed_function_call = self.print_first_function_call_in_update_block(model) + + assert printed_function_call == 'foo((1000.0 * (10 * V)))' + + def test_expression_after_magnitude_conversion_in_rhs_function_call(self): + model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'RhsFunctionCallWithDifferentButCompatibleUnits.nestml')) + printed_function_call = self.print_rhs_of_first_assignment_in_update_block(model) + + assert printed_function_call == 'foo((1000.0 * (10 * V)))' + + def test_return_stmt_after_magnitude_conversion_in_function_body(self): + model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'FunctionBodyReturnStatementWithDifferentButCompatibleUnits.nestml')) + printed_return_stmt = self.print_first_return_statement_in_first_declared_function(model) + + assert printed_return_stmt == '(0.001 * (bar))' diff --git a/tests/unit_system_test.py b/tests/unit_system_test.py deleted file mode 100644 index 1f7817b91..000000000 --- a/tests/unit_system_test.py +++ /dev/null @@ -1,177 +0,0 @@ -# -*- coding: utf-8 -*- -# -# unit_system_test.py -# -# This file is part of NEST. -# -# Copyright (C) 2004 The NEST Initiative -# -# NEST is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 2 of the License, or -# (at your option) any later version. -# -# NEST is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with NEST. If not, see . - -import os -import unittest -from pynestml.codegeneration.printers.constant_printer import ConstantPrinter - -from pynestml.codegeneration.printers.cpp_expression_printer import CppExpressionPrinter -from pynestml.codegeneration.printers.cpp_simple_expression_printer import CppSimpleExpressionPrinter -from pynestml.codegeneration.printers.cpp_type_symbol_printer import CppTypeSymbolPrinter -from pynestml.codegeneration.printers.nestml_variable_printer import NestMLVariablePrinter -from pynestml.symbol_table.symbol_table import SymbolTable -from pynestml.symbols.predefined_functions import PredefinedFunctions -from pynestml.symbols.predefined_types import PredefinedTypes -from pynestml.symbols.predefined_units import PredefinedUnits -from pynestml.symbols.predefined_variables import PredefinedVariables -from pynestml.utils.ast_source_location import ASTSourceLocation -from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter -from pynestml.codegeneration.printers.nest_cpp_function_call_printer import NESTCppFunctionCallPrinter -from pynestml.codegeneration.printers.cpp_function_call_printer import CppFunctionCallPrinter -from pynestml.utils.logger import Logger, LoggingLevel -from pynestml.utils.model_parser import ModelParser - - -SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0)) - -PredefinedUnits.register_units() -PredefinedTypes.register_types() -PredefinedVariables.register_variables() -PredefinedFunctions.register_functions() - -Logger.init_logger(LoggingLevel.INFO) - -type_symbol_printer = CppTypeSymbolPrinter() -variable_printer = NestMLVariablePrinter(None) -function_call_printer = NESTCppFunctionCallPrinter(None) -cpp_variable_printer = CppVariablePrinter(None) -printer = CppExpressionPrinter(CppSimpleExpressionPrinter(cpp_variable_printer, - ConstantPrinter(), - function_call_printer)) -cpp_variable_printer._expression_printer = printer -variable_printer._expression_printer = printer -function_call_printer._expression_printer = printer - - -def get_first_statement_in_update_block(model): - if model.get_model_list()[0].get_update_blocks()[0]: - return model.get_model_list()[0].get_update_blocks()[0].get_block().get_stmts()[0] - return None - - -def get_first_declaration_in_state_block(model): - assert len(model.get_model_list()[0].get_state_blocks()) == 1 - return model.get_model_list()[0].get_state_blocks()[0].get_declarations()[0] - - -def get_first_declared_function(model): - return model.get_model_list()[0].get_functions()[0] - - -def print_rhs_of_first_assignment_in_update_block(model): - assignment = get_first_statement_in_update_block(model).small_stmt.get_assignment() - expression = assignment.get_expression() - return printer.print(expression) - - -def print_first_function_call_in_update_block(model): - function_call = get_first_statement_in_update_block(model).small_stmt.get_function_call() - return printer.print(function_call) - - -def print_rhs_of_first_declaration_in_state_block(model): - declaration = get_first_declaration_in_state_block(model) - expression = declaration.get_expression() - return printer.print(expression) - - -def print_first_return_statement_in_first_declared_function(model): - func = get_first_declared_function(model) - return_expression = func.get_block().get_stmts()[0].small_stmt.get_return_stmt().get_expression() - return printer.print(return_expression) - - -class UnitSystemTest(unittest.TestCase): - """ - Test class for everything Unit related. - """ - - def setUp(self): - Logger.set_logging_level(LoggingLevel.INFO) - - def test_expression_after_magnitude_conversion_in_direct_assignment(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'DirectAssignmentWithDifferentButCompatibleUnits.nestml')) - printed_rhs_expression = print_rhs_of_first_assignment_in_update_block(model) - - self.assertEqual(printed_rhs_expression, '(1000.0 * (10 * V))') - - def test_expression_after_nested_magnitude_conversion_in_direct_assignment(self): - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'DirectAssignmentWithDifferentButCompatibleNestedUnits.nestml')) - printed_rhs_expression = print_rhs_of_first_assignment_in_update_block(model) - - self.assertEqual(printed_rhs_expression, '(1000.0 * (10 * V + (0.001 * (5 * mV)) + 20 * V + (1000.0 * (1 * kV))))') - - def test_expression_after_magnitude_conversion_in_compound_assignment(self): - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'CompoundAssignmentWithDifferentButCompatibleUnits.nestml')) - printed_rhs_expression = print_rhs_of_first_assignment_in_update_block(model) - self.assertEqual(printed_rhs_expression, '(0.001 * (1200 * mV))') - - def test_expression_after_magnitude_conversion_in_declaration(self): - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'DeclarationWithDifferentButCompatibleUnitMagnitude.nestml')) - printed_rhs_expression = print_rhs_of_first_declaration_in_state_block(model) - self.assertEqual(printed_rhs_expression, '(1000.0 * (10 * V))') - - def test_expression_after_type_conversion_in_declaration(self): - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'DeclarationWithDifferentButCompatibleUnits.nestml')) - declaration = get_first_declaration_in_state_block(model) - from astropy import units as u - self.assertTrue(declaration.get_expression().type.unit.unit == u.mV) - - def test_declaration_with_same_variable_name_as_unit(self): - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'DeclarationWithSameVariableNameAsUnit.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.WARNING)), 3) - - def test_expression_after_magnitude_conversion_in_standalone_function_call(self): - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'FunctionCallWithDifferentButCompatibleUnits.nestml')) - printed_function_call = print_first_function_call_in_update_block(model) - self.assertEqual(printed_function_call, 'foo((1000.0 * (10 * V)))') - - def test_expression_after_magnitude_conversion_in_rhs_function_call(self): - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'RhsFunctionCallWithDifferentButCompatibleUnits.nestml')) - printed_function_call = print_rhs_of_first_assignment_in_update_block(model) - self.assertEqual(printed_function_call, 'foo((1000.0 * (10 * V)))') - - def test_return_stmt_after_magnitude_conversion_in_function_body(self): - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'FunctionBodyReturnStatementWithDifferentButCompatibleUnits.nestml')) - printed_return_stmt = print_first_return_statement_in_first_declared_function(model) - self.assertEqual(printed_return_stmt, '(0.001 * (bar))')