Skip to content

Commit

Permalink
add include statement
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed Oct 7, 2024
1 parent ac766f0 commit eecc955
Show file tree
Hide file tree
Showing 11 changed files with 310 additions and 78 deletions.
9 changes: 9 additions & 0 deletions pynestml/codegeneration/printers/nestml_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from pynestml.meta_model.ast_function_call import ASTFunctionCall
from pynestml.meta_model.ast_if_clause import ASTIfClause
from pynestml.meta_model.ast_if_stmt import ASTIfStmt
from pynestml.meta_model.ast_include_stmt import ASTIncludeStmt
from pynestml.meta_model.ast_input_block import ASTInputBlock
from pynestml.meta_model.ast_input_port import ASTInputPort
from pynestml.meta_model.ast_input_qualifier import ASTInputQualifier
Expand Down Expand Up @@ -443,6 +444,11 @@ def print_return_stmt(self, node: ASTReturnStmt):
ret += "return " + (self.print(node.get_expression()) if node.has_expression() else "")
return ret

def print_include_stmt(self, node: ASTIncludeStmt):
ret = print_n_spaces(self.indent)
ret += "include \"" + node.get_filename() + "\""
return ret

def print_simple_expression(self, node: ASTSimpleExpression) -> str:
if node.is_function_call():
return self.print(node.function_call)
Expand Down Expand Up @@ -481,7 +487,10 @@ def print_small_stmt(self, node: ASTSmallStmt) -> str:
ret += print_sl_comment(node.in_comment) + "\n"
elif node.is_declaration():
ret = self.print(node.get_declaration())
elif node.is_include_stmt():
ret = self.print(node.get_include_stmt())
else:
assert node.is_return_stmt()
ret = self.print(node.get_return_stmt())
return ret

Expand Down
5 changes: 5 additions & 0 deletions pynestml/frontend/pynestml_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from pynestml.utils.logger import Logger, LoggingLevel
from pynestml.utils.messages import Messages
from pynestml.utils.model_parser import ModelParser
from pynestml.visitors.ast_include_statement_visitor import ASTIncludeStatementVisitor
from pynestml.visitors.ast_parent_visitor import ASTParentVisitor
from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor

Expand Down Expand Up @@ -435,6 +436,10 @@ def get_parsed_models() -> List[ASTModel]:
CoCosManager.check_model_names_unique(compilation_unit)
models.extend(compilation_unit.get_model_list())

# swap include statements for included file
for model in models:
model.accept(ASTIncludeStatementVisitor(os.path.dirname(model.file_path)))

# check that no models with duplicate names have been defined
CoCosManager.check_no_duplicate_compilation_unit_names(models)

Expand Down
10 changes: 5 additions & 5 deletions pynestml/grammars/PyNestMLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ parser grammar PyNestMLParser;
| functionCall
| declaration
| returnStmt
| includeStatement) NEWLINE;
| includeStmt) NEWLINE;

assignment : lhs_variable=variable
(directAssignment=EQUALS |
Expand All @@ -153,8 +153,8 @@ parser grammar PyNestMLParser;
compoundQuotient=FORWARD_SLASH_EQUALS)
expression;

includeStatement : INCLUDE_KEYWORD STRING_LITERAL;
includeStatement_newline : includeStatement NEWLINE;
include : INCLUDE_KEYWORD STRING_LITERAL;
includeStmt_newline : includeStmt NEWLINE;

/** ASTDeclaration A variable declaration. It can be a simple declaration defining one or multiple variables:
'a,b,c real = 0'. Or an function declaration 'function a = b + c'.
Expand Down Expand Up @@ -237,7 +237,7 @@ parser grammar PyNestMLParser;
@attribute function: A block declaring a user-defined function.
*/
modelBody: COLON
NEWLINE INDENT ( includeStatement_newline | blockWithVariables | equationsBlock | inputBlock | outputBlock | function | onReceiveBlock | onConditionBlock | updateBlock )+ DEDENT;
NEWLINE INDENT ( includeStmt_newline | blockWithVariables | equationsBlock | inputBlock | outputBlock | function | onReceiveBlock | onConditionBlock | updateBlock )+ DEDENT;

/** ASTOnReceiveBlock
@attribute block implementation of the dynamics
Expand All @@ -263,7 +263,7 @@ parser grammar PyNestMLParser;
blockWithVariables:
blockType=(STATE_KEYWORD | PARAMETERS_KEYWORD | INTERNALS_KEYWORD)
COLON
NEWLINE INDENT (includeStatement_newline | declaration_newline)+ DEDENT;
NEWLINE INDENT (includeStmt_newline | declaration_newline)+ DEDENT;

/** ASTUpdateBlock The definition of a block where the dynamical behavior of the neuron is stated:
update:
Expand Down
75 changes: 75 additions & 0 deletions pynestml/meta_model/ast_include_stmt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
#
# ast_include_stmt.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

from typing import List

from pynestml.meta_model.ast_node import ASTNode


class ASTIncludeStmt(ASTNode):
"""
This class is used to store an include statement.
"""

def __init__(self, filename: str, *args, **kwargs):
"""
Standard constructor.
Parameters for superclass (ASTNode) can be passed through :python:`*args` and :python:`**kwargs`.
:param filename: the filename of the included file (can be a single file name, or file path)
"""
super(ASTIncludeStmt, self).__init__(*args, **kwargs)
self.filename = filename

def clone(self):
"""
Return a clone ("deep copy") of this node.
:return: new AST node instance
:rtype: ASTIncludeStmt
"""
dup = ASTIncludeStmt(filename=self.filename,
# ASTNode common attributes:
source_position=self.source_position,
scope=self.scope,
comment=self.comment,
pre_comments=[s for s in self.pre_comments],
in_comment=self.in_comment,
implicit_conversion_factor=self.implicit_conversion_factor)

return dup

def get_filename(self) -> str:
"""
Returns the filename of the included file (can be a single file name, or file path).
:return: filename
"""
return self.filename

def equals(self, other: ASTNode) -> bool:
r"""
The equality method.
"""
if not isinstance(other, ASTIncludeStmt):
return False

return self.get_filename().equals(other.get_filename())
46 changes: 35 additions & 11 deletions pynestml/meta_model/ast_small_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,46 +19,44 @@
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

from typing import List
from typing import List, Optional
from pynestml.meta_model.ast_assignment import ASTAssignment
from pynestml.meta_model.ast_declaration import ASTDeclaration
from pynestml.meta_model.ast_function_call import ASTFunctionCall
from pynestml.meta_model.ast_include_stmt import ASTIncludeStmt

from pynestml.meta_model.ast_node import ASTNode
from pynestml.meta_model.ast_return_stmt import ASTReturnStmt


class ASTSmallStmt(ASTNode):
"""
This class is used to store small statements, e.g., a declaration.
Grammar:
smallStmt : assignment
| functionCall
| declaration
| returnStmt;
Attributes:
assignment (ast_assignment): A assignment reference.
function_call (ast_function_call): A function call reference.
declaration (ast_declaration): A declaration reference.
return_stmt (ast_return_stmt): A reference to the returns statement.
"""

def __init__(self, assignment=None, function_call=None, declaration=None, return_stmt=None, *args, **kwargs):
def __init__(self, assignment: Optional[ASTAssignment] = None, function_call: Optional[ASTFunctionCall] = None, declaration: Optional[ASTDeclaration] = None, return_stmt: Optional[ASTReturnStmt] = None, include_stmt: Optional[ASTIncludeStmt] = None, *args, **kwargs):
"""
Standard constructor.
Parameters for superclass (ASTNode) can be passed through :python:`*args` and :python:`**kwargs`.
:param assignment: an meta_model-assignment object.
:type assignment: ASTAssignment
:param function_call: an meta_model-function call object.
:type function_call: ASTFunctionCall
:param declaration: an meta_model-declaration object.
:type declaration: ASTDeclaration
:param return_stmt: an meta_model-return statement object.
:type return_stmt: ASTReturnStmt
"""
super(ASTSmallStmt, self).__init__(*args, **kwargs)
self.assignment = assignment
self.function_call = function_call
self.declaration = declaration
self.return_stmt = return_stmt
self.include_stmt = include_stmt

def clone(self):
"""
Expand All @@ -79,10 +77,14 @@ def clone(self):
return_stmt_dup = None
if self.return_stmt:
return_stmt_dup = self.return_stmt.clone()
include_stmt_dup = None
if self.include_stmt:
include_stmt_dup = self.include_stmt.clone()
dup = ASTSmallStmt(assignment=assignment_dup,
function_call=function_call_dup,
declaration=declaration_dup,
return_stmt=return_stmt_dup,
include_stmt=include_stmt_dup,
# ASTNode common attributes:
source_position=self.source_position,
scope=self.scope,
Expand Down Expand Up @@ -157,6 +159,20 @@ def get_return_stmt(self):
"""
return self.return_stmt

def is_include_stmt(self) -> bool:
"""
Returns whether it is a include statement or not.
:return: True if include stmt, False else.
"""
return self.include_stmt is not None

def get_include_stmt(self) -> Optional[ASTIncludeStmt]:
"""
Returns the include statement.
:return: the include statement.
"""
return self.include_stmt

def get_children(self) -> List[ASTNode]:
r"""
Returns the children of this node, if any.
Expand All @@ -174,6 +190,9 @@ def get_children(self) -> List[ASTNode]:
if self.is_return_stmt():
return [self.get_return_stmt()]

if self.is_include_stmt():
return [self.get_include_stmt()]

return []

def equals(self, other: ASTNode) -> bool:
Expand Down Expand Up @@ -201,4 +220,9 @@ def equals(self, other: ASTNode) -> bool:
if self.is_return_stmt() and other.is_return_stmt() and not self.get_return_stmt().equals(
other.get_return_stmt()):
return False
if self.is_include_stmt() + other.is_include_stmt() == 1:
return False
if self.is_include_stmt() and other.is_include_stmt() and not self.get_include_stmt().equals(
other.get_include_stmt()):
return False
return True
Loading

0 comments on commit eecc955

Please sign in to comment.