diff --git a/pykokkos/core/visitors/workunit_visitor.py b/pykokkos/core/visitors/workunit_visitor.py index 43439a93..ecb14317 100644 --- a/pykokkos/core/visitors/workunit_visitor.py +++ b/pykokkos/core/visitors/workunit_visitor.py @@ -307,6 +307,17 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr: return rand_call + if name in {"cyl_bessel_j0", "cyl_bessel_j1"}: + if len(args) != 1: + self.error(node, "pk.cyl_bessel_j0/j1 accepts only one argument") + + s = cppast.Serializer() + arg_str = s.serialize(args[0]) + math_call = cppast.CallExpr(cppast.DeclRefExpr(f"Kokkos::Experimental::{name}, double, int>"), args) + real_number_call = cppast.MemberCallExpr(math_call, cppast.DeclRefExpr("real"), []) + + return real_number_call + return super().visit_Call(node) def is_nested_call(self, node: ast.FunctionDef) -> bool: diff --git a/pykokkos/interface/__init__.py b/pykokkos/interface/__init__.py index d4dd9ff2..701ae3e9 100644 --- a/pykokkos/interface/__init__.py +++ b/pykokkos/interface/__init__.py @@ -32,6 +32,9 @@ from .hierarchical import ( AUTO, TeamMember, PerTeam, PerThread, single ) +from .mathematical_special_functions import ( + cyl_bessel_j0, cyl_bessel_j1 +) from .memory_space import MemorySpace, get_default_memory_space from .parallel_dispatch import ( execute, flush, diff --git a/pykokkos/interface/mathematical_special_functions.py b/pykokkos/interface/mathematical_special_functions.py new file mode 100644 index 00000000..b67be1e8 --- /dev/null +++ b/pykokkos/interface/mathematical_special_functions.py @@ -0,0 +1,6 @@ + +def cyl_bessel_j0(input: float) -> float: + pass + +def cyl_bessel_j1(input: float) -> float: + pass \ No newline at end of file