Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Dec 23, 2024
1 parent fdc6312 commit e8b94b9
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
20 changes: 11 additions & 9 deletions dace/frontend/fortran/fortran_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,11 +817,12 @@ def forstmt2sdfg(self, node: ast_internal_classes.For_Stmt_Node, sdfg: SDFG, cfg

increment_expr = 'i+0+1'
if isinstance(node.iter, ast_internal_classes.BinOp_Node):
increment_expr = ast_utils.ProcessedWriter(sdfg,
self.name_mapping,
placeholders=self.placeholders,
placeholders_offsets=self.placeholders_offsets,
rename_dict=self.replace_names).write_code(node.iter.rval)
increment_rhs = ast_utils.ProcessedWriter(sdfg,
self.name_mapping,
placeholders=self.placeholders,
placeholders_offsets=self.placeholders_offsets,
rename_dict=self.replace_names).write_code(node.iter.rval)
increment_expr = f'{iter_name} = {increment_rhs}'

loop_region = LoopRegion(name, condition, iter_name, init_expr, increment_expr, inverted=False, sdfg=sdfg)

Expand Down Expand Up @@ -2106,17 +2107,17 @@ def subroutine2sdfg(self, node: ast_internal_classes.Subroutine_Subprogram_Node,
self.transient_mode = True
for j in node.specification_part.symbols:
if isinstance(j, ast_internal_classes.Symbol_Decl_Node):
self.symbol2sdfg(j, new_sdfg)
self.symbol2sdfg(j, new_sdfg, new_sdfg)
else:
raise NotImplementedError("Symbol not implemented")

for j in node.specification_part.specifications:
self.declstmt2sdfg(j, new_sdfg)
self.declstmt2sdfg(j, new_sdfg, new_sdfg)
self.transient_mode = old_mode

for i in assigns:
self.translate(i, new_sdfg)
self.translate(node.execution_part, new_sdfg)
self.translate(i, new_sdfg, new_sdfg)
self.translate(node.execution_part, new_sdfg, new_sdfg)

if self.multiple_sdfgs == True:
internal_sdfg.path = self.sdfg_path + new_sdfg.name + ".sdfg"
Expand Down Expand Up @@ -2886,6 +2887,7 @@ def create_sdfg_from_internal_ast(own_ast: ast_components.InternalFortranAst, pr
ast2sdfg.top_level = program
ast2sdfg.globalsdfg = g
ast2sdfg.translate(program, g, g)
g.reset_cfg_list()
g.apply_transformations(IntrinsicSDFGTransformation)
g.expand_library_nodes()
gmap[ep] = g
Expand Down
2 changes: 2 additions & 0 deletions dace/frontend/fortran/intrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,7 @@ def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_


class IntrinsicSDFGTransformation(xf.SingleStateTransformation):

array1 = xf.PatternNode(nodes.AccessNode)
array2 = xf.PatternNode(nodes.AccessNode)
tasklet = xf.PatternNode(nodes.Tasklet)
Expand Down Expand Up @@ -1409,6 +1410,7 @@ def apply(self, state: SDFGState, sdfg: SDFG):


class MathFunctions(IntrinsicTransformation):

MathTransformation = namedtuple("MathTransformation", "function return_type")
MathReplacement = namedtuple("MathReplacement", "function replacement_function return_type")

Expand Down
3 changes: 2 additions & 1 deletion tests/fortran/fortran_test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dace.frontend.fortran.ast_internal_classes import Name_Node
from dace.frontend.fortran.fortran_parser import ParseConfig, create_internal_ast, SDFGConfig, \
create_sdfg_from_internal_ast
from dace.sdfg.sdfg import SDFG


@dataclass
Expand Down Expand Up @@ -276,7 +277,7 @@ def NAMED(cls, name: str):
return cls(Name_Node, {'name': cls(has_value=name)})


def create_singular_sdfg_from_string(sources: Dict[str, str], entry_point: str, normalize_offsets: bool = True):
def create_singular_sdfg_from_string(sources: Dict[str, str], entry_point: str, normalize_offsets: bool = True) -> SDFG:
entry_point = entry_point.split('.')

cfg = ParseConfig(main=sources['main.f90'], sources=sources, entry_points=tuple(entry_point))
Expand Down

0 comments on commit e8b94b9

Please sign in to comment.