From d798f56722682353295539089ed3c283b1d2555d Mon Sep 17 00:00:00 2001 From: Nader Al Awar Date: Fri, 2 Aug 2024 11:34:49 -0500 Subject: [PATCH] Interface: add cyl_bessel_j0 and cyl_bessel_j1 as mathemtical functions inside kernels (only accept real values for now) --- pykokkos/core/visitors/workunit_visitor.py | 11 +++++++++++ pykokkos/interface/__init__.py | 3 +++ pykokkos/interface/mathematical_special_functions.py | 6 ++++++ 3 files changed, 20 insertions(+) create mode 100644 pykokkos/interface/mathematical_special_functions.py diff --git a/pykokkos/core/visitors/workunit_visitor.py b/pykokkos/core/visitors/workunit_visitor.py index 43439a93..ecb14317 100644 --- a/pykokkos/core/visitors/workunit_visitor.py +++ b/pykokkos/core/visitors/workunit_visitor.py @@ -307,6 +307,17 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr: return rand_call + if name in {"cyl_bessel_j0", "cyl_bessel_j1"}: + if len(args) != 1: + self.error(node, "pk.cyl_bessel_j0/j1 accepts only one argument") + + s = cppast.Serializer() + arg_str = s.serialize(args[0]) + math_call = cppast.CallExpr(cppast.DeclRefExpr(f"Kokkos::Experimental::{name}, double, int>"), args) + real_number_call = cppast.MemberCallExpr(math_call, cppast.DeclRefExpr("real"), []) + + return real_number_call + return super().visit_Call(node) def is_nested_call(self, node: ast.FunctionDef) -> bool: diff --git a/pykokkos/interface/__init__.py b/pykokkos/interface/__init__.py index d4dd9ff2..701ae3e9 100644 --- a/pykokkos/interface/__init__.py +++ b/pykokkos/interface/__init__.py @@ -32,6 +32,9 @@ from .hierarchical import ( AUTO, TeamMember, PerTeam, PerThread, single ) +from .mathematical_special_functions import ( + cyl_bessel_j0, cyl_bessel_j1 +) from .memory_space import MemorySpace, get_default_memory_space from .parallel_dispatch import ( execute, flush, diff --git a/pykokkos/interface/mathematical_special_functions.py b/pykokkos/interface/mathematical_special_functions.py new file mode 100644 index 00000000..b67be1e8 --- /dev/null +++ b/pykokkos/interface/mathematical_special_functions.py @@ -0,0 +1,6 @@ + +def cyl_bessel_j0(input: float) -> float: + pass + +def cyl_bessel_j1(input: float) -> float: + pass \ No newline at end of file