From 38ddba51c963a28959bb86dd5ad8074ce8428a68 Mon Sep 17 00:00:00 2001 From: Chad Mitchell Date: Thu, 5 Dec 2024 16:46:41 -0800 Subject: [PATCH 1/2] Draft of SmallMatrix Python bindings. --- src/python/CMakeLists.txt | 1 + src/python/SmallMatrix.cpp | 72 ++++++++++++++++++++++++++++++++++++++ src/python/pyImpactX.H | 1 + 3 files changed, 74 insertions(+) create mode 100644 src/python/SmallMatrix.cpp diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index d74fe1cc5..3265b6eb5 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -11,4 +11,5 @@ target_sources(pyImpactX ReferenceParticle.cpp transformation.cpp WakeConvolution.cpp + SmallMatrix.cpp ) diff --git a/src/python/SmallMatrix.cpp b/src/python/SmallMatrix.cpp new file mode 100644 index 000000000..052533551 --- /dev/null +++ b/src/python/SmallMatrix.cpp @@ -0,0 +1,72 @@ +/* Copyright 2021-2023 The ImpactX Community + * + * Authors: Ryan Sandberg, Axel Huebl + * License: BSD-3-Clause-LBNL + */ +#include "pyImpactX.H" +#include + +namespace py = pybind11; + +namespace pybind11 { +namespace detail { + +template +struct pybind11::detail::type_caster> { +public: + PYBIND11_TYPE_CASTER(amrex::SmallMatrix, + _("SmallMatrix[") + py::detail::make_caster::name() + _("]")); + + // Conversion from Python to C++ + bool load(handle src, bool) { + // Ensure we have a numpy array + py::array_t arr = py::cast>(src); + py::buffer_info buf = arr.request(); + + // Check dimensions and shape + if (buf.ndim != 2) { + throw std::runtime_error("SmallMatrix requires a 2D array."); + } + if (buf.shape[0] != NRows || buf.shape[1] != NCols) { + throw std::runtime_error("SmallMatrix array shape must match NRows x NCols."); + } + + // Create a SmallMatrix and copy data + amrex::SmallMatrix mat; + T* ptr = static_cast(buf.ptr); + for (int i = 0; i < NRows * NCols; ++i) { + mat.m_mat[i] = ptr[i]; + } + + value = mat; + return true; + } + + // Conversion from C++ to Python + static handle cast(const amrex::SmallMatrix& src, + return_value_policy /* policy */, handle /* parent */) { + py::array_t arr({NRows, NCols}); + py::buffer_info buf = arr.request(); + T* ptr = static_cast(buf.ptr); + for (int i = 0; i < NRows * NCols; ++i) { + ptr[i] = src.m_mat[i]; + } + return arr.release(); + } +}; + +} // namespace detail +} // namespace pybind11 + + +PYBIND11_MODULE(example, m) { + // You can now just bind constructors and methods normally without defining conversion code: + py::class_>(m, "SmallMatrix6x6") + .def(py::init<>()) // Default init + .def("as_array", [](const amrex::SmallMatrix& mat) { + return mat; // Will use type_caster to return a numpy array + }); + + // Now Python functions expecting a SmallMatrix can pass a numpy array directly: + // def some_func(mat: SmallMatrix6x6): ... +} diff --git a/src/python/pyImpactX.H b/src/python/pyImpactX.H index 4401cf81a..748c846bc 100644 --- a/src/python/pyImpactX.H +++ b/src/python/pyImpactX.H @@ -13,6 +13,7 @@ #include #include #include +#include #include From f7ed9272f9ef1964bc3f1462fe1bae26e9d7279d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Dec 2024 00:50:42 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/python/SmallMatrix.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/python/SmallMatrix.cpp b/src/python/SmallMatrix.cpp index 052533551..cf7330225 100644 --- a/src/python/SmallMatrix.cpp +++ b/src/python/SmallMatrix.cpp @@ -66,7 +66,7 @@ PYBIND11_MODULE(example, m) { .def("as_array", [](const amrex::SmallMatrix& mat) { return mat; // Will use type_caster to return a numpy array }); - + // Now Python functions expecting a SmallMatrix can pass a numpy array directly: // def some_func(mat: SmallMatrix6x6): ... }