Skip to content

Commit

Permalink
Added new drjit.scatter_inc() operation for stream compaction
Browse files Browse the repository at this point in the history
This commit adds a new and relatively advanced Dr.Jit operation named
``drjit.scatter_inc()`` that atomically increments a value within a
``uint32``-typed Dr.Jit array.

It works just like the standard ``drjit.scatter_reduce()`` operation for
32-bit unsigned integer operands, but with a fixed ``value=1`` parameter
and ``reduce_op=ReduceOp::Add``.

The main difference is that this variant additionally returns the *old*
value of the target array prior to the atomic update in contrast to the
more general scatter-reduction, which just returns ``None``. The
operation also supports masking---the return value in the unmasked case
is undefined.

This operation is a building block for stream compaction: threads can
scatter-increment a global counter to request a spot in an array and
then write their result there. The recipe for this is look as follows:

```python
ctr = UInt32(0) # Counter array
mask = drjit.ones(Bool, len(data_1)) # .. or a more complex condition

my_index = dr.scatter_inc(target=ctr, index=UInt32(0), mask=active)

dr.scatter(
    target=data_compact_1,
    value=data_1,
    index=my_index,
    mask=active
)

dr.scatter(
    target=data_compact_2,
    value=data_2,
    index=my_index,
    mask=active
)
```

When following this approach, be sure to provide the same mask value to
the ``dr.scatter_inc()`` and subsequent ``dr.scatter()`` operations.

``dr.scatter_inc()`` exhibits the following unusual behavior compared to
normal Dr.Jit operations: the return value references the instantaneous
state during a potentially large sequence of atomic operations. This
instantaneous state is not reproducible in later kernel evaluations, and
Dr.Jit will refuse to do so when the computed index is reused:

```python
my_index = dr.scatter_inc(target=ctr, index=UInt32(0), mask=active)
dr.scatter(
    target=data_compact_1,
    value=data_1,
    index=my_index,
    mask=active
)

dr.eval(data_compact_1) # Run Kernel #1

dr.scatter(
    target=data_compact_2,
    value=data_2,
    index=my_index, # <-- oops, reusing my_index in another kernel.
    mask=active     #     This raises an exception.
)
```

To get the above code to work, you will need to evaluate ``my_index`` at
the same time to materialize it into a stored (and therefore trivially
reproducible) representation. For this, ensure that the size of the
``active`` mask matches ``len(data_*)`` and that it is not the trivial
``True`` default mask (otherwise, the evaluated ``my_index`` will be
scalar).

```python
dr.eval(data_compact_1, my_index)
```

Such multi-stage evaluation is potentially inefficient and may defeat
the purpose of performing stream compaction in the first place. In
general, prefer keeping all scatter operations involving the computed
index in the same kernel, and then this issue does not arise.
  • Loading branch information
wjakob committed Oct 20, 2023
1 parent dece978 commit 754a541
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ These operations are *horizontal* in the sense that [..]
.. autofunction:: gather
.. autofunction:: scatter
.. autofunction:: scatter_reduce
.. autofunction:: scatter_inc
.. autofunction:: ravel
.. autofunction:: unravel
.. autofunction:: slice
Expand Down
102 changes: 101 additions & 1 deletion drjit/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,106 @@ def scatter_reduce(op, target, value, index, active=True):

return value.scatter_reduce_(op, target, index, active)

def scatter_inc(target, index, active=True):
'''
Atomically increment a value within an unsigned 32-bit integer array
and return the value prior to the update.
This operation works just like the :py:func:`drjit.scatter_reduce()`
operation for 32-bit unsigned integer operands, but with a fixed
``value=1`` parameter and ``reduce_op=ReduceOp::Add``.
The main difference is that this variant additionally returns the *old*
value of the target array prior to the atomic update in contrast to the
more general scatter-reduction, which just returns ``None``. The operation
also supports masking---the return value in the unmasked case is undefined.
Both ``target`` and ``index`` parameters must be 1D unsigned 32-bit
arrays.
This operation is a building block for stream compaction: threads can
scatter-increment a global counter to request a spot in an array and then
write their result there. The recipe for this is look as follows:
.. code-block:: python
ctr = UInt32(0) # Counter
active = drjit.ones(Bool, len(data_1)) # .. or a more complex condition
my_index = dr.scatter_inc(target=ctr, index=UInt32(0), mask=active)
dr.scatter(
target=data_compact_1,
value=data_1,
index=my_index,
mask=active
)
dr.scatter(
target=data_compact_2,
value=data_2,
index=my_index,
mask=active
)
When following this approach, be sure to provide the same mask value to the
:py:func:`drjit.scatter_inc()` and subsequent :py:func:`dr.scatter()`
operations.
The function :py:func:`drjit.scatter_inc()` exhibits the following unusual
behavior compared to regular Dr.Jit operations: the return value references
the instantaneous state during a potentially large sequence of atomic
operations. This instantaneous state is not reproducible in later kernel
evaluations, and Dr.Jit will refuse to do so when the computed index is
reused. In essence, the variable is "consumed" by the process of
evaluation.
.. code-block:: python
my_index = dr.scatter_inc(target=ctr, index=UInt32(0), mask=active)
dr.scatter(
target=data_compact_1,
value=data_1,
index=my_index,
mask=active
)
dr.eval(data_compact_1) # Run Kernel #1
dr.scatter(
target=data_compact_2,
value=data_2,
index=my_index, # <-- oops, reusing my_index in another kernel.
mask=active # This raises an exception.
)
To get the above code to work, you will need to evaluate ``my_index`` at
the same time to materialize it into a stored (and therefore trivially
reproducible) representation. For this, ensure that the size of the
``active`` mask matches ``len(data_*)`` and that it is not the trivial
``True`` default mask (otherwise, the evaluated ``my_index`` will be
scalar).
.. code-block:: python
dr.eval(data_compact_1, my_index)
Such multi-stage evaluation is potentially inefficient and may defeat the
purpose of performing stream compaction in the first place. In general,
prefer keeping all scatter operations involving the computed index in the
same kernel, and then this issue does not arise.
The implementation of :py:func:`drjit.scatter_inc()` performs a local
reduction first, followed by a single atomic write per SIMD packet/warp.
This is done to reduce contention from a potentially very large number of
atomic operations targeting the same memory address. Fully masked updates
do not cause memory traffic.
'''

if not _dr.is_jit_v(target) or target.Type != _dr.VarType.UInt32 or \
type(index) is not type(target):
raise Exception('scatter_inc(): invalid input types!')
return type(target).scatter_inc_(target, index, active)


def ravel(array, order='A'):
'''
Expand Down Expand Up @@ -3800,7 +3900,7 @@ def hypot(a, b):


def prefix_sum(value, exclusive=True):
'''
r'''
Compute an exclusive or inclusive prefix sum of the 1D input array.
By default, the function returns an output array :math:`\mathbf{y}` of the
Expand Down
17 changes: 14 additions & 3 deletions include/drjit/array_router.h
Original file line number Diff line number Diff line change
Expand Up @@ -1173,20 +1173,31 @@ void scatter_reduce(ReduceOp op, Target &&target, const Value &value,
}

template <typename Target, typename Value, typename Index>
void scatter_reduce_kahan(Target &&target_1, Target &&target_2,
void scatter_reduce_kahan(Target &target_1, Target &target_2,
const Value &value, const Index &index,
const mask_t<Value> &mask = true) {
static_assert(
is_jit_v<Target> &&
is_jit_v<Value> &&
is_jit_v<Index> &&
array_depth_v<Value> == array_depth_v<Index> &&
array_depth_v<Value> == 1,
array_depth_v<Target> == 1 &&
array_depth_v<Value> == 1 &&
array_depth_v<Index> == 1,
"Only flat JIT arrays are supported at the moment");

value.scatter_reduce_kahan_(target_1, target_2, index, mask);
}

template <typename Index>
Index scatter_inc(Index &target,
const Index &index,
const mask_t<Index> &mask = true) {
static_assert(is_jit_v<Index> && array_depth_v<Index> == 1,
"Only flat JIT arrays are supported at the moment");

return Index::scatter_inc_(target, index, mask);
}

template <typename T, typename TargetType>
decltype(auto) migrate(const T &value, TargetType target) {
static_assert(std::is_enum_v<TargetType>);
Expand Down
12 changes: 11 additions & 1 deletion include/drjit/autodiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -1413,7 +1413,7 @@ struct DiffArray : ArrayBase<value_t<Type_>, is_mask_v<Type_>, DiffArray<Type_>>
const MaskType &mask = true) const {
if constexpr (std::is_scalar_v<Type>) {
(void) dst_1; (void) dst_2; (void) offset; (void) mask;
drjit_raise("Array scatter_reduce operation not supported for scalar array type.");
drjit_raise("Array scatter_reduce_kahan operation not supported for scalar array type.");
} else {
scatter_reduce_kahan(dst_1.m_value, dst_2.m_value, m_value,
offset.m_value, mask.m_value);
Expand All @@ -1430,6 +1430,16 @@ struct DiffArray : ArrayBase<value_t<Type_>, is_mask_v<Type_>, DiffArray<Type_>>
}
}

static DiffArray scatter_inc_(DiffArray &dst, const DiffArray &offset, const MaskType &mask) {
if constexpr (std::is_scalar_v<Type>) {
(void) dst; (void) offset; (void) mask;
drjit_raise("Array scatter_inc operation not supported for scalar array type.");
} else {
return Type::scatter_inc_(dst.m_value, offset.m_value, mask.m_value);
}
}


template <bool>
static DiffArray gather_(const void *src, const IndexType &offset,
const MaskType &mask = true) {
Expand Down
8 changes: 8 additions & 0 deletions include/drjit/jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,14 @@ struct JitArray : ArrayBase<Value_, is_mask_v<Value_>, Derived_> {
m_index, index.index(), mask.index());
}

template <typename Mask>
static Derived scatter_inc_(Derived &dst, const Derived &index,
const Mask &mask) {
static_assert(
std::is_same_v<detached_t<Mask>, detached_t<mask_t<Derived>>>);
return steal(jit_var_scatter_inc(dst.index_ptr(), index.index(), mask.index()));
}

//! @}
// -----------------------------------------------------------------------

Expand Down
6 changes: 6 additions & 0 deletions src/python/bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,12 @@ auto bind_full(py::class_<Array> &cls, bool /* scalar_mode */ = false) {
dr::scatter_reduce(op, target, value, index, mask);
}, "op"_a, "target"_a.noconvert(), "index"_a, "mask"_a);
}
if constexpr (std::is_same_v<Scalar, uint32_t> && dr::is_jit_v<Array>) {
cls.def_static("scatter_inc_",
[](Array& target, const Array & index, const Mask& mask) {
return dr::scatter_inc(target, index, mask);
});
}
}

if constexpr (dr::is_jit_v<Array>) {
Expand Down
16 changes: 16 additions & 0 deletions tests/python/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,19 @@ def test08_prefix_sum(m):
assert dr.prefix_sum(t(1, 2, 3)) == t(0, 1, 3)
assert dr.prefix_sum(t(1, 2, 3), exclusive=False) == t(1, 3, 6)
assert dr.cumsum(t(1, 2, 3)) == t(1, 3, 6)

def test09_scatter_inc(m):
try:
import numpy as np
except ImportError:
pytest.skip('NumPy is not installed')
n=10000
counter = m.UInt32(0)
index = dr.arange(m.UInt32, n)
offset = dr.scatter_inc(counter, m.UInt32(0))

out = dr.zeros(m.UInt32, n)
dr.scatter(out, offset, index)
out_np = np.array(out)
out_np.sort()
assert np.all(out_np == np.arange(n))

0 comments on commit 754a541

Please sign in to comment.