Skip to content

Commit 238a289

Browse files
sclaus2Claus Susannemscroggs
authored
Add void* to tabulate_tensor kernel (#753)
* add void* to tabulate_tensor * try to trigger CI * try to trigger CI * run ruff * rename user_data custom_data * add numba functions to obtain empty void* and conversion of numpy array to void* * fix ruff check * add line to remove noqa * expand comment for numba intrinsic function * add test to use a struct in C-function similar to tabulate_tensor using void* * changes to custom data test for CI * specify void* branch for dolfinx test in github actions * trying to set dolfinx refs for ffcx testing for pull request * incorporate review suggestions * add void* argument to test_ds_prisms --------- Co-authored-by: Claus Susanne <[email protected]> Co-authored-by: Matthew Scroggs <[email protected]>
1 parent 27dce3b commit 238a289

11 files changed

+245
-8
lines changed

.github/workflows/dolfinx-tests.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ on:
1010
inputs:
1111
dolfinx_ref:
1212
description: "DOLFINx branch or tag"
13-
default: "main"
13+
default: "sclaus2/add-void-to-kernels"
1414
type: string
1515
basix_ref:
1616
description: "Basix branch or tag"
@@ -54,7 +54,7 @@ jobs:
5454
with:
5555
path: ./dolfinx
5656
repository: FEniCS/dolfinx
57-
ref: main
57+
ref: sclaus2/add-void-to-kernels
5858
- name: Get DOLFINx source (specified branch/tag)
5959
if: github.event_name == 'workflow_dispatch'
6060
uses: actions/checkout@v4

ffcx/codegeneration/C/expressions_template.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
const {scalar_type}* restrict c,
2424
const {geom_type}* restrict coordinate_dofs,
2525
const int* restrict entity_local_index,
26-
const uint8_t* restrict quadrature_permutation)
26+
const uint8_t* restrict quadrature_permutation,
27+
void* custom_data)
2728
{{
2829
{tabulate_expression}
2930
}}

ffcx/codegeneration/C/integrals_template.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
const {scalar_type}* restrict c,
1717
const {geom_type}* restrict coordinate_dofs,
1818
const int* restrict entity_local_index,
19-
const uint8_t* restrict quadrature_permutation)
19+
const uint8_t* restrict quadrature_permutation,
20+
void* custom_data)
2021
{{
2122
{tabulate_tensor}
2223
}}

ffcx/codegeneration/ufcx.h

+11-4
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,15 @@ extern "C"
8686
/// For integrals not on interior facets, this argument has no effect and a
8787
/// null pointer can be passed. For interior facets the array will have size 2
8888
/// (one permutation for each cell adjacent to the facet).
89+
/// @param[in] custom_data Custom user data passed to the tabulate function.
90+
/// For example, a struct with additional data needed for the tabulate function.
91+
/// See the implementation of runtime integrals for further details.
8992
typedef void(ufcx_tabulate_tensor_float32)(
9093
float* restrict A, const float* restrict w, const float* restrict c,
9194
const float* restrict coordinate_dofs,
9295
const int* restrict entity_local_index,
93-
const uint8_t* restrict quadrature_permutation);
96+
const uint8_t* restrict quadrature_permutation,
97+
void* custom_data);
9498

9599
/// Tabulate integral into tensor A with compiled
96100
/// quadrature rule and double precision
@@ -100,7 +104,8 @@ extern "C"
100104
double* restrict A, const double* restrict w, const double* restrict c,
101105
const double* restrict coordinate_dofs,
102106
const int* restrict entity_local_index,
103-
const uint8_t* restrict quadrature_permutation);
107+
const uint8_t* restrict quadrature_permutation,
108+
void* custom_data);
104109

105110
#ifndef __STDC_NO_COMPLEX__
106111
/// Tabulate integral into tensor A with compiled
@@ -111,7 +116,8 @@ extern "C"
111116
float _Complex* restrict A, const float _Complex* restrict w,
112117
const float _Complex* restrict c, const float* restrict coordinate_dofs,
113118
const int* restrict entity_local_index,
114-
const uint8_t* restrict quadrature_permutation);
119+
const uint8_t* restrict quadrature_permutation,
120+
void* custom_data);
115121
#endif // __STDC_NO_COMPLEX__
116122

117123
#ifndef __STDC_NO_COMPLEX__
@@ -123,7 +129,8 @@ extern "C"
123129
double _Complex* restrict A, const double _Complex* restrict w,
124130
const double _Complex* restrict c, const double* restrict coordinate_dofs,
125131
const int* restrict entity_local_index,
126-
const uint8_t* restrict quadrature_permutation);
132+
const uint8_t* restrict quadrature_permutation,
133+
void* custom_data);
127134
#endif // __STDC_NO_COMPLEX__
128135

129136
typedef struct ufcx_integral

ffcx/codegeneration/utils.py

+78
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
import numpy as np
1111
import numpy.typing as npt
1212

13+
try:
14+
import numba
15+
except ImportError:
16+
numba = None
17+
1318

1419
def dtype_to_c_type(dtype: typing.Union[npt.DTypeLike, str]) -> str:
1520
"""For a NumPy dtype, return the corresponding C type.
@@ -80,6 +85,79 @@ def numba_ufcx_kernel_signature(dtype: npt.DTypeLike, xdtype: npt.DTypeLike):
8085
types.CPointer(from_dtype(xdtype)),
8186
types.CPointer(types.intc),
8287
types.CPointer(types.uint8),
88+
types.CPointer(types.void),
8389
)
8490
except ImportError as e:
8591
raise e
92+
93+
94+
if numba is not None:
95+
96+
@numba.extending.intrinsic
97+
def empty_void_pointer(typingctx):
98+
"""Custom intrinsic to return an empty void* pointer.
99+
100+
This function creates a void pointer initialized to null (0).
101+
This is used to pass a nullptr to the UFCx tabulate_tensor interface.
102+
103+
Args:
104+
typingctx: The typing context.
105+
106+
Returns:
107+
A Numba signature and a code generation function that returns a void pointer.
108+
"""
109+
110+
def codegen(context, builder, signature, args):
111+
null_ptr = context.get_constant(numba.types.voidptr, 0)
112+
return null_ptr
113+
114+
sig = numba.types.voidptr()
115+
return sig, codegen
116+
117+
@numba.extending.intrinsic
118+
def get_void_pointer(typingctx, arr):
119+
"""Custom intrinsic to get a void* pointer from a NumPy array.
120+
121+
This function takes a NumPy array and returns a void pointer to the array's data.
122+
This is used to pass custom data organised in a NumPy array
123+
to the UFCx tabulate_tensor interface.
124+
125+
Args:
126+
typingctx: The typing context.
127+
arr: The NumPy array to get the void pointer to the first element from.
128+
In a multi-dimensional NumPy array, the memory is laid out in a contiguous
129+
block of memory, see
130+
https://numpy.org/doc/stable/reference/arrays.ndarray.html#internal-memory-layout-of-an-ndarray
131+
132+
Returns:
133+
sig: A Numba signature, which specifies the numba type (here voidptr),
134+
codegen: A code generation function, which returns the LLVM IR to cast
135+
the raw data pointer to the first element of the of the contiguous block of memory
136+
of the NumPy array to void*.
137+
"""
138+
if not isinstance(arr, numba.types.Array):
139+
raise TypeError("Expected a NumPy array")
140+
141+
def codegen(context, builder, signature, args):
142+
"""Generate LLVM IR code to convert a NumPy array to a void* pointer.
143+
144+
This function generates the necessary LLVM IR instructions to:
145+
1. Allocate memory for the array on the stack.
146+
2. Cast the allocated memory to a void* pointer.
147+
148+
Args:
149+
context: The LLVM context.
150+
builder: The LLVM IR builder.
151+
signature: The function signature.
152+
args: The input arguments (NumPy array).
153+
154+
Returns:
155+
A void* pointer to the array's data.
156+
"""
157+
[arr] = args
158+
raw_ptr = numba.core.cgutils.alloca_once_value(builder, arr)
159+
void_ptr = builder.bitcast(raw_ptr, context.get_value_type(numba.types.voidptr))
160+
return void_ptr
161+
162+
sig = numba.types.voidptr(arr)
163+
return sig, codegen

test/test_add_mode.py

+3
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def test_additive_facet_integral(dtype, compile_args):
8686
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
8787
ffi.cast("int *", facets.ctypes.data),
8888
ffi.cast("uint8_t *", perm.ctypes.data),
89+
ffi.NULL,
8990
)
9091
assert np.isclose(A.sum(), np.sqrt(12) * (i + 1))
9192

@@ -158,6 +159,7 @@ def test_additive_cell_integral(dtype, compile_args):
158159
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
159160
ffi.NULL,
160161
ffi.NULL,
162+
ffi.NULL,
161163
)
162164

163165
A0 = np.array(A)
@@ -169,6 +171,7 @@ def test_additive_cell_integral(dtype, compile_args):
169171
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
170172
ffi.NULL,
171173
ffi.NULL,
174+
ffi.NULL,
172175
)
173176

174177
assert np.all(np.isclose(A, (i + 2) * A0))

test/test_custom_data.py

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (C) 2024 Susanne Claus
2+
#
3+
# This file is part of FFCx. (https://www.fenicsproject.org)
4+
#
5+
# SPDX-License-Identifier: LGPL-3.0-or-later
6+
7+
import numpy as np
8+
import pytest
9+
10+
11+
def test_tabulate_tensor_integral_add_values():
12+
pytest.importorskip("cffi")
13+
14+
from cffi import FFI
15+
16+
# Define custom tabulate tensor function in C with a struct
17+
# Step 1: Define the function in C and set up the CFFI builder
18+
ffibuilder = FFI()
19+
ffibuilder.set_source(
20+
"_cffi_kernelA",
21+
r"""
22+
typedef struct {
23+
size_t size;
24+
double* values;
25+
} cell_data;
26+
27+
void tabulate_tensor_integral_add_values(double* restrict A,
28+
const double* restrict w,
29+
const double* restrict c,
30+
const double* restrict coordinate_dofs,
31+
const int* restrict entity_local_index,
32+
const uint8_t* restrict quadrature_permutation,
33+
void* custom_data)
34+
{
35+
// Cast the void* custom_data to cell_data*
36+
cell_data* custom_data_ptr = (cell_data*)custom_data;
37+
38+
// Access the custom data
39+
size_t size = custom_data_ptr->size;
40+
double* values = custom_data_ptr->values;
41+
42+
// Use the values in your computations
43+
for (size_t i = 0; i < size; i++) {
44+
A[0] += values[i];
45+
}
46+
}
47+
""",
48+
)
49+
ffibuilder.cdef(
50+
"""
51+
typedef struct {
52+
size_t size;
53+
double* values;
54+
} cell_data;
55+
56+
void tabulate_tensor_integral_add_values(double* restrict A,
57+
const double* restrict w,
58+
const double* restrict c,
59+
const double* restrict coordinate_dofs,
60+
const int* restrict entity_local_index,
61+
const uint8_t* restrict quadrature_permutation,
62+
void* custom_data);
63+
"""
64+
)
65+
66+
# Step 2: Compile the C code
67+
ffibuilder.compile(verbose=True)
68+
69+
# Step 3: Import the compiled library
70+
from _cffi_kernelA import ffi, lib
71+
72+
# Define cell data
73+
values = np.array([2.0, 1.0], dtype=np.float64)
74+
size = len(values)
75+
expected_result = np.array([3.0], dtype=np.float64)
76+
77+
# Define the input arguments
78+
A = np.zeros(1, dtype=np.float64)
79+
w = np.array([1.0], dtype=np.float64)
80+
c = np.array([0.0], dtype=np.float64)
81+
coordinate_dofs = np.array(
82+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0], dtype=np.float64
83+
)
84+
entity_local_index = np.array([0], dtype=np.int32)
85+
quadrature_permutation = np.array([0], dtype=np.uint8)
86+
87+
# Cast the arguments to the appropriate C types
88+
A_ptr = ffi.cast("double*", A.ctypes.data)
89+
w_ptr = ffi.cast("double*", w.ctypes.data)
90+
c_ptr = ffi.cast("double*", c.ctypes.data)
91+
coordinate_dofs_ptr = ffi.cast("double*", coordinate_dofs.ctypes.data)
92+
entity_local_index_ptr = ffi.cast("int*", entity_local_index.ctypes.data)
93+
quadrature_permutation_ptr = ffi.cast("uint8_t*", quadrature_permutation.ctypes.data)
94+
95+
# Use ffi.from_buffer to create a CFFI pointer from the NumPy array
96+
values_ptr = ffi.cast("double*", values.ctypes.data)
97+
98+
# Allocate memory for the struct
99+
custom_data = ffi.new("cell_data*")
100+
custom_data.size = size
101+
custom_data.values = values_ptr
102+
103+
# Cast the struct to void*
104+
custom_data_ptr = ffi.cast("void*", custom_data)
105+
106+
# Call the function
107+
lib.tabulate_tensor_integral_add_values(
108+
A_ptr,
109+
w_ptr,
110+
c_ptr,
111+
coordinate_dofs_ptr,
112+
entity_local_index_ptr,
113+
quadrature_permutation_ptr,
114+
custom_data_ptr,
115+
)
116+
117+
# Assert the result
118+
np.testing.assert_allclose(A, expected_result, rtol=1e-5)

test/test_jit_expression.py

+7
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def test_matvec(compile_args):
6464
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
6565
ffi.cast("int *", entity_index.ctypes.data),
6666
ffi.cast("uint8_t *", quad_perm.ctypes.data),
67+
ffi.NULL,
6768
)
6869

6970
# Check the computation against correct NumPy value
@@ -133,6 +134,7 @@ def test_rank1(compile_args):
133134
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
134135
ffi.cast("int *", entity_index.ctypes.data),
135136
ffi.cast("uint8_t *", quad_perm.ctypes.data),
137+
ffi.NULL,
136138
)
137139

138140
f = np.array([[1.0, 2.0, 3.0], [-4.0, -5.0, 6.0]])
@@ -203,6 +205,7 @@ def test_elimiate_zero_tables_tensor(compile_args):
203205
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
204206
ffi.cast("int *", entity_index.ctypes.data),
205207
ffi.cast("uint8_t *", quad_perm.ctypes.data),
208+
ffi.NULL,
206209
)
207210

208211
def exact_expr(x):
@@ -261,6 +264,7 @@ def test_grad_constant(compile_args):
261264
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
262265
ffi.cast("int *", entity_index.ctypes.data),
263266
ffi.cast("uint8_t *", quad_perm.ctypes.data),
267+
ffi.NULL,
264268
)
265269

266270
assert output[0] == pytest.approx(consts[1] * 2 * points[0, 0])
@@ -316,6 +320,7 @@ def test_facet_expression(compile_args):
316320
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
317321
ffi.cast("int *", entity_index.ctypes.data),
318322
ffi.cast("uint8_t *", quad_perm.ctypes.data),
323+
ffi.NULL,
319324
)
320325
# Assert that facet normal is perpendicular to tangent
321326
assert np.isclose(np.dot(output, tangent), 0)
@@ -366,6 +371,7 @@ def check_expression(expression_class, output_shape, entity_values, reference_va
366371
ffi_data["coords"],
367372
ffi_data["entity_index"],
368373
ffi_data["quad_perm"],
374+
ffi.NULL,
369375
)
370376
np.testing.assert_allclose(output, ref_val)
371377

@@ -430,5 +436,6 @@ def test_facet_geometry_expressions_3D(compile_args):
430436
ffi.cast(f"{c_xtype} *", coords.ctypes.data),
431437
ffi.cast("int *", entity_index.ctypes.data),
432438
ffi.cast("uint8_t *", quad_perm.ctypes.data),
439+
ffi.NULL,
433440
)
434441
np.testing.assert_allclose(output, np.asarray(ref_fev)[:3, :])

0 commit comments

Comments
 (0)