From ab81bafefc5d1641c597fd980e074eb5388fbc9d Mon Sep 17 00:00:00 2001 From: kennykos Date: Wed, 18 Dec 2024 11:27:43 -0600 Subject: [PATCH 1/3] Interface: add support for Kokkos Mathematical Special Functions Special functions defined in https://github.com/kokkos/kokkos/blob/9fa2a01747f0a30ca2723c8d7d0a22c95a05717a/core/src/Kokkos_MathematicalSpecialFunctions.hpp#L445 List of special functions: * expint1 * erfcx * cyl_bessel_j0 * cyl_bessel_y0 * cyl_bessel_i0 * cyl_bessel_k0 * cyl_bessel_j1 * cyl_bessel_y1 * cyl_bessel_i1 * cyl_bessel_k1 * cyl_bessel_h10 * cyl_bessel_h11 * cyl_bessel_h20 * cyl_bessel_h21 --- pykokkos/core/visitors/workunit_visitor.py | 17 ++++++++- pykokkos/interface/__init__.py | 5 ++- .../mathematical_special_functions.py | 37 ++++++++++++++++++- 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/pykokkos/core/visitors/workunit_visitor.py b/pykokkos/core/visitors/workunit_visitor.py index 3d52a7b..77e8742 100644 --- a/pykokkos/core/visitors/workunit_visitor.py +++ b/pykokkos/core/visitors/workunit_visitor.py @@ -307,9 +307,22 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr: return rand_call - if name in {"cyl_bessel_j0", "cyl_bessel_j1"}: + if name in {"expint1", "erfcx"}: if len(args) != 1: - self.error(node, "pk.cyl_bessel_j0/j1 accepts only one argument") + self.error(node, f"pk.{name}() accepts only one argument") + + s = cppast.Serializer() + math_call = cppast.CallExpr(cppast.DeclRefExpr(f"Kokkos::Experimental::{name}"), args) + + return math_call + + if name in { + "cyl_bessel_j0", "cyl_bessel_y0", "cyl_bessel_i0", "cyl_bessel_k0", + "cyl_bessel_j1", "cyl_bessel_y1", "cyl_bessel_i1", "cyl_bessel_k1", + "cyl_bessel_h10", "cyl_bessel_h11", "cyl_bessel_h20", "cyl_bessel_h21" + }: + if len(args) != 1: + self.error(node, f"pk.{name}() accepts only one argument") s = cppast.Serializer() arg_str = s.serialize(args[0]) diff --git a/pykokkos/interface/__init__.py b/pykokkos/interface/__init__.py index bf6f921..903188e 100644 --- a/pykokkos/interface/__init__.py +++ b/pykokkos/interface/__init__.py @@ -33,7 +33,10 @@ AUTO, TeamMember, PerTeam, PerThread, single ) from .mathematical_special_functions import ( - cyl_bessel_j0, cyl_bessel_j1 + expint1, erfcx, + cyl_bessel_j0, cyl_bessel_y0, cyl_bessel_k0, cyl_bessel_i0, + cyl_bessel_j1, cyl_bessel_y1, cyl_bessel_k1, cyl_bessel_i1, + cyl_bessel_h10, cyl_bessel_h11, cyl_bessel_h20, cyl_bessel_h21 ) from .memory_space import MemorySpace, get_default_memory_space from .parallel_dispatch import ( diff --git a/pykokkos/interface/mathematical_special_functions.py b/pykokkos/interface/mathematical_special_functions.py index b67be1e..eb4f498 100644 --- a/pykokkos/interface/mathematical_special_functions.py +++ b/pykokkos/interface/mathematical_special_functions.py @@ -1,6 +1,41 @@ +def expint1(input: float) -> float: + pass + +def erfc(input: float) -> float: + pass def cyl_bessel_j0(input: float) -> float: pass +def cyl_bessel_y0(input: float) -> float: + pass + +def cyl_bessel_i0(input: float) -> float: + pass + +def cyl_bessel_k0(input: float) -> float: + pass + def cyl_bessel_j1(input: float) -> float: - pass \ No newline at end of file + pass + +def cyl_bessel_y1(input: float) -> float: + pass + +def cyl_bessel_i1(input: float) -> float: + pass + +def cyl_bessel_k1(input: float) -> float: + pass + +def cyl_bessel_h10(input: float) -> float: + pass + +def cyl_bessel_h11(input: float) -> float: + pass + +def cyl_bessel_h20(input: float) -> float: + pass + +def cyl_bessel_h21(input: float) -> float: + pass From 838d7d83d68feeb06cb75502cfcc61f2f63c2f96 Mon Sep 17 00:00:00 2001 From: kennykos Date: Wed, 18 Dec 2024 11:39:26 -0600 Subject: [PATCH 2/3] Bug: update call name from erfc -> erfcx --- pykokkos/interface/mathematical_special_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pykokkos/interface/mathematical_special_functions.py b/pykokkos/interface/mathematical_special_functions.py index eb4f498..096f493 100644 --- a/pykokkos/interface/mathematical_special_functions.py +++ b/pykokkos/interface/mathematical_special_functions.py @@ -1,7 +1,7 @@ def expint1(input: float) -> float: pass -def erfc(input: float) -> float: +def erfcx(input: float) -> float: pass def cyl_bessel_j0(input: float) -> float: From a92b14b16285b0da622777f602583b3606f7db70 Mon Sep 17 00:00:00 2001 From: kennykos Date: Tue, 7 Jan 2025 12:08:39 -0600 Subject: [PATCH 3/3] Interface: function visitor to handle math special functions --- pykokkos/core/visitors/pykokkos_visitor.py | 17 +++++++++++++++++ pykokkos/core/visitors/visitors_util.py | 19 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/pykokkos/core/visitors/pykokkos_visitor.py b/pykokkos/core/visitors/pykokkos_visitor.py index 8a8606b..f349b77 100644 --- a/pykokkos/core/visitors/pykokkos_visitor.py +++ b/pykokkos/core/visitors/pykokkos_visitor.py @@ -427,6 +427,23 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr: if visitors_util.is_math_function(name) or name in ["printf", "abs", "Kokkos::PerTeam", "Kokkos::PerThread", "Kokkos::fence"]: return cppast.CallExpr(function, args) + if visitors_util.is_math_special_function(name): + if name in ["expint1", "erfcx"]: + if len(args) != 1: + self.error(node, f"pk.{name}() accepts only one argument") + s = cppast.Serializer() + math_call = cppast.CallExpr(cppast.DeclRefExpr(f"Kokkos::Experimental::{name}"), args) + + return math_call + else: + if len(args) != 1: + self.error(node, f"pk.{name}() accepts only one argument") + s = cppast.Serializer() + arg_str = s.serialize(args[0]) + math_call = cppast.CallExpr(cppast.DeclRefExpr(f"Kokkos::Experimental::{name}, double, int>"), args) + real_number_call = cppast.MemberCallExpr(math_call, cppast.DeclRefExpr("real"), []) + return real_number_call + if function in self.kokkos_functions: if "PK_RESTRICT" in os.environ: return adjust_kokkos_function_call(function, args, self.restrict_views, self.views) diff --git a/pykokkos/core/visitors/visitors_util.py b/pykokkos/core/visitors/visitors_util.py index bd4c907..9fc10ed 100644 --- a/pykokkos/core/visitors/visitors_util.py +++ b/pykokkos/core/visitors/visitors_util.py @@ -116,6 +116,23 @@ def pretty_print(node): "nan", } +math_special_functions: Set = { + "expint1", + "erfcx", + "cyl_bessel_j0", + "cyl_bessel_y0", + "cyl_bessel_i0", + "cyl_bessel_k0", + "cyl_bessel_j1", + "cyl_bessel_y1", + "cyl_bessel_i1", + "cyl_bessel_k1", + "cyl_bessel_h10", + "cyl_bessel_h11", + "cyl_bessel_h20", + "cyl_bessel_h21", +} + math_constants: Dict[str, str] = { "e": "M_E", "pi": "M_PI", @@ -171,6 +188,8 @@ def get_allowed_type_str(python_type: str) -> str: def is_math_function(function: str) -> bool: return function in math_functions +def is_math_special_function(function: str) -> bool: + return function in math_special_functions def get_node_name(node: Union[ast.Attribute, ast.Name]) -> str: name: str = ""