diff --git a/src/Base/CMakeLists.txt b/src/Base/CMakeLists.txt index 2786e46d..1fa7fc51 100644 --- a/src/Base/CMakeLists.txt +++ b/src/Base/CMakeLists.txt @@ -13,6 +13,7 @@ foreach(D IN LISTS AMReX_SPACEDIM) DistributionMapping.cpp FArrayBox.cpp Geometry.cpp + IndexType.cpp IntVect.cpp RealVect.cpp MultiFab.cpp diff --git a/src/Base/IndexType.cpp b/src/Base/IndexType.cpp new file mode 100644 index 00000000..6a94cdc5 --- /dev/null +++ b/src/Base/IndexType.cpp @@ -0,0 +1,123 @@ +/* Copyright 2021-2022 The AMReX Community + * + * Authors: David Grote + * License: BSD-3-Clause-LBNL + */ +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace py = pybind11; +using namespace amrex; + +namespace { + int check_index(const int i) + { + const int ii = (i >= 0) ? i : AMREX_SPACEDIM + i; + if ((ii < 0) || (ii >= AMREX_SPACEDIM)) + throw py::index_error( "IndexType index " + std::to_string(i) + " out of bounds"); + return ii; + } +} + +void init_IndexType(py::module &m) { + py::class_< IndexType > index_type(m, "IndexType"); + index_type.def("__repr__", + [](py::object& obj) { + py::str py_name = obj.attr("__class__").attr("__name__"); + const std::string name = py_name; + const auto iv = obj.cast(); + std::stringstream s; + s << iv; + return ""; + } + ) + .def("__str", + [](const IndexType& iv) { + std::stringstream s; + s << iv; + return s.str(); + }) + + .def(py::init<>()) + .def(py::init()) +#if (AMREX_SPACEDIM > 1) + .def(py::init()) +#endif + + .def("__getitem__", + [](const IndexType& v, const int i) { + const int ii = check_index(i); + return v[ii]; + }) + + .def("__len__", [](IndexType const &) { return AMREX_SPACEDIM; }) + .def("__eq__", + py::overload_cast(&IndexType::operator==, py::const_)) + .def("__ne__", + py::overload_cast(&IndexType::operator!=, py::const_)) + .def("__lt__", &IndexType::operator<) + + .def("set", [](IndexType& v, int i) { + const int ii = check_index(i); + v.set(ii); + }) + .def("unset", [](IndexType& v, int i) { + const int ii = check_index(i); + v.unset(ii); + }) + .def("test", [](const IndexType& v, int i) { + const int ii = check_index(i); + return v.test(ii); + }) + .def("setall", &IndexType::setall) + .def("clear", &IndexType::clear) + .def("any", &IndexType::any) + .def("ok", &IndexType::ok) + .def("flip", [](IndexType& v, int i) { + const int ii = check_index(i); + v.flip(ii); + }) + + .def("cell_centered", py::overload_cast<>(&IndexType::cellCentered, py::const_)) + .def("cell_centered", [](const IndexType& v, int i) { + const int ii = check_index(i); + return v.cellCentered(ii); + }) + .def("node_centered", py::overload_cast<>(&IndexType::nodeCentered, py::const_)) + .def("node_centered", [](const IndexType& v, int i) { + const int ii = check_index(i); + return v.nodeCentered(ii); + }) + + .def("set_type", [](IndexType& v, int i, IndexType::CellIndex t) { + const int ii = check_index(i); + v.setType(ii, t); + }) + .def("ix_type", py::overload_cast<>(&IndexType::ixType, py::const_)) + .def("ix_type", [](const IndexType& v, int i) { + const int ii = check_index(i); + return v.ixType(ii); + }) + .def("to_IntVect", &IndexType::toIntVect) + + .def_static("cell_type", &IndexType::TheCellType) + .def_static("node_type", &IndexType::TheNodeType) + + ; + + py::enum_(index_type, "CellIndex") + .value("CELL", IndexType::CellIndex::CELL) + .value("NODE", IndexType::CellIndex::NODE) + .export_values(); + +} diff --git a/src/pyAMReX.cpp b/src/pyAMReX.cpp index eb232911..cb6aa542 100644 --- a/src/pyAMReX.cpp +++ b/src/pyAMReX.cpp @@ -28,6 +28,7 @@ void init_Dim3(py::module&); void init_DistributionMapping(py::module&); void init_FArrayBox(py::module&); void init_Geometry(py::module&); +void init_IndexType(py::module &); void init_IntVect(py::module &); void init_RealVect(py::module &); void init_AmrMesh(py::module &); @@ -70,6 +71,7 @@ PYBIND11_MODULE(amrex_3d_pybind, m) { Dim3 FArrayBox IntVect + IndexType RealVect MultiFab ParallelDescriptor @@ -88,6 +90,7 @@ PYBIND11_MODULE(amrex_3d_pybind, m) { init_Arena(m); init_Dim3(m); init_IntVect(m); + init_IndexType(m); init_RealVect(m); init_Periodicity(m); init_Array4(m); diff --git a/tests/test_indextype.py b/tests/test_indextype.py new file mode 100644 index 00000000..29ebc54a --- /dev/null +++ b/tests/test_indextype.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- + +import pytest + +import amrex.space3d as amr + + +@pytest.mark.skipif(amr.Config.spacedim != 1, reason="Requires AMREX_SPACEDIM = 1") +def test_indextype_1d(): + obj = amr.IndexType(amr.IndexType.CellIndex.NODE) + assert obj.node_centered() + assert not obj.cell_centered() + with pytest.raises(IndexError): + obj[-2] + with pytest.raises(IndexError): + obj[1] + + +@pytest.mark.skipif(amr.Config.spacedim != 2, reason="Requires AMREX_SPACEDIM = 2") +def test_indextype_2d(): + obj = amr.IndexType(amr.IndexType.CellIndex.NODE, amr.IndexType.CellIndex.CELL) + assert obj.node_centered(0) + assert obj.cell_centered(1) + assert obj.node_centered(-2) + assert obj.cell_centered(-1) + + with pytest.raises(IndexError): + obj[-3] + with pytest.raises(IndexError): + obj[2] + + +@pytest.mark.skipif(amr.Config.spacedim != 3, reason="Requires AMREX_SPACEDIM = 3") +def test_indextype_3d(): + obj = amr.IndexType( + amr.IndexType.CellIndex.NODE, + amr.IndexType.CellIndex.CELL, + amr.IndexType.CellIndex.NODE, + ) + + # Check indexing + assert obj.node_centered(0) + assert obj.cell_centered(1) + assert obj.node_centered(2) + assert obj.node_centered(-3) + assert obj.cell_centered(-2) + assert obj.node_centered(-1) + with pytest.raises(IndexError): + obj[-4] + with pytest.raises(IndexError): + obj[3] + + # Check methods + obj.set(1) + assert obj.node_centered() + obj.unset(1) + assert not obj.node_centered() + + +def test_indextype_static(): + cell = amr.IndexType.cell_type() + for i in range(amr.Config.spacedim): + assert not cell.test(i) + + node = amr.IndexType.node_type() + for i in range(amr.Config.spacedim): + assert node[i] + + assert cell == amr.IndexType.cell_type() + assert node == amr.IndexType.node_type() + assert cell < node + + +def test_indextype_conversions(): + node = amr.IndexType.node_type() + assert node.ix_type() == amr.IntVect(1) + assert node.to_IntVect() == amr.IntVect(1)