Skip to content

Commit

Permalink
add in-place mutation ops for NamedArrays, support lists/1-d JAX Arra…
Browse files Browse the repository at this point in the history
…ys for indexing (#85)
  • Loading branch information
dlwh authored Apr 14, 2024
1 parent 5599b78 commit 864a615
Show file tree
Hide file tree
Showing 6 changed files with 530 additions and 65 deletions.
5 changes: 3 additions & 2 deletions docs/fp8.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# FP8 Training

!!! warning
FP8 training in Haliax is currently experimental and may change in the future.

FP8 training in Haliax is currently experimental and may change in the future.

FP8 refers to 8-bit floating point numbers. FP8 is a massively reduced precision compared to the 32-bit floating point numbers
or 16-bit floating point numbers that are typically used in deep learning: there are only 256 possible values in FP8, compared to
Expand Down Expand Up @@ -123,7 +124,7 @@ you will get the gradient computation as normal, but you'll also get the updated
This updated state needs to directly replace the state in the module (rather than be used for a gradient step), which is
why you need to use the `partition_for_grad_overwrite`

The FP8 `dot_general` module is implemented in [haliax.quantization.Fp8DotGeneral][]. It's actually not that complicated:
The FP8 `dot_general` module is implemented in [haliax.quantization.Fp8DotGeneralOp][]. It's actually not that complicated:

1) It holds a scaling factor and history of maximum values for each of (lhs, rhs, output) and updates them based on the
gradients.
Expand Down
93 changes: 90 additions & 3 deletions docs/indexing.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,42 @@ Otherwise, the idea is pretty straightforward: any unspecified axes are treated
slices are kept in reduced dimensions, and integers eliminate dimensions. If all dimensions are eliminated, a scalar
JAX ndarray is returned.

The following types are supported for indexing:

* Integers, including scalar JAX arrays
* Slices
* [haliax.dslice][] objects (See [Dynamic Slices](#dynamic-slices) below.)
* Lists of integers
* Named arrays (See [Advanced Indexing](#advanced-indexing) below.)
* 1-D JAX Arrays of integers

1-D JAX Arrays are interpreted as NamedArrays with a single axis with the same name as
the one they are slicing. That is:

```python
import haliax as hax
import jax
import jax.numpy as jnp

X = hax.Axis("X", 10)
Y = hax.Axis("Y", 20)

a = hax.random.uniform(jax.random.PRNGKey(0), (X, Y))

sliced = a["X", jnp.array([1, 2, 3])]

# same as
a.array[jnp.array([1, 2, 3]), :]
```

Note that boolean arrays are not supported, as JAX does not support them in JIT-compiled code. You
can use [haliax.where][] for most of the same functionality, though.

### Shapes in JAX

Before we continue note on shapes in JAX. Most JAX code will be used inside `jit`, which means that the sizes of all arrays
must be determined at compile time (i.e. when JAX interprets your functions abstractly). This is a hard requirement in
XLA.
Before we continue, a note on shapes in JAX. Most JAX code will be used inside `jit`, which means that the sizes of all
arrays must be determined at compile time (i.e. when JAX interprets your functions abstractly). This is a hard
requirement in XLA. It might worked around one day, but it's the way things are for now.

A consequence of this restriction is that certain indexing patterns aren't allowed in `jit`-ed JAX code:

Expand Down Expand Up @@ -177,3 +207,60 @@ a[{"Y": ind1}] # error, "X" is not eliminated by the indexing operation

a[{"X": ind2, "Y": ind1}] # ok, because X and Y are eliminated by the indexing operation
```

## Index Update

JAX is a functional version of NumPy, so it doesn't directly support in-place updates. It does
however [provide an `at` syntax](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at)
to express the same logic (and that will typically be optimized to be as efficient as an in-place update). Haliax
provides a similar syntax for updating arrays.

```python
import haliax as hax

X = hax.Axis("X", 10)
Y = hax.Axis("Y", 20)
Z = hax.Axis("Z", 30)

a = hax.zeros((X, Y, Z))

a.at[{"X": 1, "Y": 2, "Z": 3}].set(1.0) # sets a[1, 2, 3] to 1.0
a.at["X", 1].set(2.0) # sets a[1, :, :] to 2.0

a.at[{"X": 1, "Y": hax.ds(3, 5), "Z": 3}].add(1.0) # adds 1.0 to a[1, 3:8, 3]
```

Haliax supports the same `at` functionality as JAX, just with named arrays and additionally dslices. A summary of the
`at` syntax is as follows:

| Alternate Syntax | Equivalent In-Place Operation |
|------------------------------|-------------------------------|
| `x = x.at[idx].set(y)` | `x[idx] = y` |
| `x = x.at[idx].add(y)` | `x[idx] += y` |
| `x = x.at[idx].multiply(y)` | `x[idx] *= y` |
| `x = x.at[idx].divide(y)` | `x[idx] /= y` |
| `x = x.at[idx].power(y)` | `x[idx] **= y` |
| `x = x.at[idx].min(y)` | `x[idx] = minimum(x[idx], y)` |
| `x = x.at[idx].max(y)` | `x[idx] = maximum(x[idx], y)` |
| `x = x.at[idx].apply(ufunc)` | `ufunc.at(x, idx)` |
| `x = x.at[idx].get()` | `x = x[idx]` |

These methods also have options to control out-of-bounds behavior, as well as allowing you
to specify that the indices are sorted or unique. (If they are, XLA can sometimes optimize the
operation more effectively.)

!!! note

These named arguments are not passed to `at`, but to the next method in the chain.

(This is copied from the JAX documentation:)

* `mode`: One of `"promise_in_bounds"`, `"clip"`, `"drop"`, or `"fill"`. See [JAX's documentation](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.GatherScatterMode.html#jax.lax.GatherScatterMode) for more details.
* `indices_are_sorted`: If `True`, the implementation will assume that the indices passed to `at` are sorted in ascending order, which can lead to more efficient execution on some backends.
* `unique_indices`: If `True`, the implementation will assume that the indices passed to `at` are unique, which can result in more efficient execution on some backends.
* `fill_value`: Only applies to the `get()` method: the fill value to return for out-of-bounds slices when mode is 'fill'. Ignored otherwise. Defaults to NaN for inexact types, the largest negative value for signed types, the largest positive value for unsigned types, and True for booleans.

!!! tip

It's worth emphasizing that these functions are typically compiled to scatter-add and friends (as appropriate).
This is the preferred way to do scatter/gather operations in JAX, as well as in Haliax.
1 change: 1 addition & 0 deletions docs/rearrange.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ If you're used to einops, the syntax should be familiar, with the main differenc
and the additional "unordered" syntax for selecting dimensions by name.

!!! warning

This syntax is fairly new. It is pretty well-tested, but it is possible that there are bugs.

### Examples
Expand Down
Loading

0 comments on commit 864a615

Please sign in to comment.