Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add IndexType #139

Merged
merged 11 commits into from
Jul 5, 2023
1 change: 1 addition & 0 deletions src/Base/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ foreach(D IN LISTS AMReX_SPACEDIM)
DistributionMapping.cpp
FArrayBox.cpp
Geometry.cpp
IndexType.cpp
IntVect.cpp
RealVect.cpp
MultiFab.cpp
Expand Down
123 changes: 123 additions & 0 deletions src/Base/IndexType.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/* Copyright 2021-2022 The AMReX Community
*
* Authors: David Grote
* License: BSD-3-Clause-LBNL
*/
#include <AMReX_Config.H>
#include <AMReX_Dim3.H>
#include <AMReX_IntVect.H>
#include <AMReX_IndexType.H>

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>

#include <array>
#include <sstream>
#include <string>

namespace py = pybind11;
using namespace amrex;

namespace {
int check_index(const int i)
{
const int ii = (i >= 0) ? i : AMREX_SPACEDIM + i;
if ((ii < 0) || (ii >= AMREX_SPACEDIM))
throw py::index_error( "IndexType index " + std::to_string(i) + " out of bounds");
return ii;
}
}

void init_IndexType(py::module &m) {
py::class_< IndexType > index_type(m, "IndexType");
index_type.def("__repr__",
[](py::object& obj) {
py::str py_name = obj.attr("__class__").attr("__name__");
const std::string name = py_name;
const auto iv = obj.cast<IndexType>();
std::stringstream s;
s << iv;
return "<amrex." + name + " " + s.str() + ">";
}
)
.def("__str",
[](const IndexType& iv) {
std::stringstream s;
s << iv;
return s.str();
})

.def(py::init<>())
.def(py::init<IndexType>())
#if (AMREX_SPACEDIM > 1)
.def(py::init<AMREX_D_DECL(IndexType::CellIndex, IndexType::CellIndex, IndexType::CellIndex)>())
#endif

.def("__getitem__",
[](const IndexType& v, const int i) {
const int ii = check_index(i);
return v[ii];
})

.def("__len__", [](IndexType const &) { return AMREX_SPACEDIM; })
.def("__eq__",
py::overload_cast<const IndexType&>(&IndexType::operator==, py::const_))
.def("__ne__",
py::overload_cast<const IndexType&>(&IndexType::operator!=, py::const_))
.def("__lt__", &IndexType::operator<)

.def("set", [](IndexType& v, int i) {
const int ii = check_index(i);
v.set(ii);
})
.def("unset", [](IndexType& v, int i) {
const int ii = check_index(i);
v.unset(ii);
})
.def("test", [](const IndexType& v, int i) {
const int ii = check_index(i);
return v.test(ii);
})
.def("setall", &IndexType::setall)
.def("clear", &IndexType::clear)
.def("any", &IndexType::any)
.def("ok", &IndexType::ok)
.def("flip", [](IndexType& v, int i) {
const int ii = check_index(i);
v.flip(ii);
})

.def("cell_centered", py::overload_cast<>(&IndexType::cellCentered, py::const_))
.def("cell_centered", [](const IndexType& v, int i) {
const int ii = check_index(i);
return v.cellCentered(ii);
})
.def("node_centered", py::overload_cast<>(&IndexType::nodeCentered, py::const_))
.def("node_centered", [](const IndexType& v, int i) {
const int ii = check_index(i);
return v.nodeCentered(ii);
})

.def("set_type", [](IndexType& v, int i, IndexType::CellIndex t) {
const int ii = check_index(i);
v.setType(ii, t);
})
.def("ix_type", py::overload_cast<>(&IndexType::ixType, py::const_))
.def("ix_type", [](const IndexType& v, int i) {
const int ii = check_index(i);
return v.ixType(ii);
})
.def("to_IntVect", &IndexType::toIntVect)

.def_static("cell_type", &IndexType::TheCellType)
.def_static("node_type", &IndexType::TheNodeType)

;

py::enum_<IndexType::CellIndex>(index_type, "CellIndex")
.value("CELL", IndexType::CellIndex::CELL)
.value("NODE", IndexType::CellIndex::NODE)
.export_values();

}
3 changes: 3 additions & 0 deletions src/pyAMReX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void init_Dim3(py::module&);
void init_DistributionMapping(py::module&);
void init_FArrayBox(py::module&);
void init_Geometry(py::module&);
void init_IndexType(py::module &);
void init_IntVect(py::module &);
void init_RealVect(py::module &);
void init_AmrMesh(py::module &);
Expand Down Expand Up @@ -70,6 +71,7 @@ PYBIND11_MODULE(amrex_3d_pybind, m) {
Dim3
FArrayBox
IntVect
IndexType
RealVect
MultiFab
ParallelDescriptor
Expand All @@ -88,6 +90,7 @@ PYBIND11_MODULE(amrex_3d_pybind, m) {
init_Arena(m);
init_Dim3(m);
init_IntVect(m);
init_IndexType(m);
init_RealVect(m);
init_Periodicity(m);
init_Array4(m);
Expand Down
77 changes: 77 additions & 0 deletions tests/test_indextype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-

import pytest

import amrex.space3d as amr


@pytest.mark.skipif(amr.Config.spacedim != 1, reason="Requires AMREX_SPACEDIM = 1")
def test_indextype_1d():
obj = amr.IndexType(amr.IndexType.CellIndex.NODE)
assert obj.node_centered()
assert not obj.cell_centered()
with pytest.raises(IndexError):
obj[-2]
Dismissed Show dismissed Hide dismissed
with pytest.raises(IndexError):
obj[1]
Dismissed Show dismissed Hide dismissed


@pytest.mark.skipif(amr.Config.spacedim != 2, reason="Requires AMREX_SPACEDIM = 2")
def test_indextype_2d():
obj = amr.IndexType(amr.IndexType.CellIndex.NODE, amr.IndexType.CellIndex.CELL)
assert obj.node_centered(0)
assert obj.cell_centered(1)
assert obj.node_centered(-2)
assert obj.cell_centered(-1)

with pytest.raises(IndexError):
obj[-3]
Dismissed Show dismissed Hide dismissed
with pytest.raises(IndexError):
obj[2]
Dismissed Show dismissed Hide dismissed


@pytest.mark.skipif(amr.Config.spacedim != 3, reason="Requires AMREX_SPACEDIM = 3")
def test_indextype_3d():
obj = amr.IndexType(
amr.IndexType.CellIndex.NODE,
amr.IndexType.CellIndex.CELL,
amr.IndexType.CellIndex.NODE,
)

# Check indexing
assert obj.node_centered(0)
assert obj.cell_centered(1)
assert obj.node_centered(2)
assert obj.node_centered(-3)
assert obj.cell_centered(-2)
assert obj.node_centered(-1)
with pytest.raises(IndexError):
obj[-4]
Dismissed Show dismissed Hide dismissed
with pytest.raises(IndexError):
obj[3]
Dismissed Show dismissed Hide dismissed

# Check methods
obj.set(1)
assert obj.node_centered()
obj.unset(1)
assert not obj.node_centered()


def test_indextype_static():
cell = amr.IndexType.cell_type()
for i in range(amr.Config.spacedim):
assert not cell.test(i)

node = amr.IndexType.node_type()
for i in range(amr.Config.spacedim):
assert node[i]

assert cell == amr.IndexType.cell_type()
assert node == amr.IndexType.node_type()
assert cell < node


def test_indextype_conversions():
node = amr.IndexType.node_type()
assert node.ix_type() == amr.IntVect(1)
assert node.to_IntVect() == amr.IntVect(1)