Skip to content

Commit 0ee47c9

Browse files
Now using strict shape/dtype promotion rules.
This means that: 1. Tests now pass using `JAX_NUMPY_DTYPE_PROMOTION=strict` and `JAX_NUMPY_RANK_PROMOTION=raise`, and these are enabled in tests by default. 2. The values passed to `diffeqsolve` now more carefully determine the dtype used in the integration (previously things were mostly just left to behave in ad-hoc fashion; whatever the various interacting arrays promoted their dtypes to): a. The dtype of timelike values is the `jnp.result_type` of `t0`, `t1`, `dt0`, and `SaveAt(ts=...)`. If any of these are complex an error is raised. If these are all integers we use the default floating-point dtype. b. The `jnp.result_type` of the time dtype, and each leaf of `y0`, is the dtype of that leaf. 3. Of course, `diffeqsolve` accepts user-specified functions (e.g. the vector field of an `ODETerm`), and these could potentially return arrays with dtypes that do not match the ones we have selected above, which might cause further upcasting. For the sake of backward compatibility we don't try to prohibit this -- a user who feels strongly about this should enable `JAX_NUMPY_DTYPE_PROMOTION=strict` and fix their vector fields appropriately. (And can then be assured that the dtypes of these quantities are exactly as specified by the rules above.) So the key thing this commit enables is that using this flag to enforce this is now possible, without any false positives from Diffrax itself!
1 parent 7965e89 commit 0ee47c9

13 files changed

+184
-71
lines changed

diffrax/_brownian/path.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import jax.numpy as jnp
88
import jax.random as jr
99
import jax.tree_util as jtu
10+
import lineax.internal as lxi
1011
from jaxtyping import Array, PRNGKeyArray, PyTree
1112

1213
from .._custom_types import levy_tree_transpose, LevyArea, LevyVal, RealScalarLike
1314
from .._misc import (
14-
default_floating_dtype,
1515
force_bitcast_convert_type,
1616
is_tuple_of_ints,
1717
split_by_tree,
@@ -52,7 +52,7 @@ def __init__(
5252
levy_area: LevyArea = "",
5353
):
5454
self.shape = (
55-
jax.ShapeDtypeStruct(shape, default_floating_dtype())
55+
jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype())
5656
if is_tuple_of_ints(shape)
5757
else shape
5858
)
@@ -87,8 +87,14 @@ def evaluate(
8787
) -> Union[PyTree[Array], LevyVal]:
8888
del left
8989
if t1 is None:
90+
dtype = jnp.result_type(t0)
9091
t1 = t0
91-
t0 = 0
92+
t0 = jnp.array(0, dtype)
93+
else:
94+
with jax.numpy_dtype_promotion("standard"):
95+
dtype = jnp.result_type(t0, t1)
96+
t0 = jnp.astype(t0, dtype)
97+
t1 = jnp.astype(t1, dtype)
9298
t0 = eqxi.nondifferentiable(t0, name="t0")
9399
t1 = eqxi.nondifferentiable(t1, name="t1")
94100
t1 = cast(RealScalarLike, t1)

diffrax/_brownian/tree.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import jax.numpy as jnp
1010
import jax.random as jr
1111
import jax.tree_util as jtu
12+
import lineax.internal as lxi
1213
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
1314

1415
from .._custom_types import (
@@ -20,7 +21,6 @@
2021
RealScalarLike,
2122
)
2223
from .._misc import (
23-
default_floating_dtype,
2424
is_tuple_of_ints,
2525
linear_rescale,
2626
split_by_tree,
@@ -179,7 +179,7 @@ def __init__(
179179
self.levy_area = levy_area
180180
self._spline = _spline
181181
self.shape = (
182-
jax.ShapeDtypeStruct(shape, default_floating_dtype())
182+
jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype())
183183
if is_tuple_of_ints(shape)
184184
else shape
185185
)

diffrax/_integrate.py

+27-11
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import jax.core
1111
import jax.numpy as jnp
1212
import jax.tree_util as jtu
13+
import lineax.internal as lxi
1314
from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real
1415

1516
from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint
@@ -299,10 +300,10 @@ def body_fun_aux(state):
299300

300301
# Count the number of steps, just for statistical purposes.
301302
num_steps = state.num_steps + 1
302-
num_accepted_steps = state.num_accepted_steps + keep_step
303+
num_accepted_steps = state.num_accepted_steps + jnp.where(keep_step, 1, 0)
303304
# Not just ~keep_step, which does the wrong thing when keep_step is a non-array
304305
# bool True/False.
305-
num_rejected_steps = state.num_rejected_steps + jnp.invert(keep_step)
306+
num_rejected_steps = state.num_rejected_steps + jnp.where(keep_step, 0, 1)
306307

307308
#
308309
# Store the output produced from this numerical step.
@@ -369,7 +370,7 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
369370
subsaveat.fn(tprev, y, args),
370371
save_state.ys,
371372
)
372-
save_index = save_state.save_index + keep_step
373+
save_index = save_state.save_index + jnp.where(keep_step, 1, 0)
373374
save_state = eqx.tree_at(
374375
lambda s: [s.ts, s.ys, s.save_index],
375376
save_state,
@@ -388,7 +389,7 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
388389
dense_info,
389390
dense_infos,
390391
)
391-
dense_save_index = dense_save_index + keep_step
392+
dense_save_index = dense_save_index + jnp.where(keep_step, 1, 0)
392393

393394
new_state = State(
394395
y=y,
@@ -625,7 +626,7 @@ def diffeqsolve(
625626
f"t0 with value {t0} and type {type(t0)}, "
626627
f"dt0 with value {dt0} and type {type(dt0)}"
627628
)
628-
with jax.ensure_compile_time_eval():
629+
with jax.ensure_compile_time_eval(), jax.numpy_dtype_promotion("standard"):
629630
pred = (t1 - t0) * dt0 < 0
630631
dt0 = eqxi.error_if(jnp.array(dt0), pred, msg)
631632

@@ -641,7 +642,8 @@ def diffeqsolve(
641642
)
642643
warnings.warn(
643644
"Complex dtype support is work in progress, please read "
644-
"https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully."
645+
"https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.",
646+
stacklevel=2,
645647
)
646648

647649
# Backward compatibility
@@ -653,7 +655,8 @@ def diffeqsolve(
653655
f"{solver.__class__.__name__} is deprecated in favour of "
654656
"`terms=MultiTerm(ODETerm(...), SomeOtherTerm(...))`. This means that "
655657
"the same terms can now be passed used for both general and SDE-specific "
656-
"solvers!"
658+
"solvers!",
659+
stacklevel=2,
657660
)
658661
terms = MultiTerm(*terms)
659662

@@ -668,7 +671,8 @@ def diffeqsolve(
668671
if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)):
669672
warnings.warn(
670673
f"`{type(solver).__name__}` is not marked as converging to either the "
671-
"Itô or the Stratonovich solution."
674+
"Itô or the Stratonovich solution.",
675+
stacklevel=2,
672676
)
673677
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
674678
# Specific check to not work even if using HalfSolver(Euler())
@@ -684,11 +688,22 @@ def diffeqsolve(
684688
)
685689

686690
# Allow setting e.g. t0 as an int with dt0 as a float.
687-
timelikes = [jnp.array(0.0), t0, t1, dt0] + [
691+
timelikes = [t0, t1, dt0] + [
688692
s.ts for s in jtu.tree_leaves(saveat.subs, is_leaf=_is_subsaveat)
689693
]
690694
timelikes = [x for x in timelikes if x is not None]
691-
time_dtype = jnp.result_type(*timelikes)
695+
with jax.numpy_dtype_promotion("standard"):
696+
time_dtype = jnp.result_type(*timelikes)
697+
if jnp.issubdtype(time_dtype, jnp.complexfloating):
698+
raise ValueError(
699+
"Cannot use complex dtype for `t0`, `t1`, `dt0`, or `SaveAt(ts=...)`."
700+
)
701+
elif jnp.issubdtype(time_dtype, jnp.floating):
702+
pass
703+
elif jnp.issubdtype(time_dtype, jnp.integer):
704+
time_dtype = lxi.default_floating_dtype()
705+
else:
706+
raise ValueError(f"Unrecognised time dtype {time_dtype}.")
692707
t0 = jnp.asarray(t0, dtype=time_dtype)
693708
t1 = jnp.asarray(t1, dtype=time_dtype)
694709
if dt0 is not None:
@@ -708,7 +723,8 @@ def _get_subsaveat_ts(saveat):
708723
# fixing issue with float64 and weak dtypes, see discussion at:
709724
# https://github.com/patrick-kidger/diffrax/pull/197#discussion_r1130173527
710725
def _promote(yi):
711-
_dtype = jnp.result_type(yi, time_dtype) # noqa: F821
726+
with jax.numpy_dtype_promotion("standard"):
727+
_dtype = jnp.result_type(yi, time_dtype) # noqa: F821
712728
return jnp.asarray(yi, dtype=_dtype)
713729

714730
y0 = jtu.tree_map(_promote, y0)

diffrax/_misc.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,30 @@ def static_select(pred: BoolScalarLike, a: ArrayLike, b: ArrayLike) -> ArrayLike
162162
return lax.select(pred, a, b)
163163

164164

165-
def default_floating_dtype():
166-
if jax.config.jax_enable_x64: # pyright: ignore
167-
return jnp.float64
168-
else:
169-
return jnp.float32
165+
def upcast_or_raise(
166+
x: ArrayLike, array_for_dtype: ArrayLike, x_name: str, dtype_name: str
167+
):
168+
"""If `JAX_NUMPY_DTYPE_PROMOTION=strict`, then this will raise an error if
169+
`jnp.result_type(x, array_for_dtype)` is not the same as `array_for_dtype.dtype`.
170+
It will then cast `x` to `jnp.result_type(x, array_for_dtype)`.
171+
172+
Thus if `JAX_NUMPY_DTYPE_PROMOTION=standard`, then the usual anything-goes behaviour
173+
will apply. If `JAX_NUMPY_DTYPE_PROMOTION=strict` then we loosen from prohibiting
174+
all dtype casting, to still allowing upcasting.
175+
"""
176+
x_dtype = jnp.result_type(x)
177+
target_dtype = jnp.result_type(array_for_dtype)
178+
with jax.numpy_dtype_promotion("standard"):
179+
promote_dtype = jnp.result_type(x_dtype, target_dtype)
180+
config_value = jax.config.jax_numpy_dtype_promotion
181+
if config_value == "strict":
182+
if target_dtype != promote_dtype:
183+
raise ValueError(
184+
f"When `JAX_NUMPY_DTYPE_PROMOTION=strict`, then {x_name} must have "
185+
f"a dtype that can be promoted to the dtype of {dtype_name}. "
186+
f"However {x_name} had dtype {x_dtype} and {dtype_name} had dtype "
187+
f"{target_dtype}."
188+
)
189+
elif config_value != "standard":
190+
assert False, f"Unrecognised `JAX_NUMPY_DTYPE_PROMOTION={config_value}`"
191+
return jnp.astype(x, promote_dtype)

diffrax/_step_size_controller/adaptive.py

+45-12
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
VF,
2828
Y,
2929
)
30+
from .._misc import upcast_or_raise
3031
from .._solution import RESULTS
3132
from .._term import AbstractTerm, ODETerm
3233
from .base import AbstractStepSizeController
@@ -325,6 +326,14 @@ class PIDController(
325326
safety: RealScalarLike = 0.9
326327
error_order: Optional[RealScalarLike] = None
327328

329+
def __check_init__(self):
330+
if self.jump_ts is not None and not jnp.issubdtype(
331+
self.jump_ts.dtype, jnp.inexact
332+
):
333+
raise ValueError(
334+
f"jump_ts must be floating point, not {self.jump_ts.dtype}"
335+
)
336+
328337
def wrap(self, direction: IntScalarLike):
329338
step_ts = None if self.step_ts is None else self.step_ts * direction
330339
jump_ts = None if self.jump_ts is None else self.jump_ts * direction
@@ -632,18 +641,30 @@ def _clip_step_ts(self, t0: RealScalarLike, t1: RealScalarLike) -> RealScalarLik
632641
if self.step_ts is None:
633642
return t1
634643

644+
step_ts0 = upcast_or_raise(
645+
self.step_ts,
646+
t0,
647+
"`PIDController.step_ts`",
648+
"time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)",
649+
)
650+
step_ts1 = upcast_or_raise(
651+
self.step_ts,
652+
t1,
653+
"`PIDController.step_ts`",
654+
"time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)",
655+
)
635656
# TODO: it should be possible to switch this O(nlogn) for just O(n) by keeping
636657
# track of where we were last, and using that as a hint for the next search.
637-
t0_index = jnp.searchsorted(self.step_ts, t0, side="right")
638-
t1_index = jnp.searchsorted(self.step_ts, t1, side="right")
658+
t0_index = jnp.searchsorted(step_ts0, t0, side="right")
659+
t1_index = jnp.searchsorted(step_ts1, t1, side="right")
639660
# This minimum may or may not actually be necessary. The left branch is taken
640661
# iff t0_index < t1_index <= len(self.step_ts), so all valid t0_index s must
641662
# already satisfy the minimum.
642663
# However, that branch is actually executed unconditionally and then where'd,
643664
# so we clamp it just to be sure we're not hitting undefined behaviour.
644665
t1 = jnp.where(
645666
t0_index < t1_index,
646-
self.step_ts[jnp.minimum(t0_index, len(self.step_ts) - 1)],
667+
step_ts1[jnp.minimum(t0_index, len(self.step_ts) - 1)],
647668
t1,
648669
)
649670
return t1
@@ -653,23 +674,35 @@ def _clip_jump_ts(
653674
) -> tuple[RealScalarLike, BoolScalarLike]:
654675
if self.jump_ts is None:
655676
return t1, False
656-
if self.jump_ts is not None and not jnp.issubdtype(
657-
self.jump_ts.dtype, jnp.inexact
658-
):
677+
assert jnp.issubdtype(self.jump_ts.dtype, jnp.inexact)
678+
if not jnp.issubdtype(jnp.result_type(t0), jnp.inexact):
659679
raise ValueError(
660-
f"jump_ts must be floating point, not {self.jump_ts.dtype}"
680+
"`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. "
681+
f"Got {jnp.result_type(t0)}."
661682
)
662683
if not jnp.issubdtype(jnp.result_type(t1), jnp.inexact):
663684
raise ValueError(
664-
"t0, t1, dt0 must be floating point when specifying jump_t. Got "
665-
f"{jnp.result_type(t1)}."
685+
"`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. "
686+
f"Got {jnp.result_type(t1)}."
666687
)
667-
t0_index = jnp.searchsorted(self.jump_ts, t0, side="right")
668-
t1_index = jnp.searchsorted(self.jump_ts, t1, side="right")
688+
jump_ts0 = upcast_or_raise(
689+
self.jump_ts,
690+
t0,
691+
"`PIDController.jump_ts`",
692+
"time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)",
693+
)
694+
jump_ts1 = upcast_or_raise(
695+
self.jump_ts,
696+
t1,
697+
"`PIDController.jump_ts`",
698+
"time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)",
699+
)
700+
t0_index = jnp.searchsorted(jump_ts0, t0, side="right")
701+
t1_index = jnp.searchsorted(jump_ts1, t1, side="right")
669702
next_made_jump = t0_index < t1_index
670703
t1 = jnp.where(
671704
next_made_jump,
672-
eqxi.prevbefore(self.jump_ts[jnp.minimum(t0_index, len(self.jump_ts) - 1)]),
705+
eqxi.prevbefore(jump_ts1[jnp.minimum(t0_index, len(self.jump_ts) - 1)]),
673706
t1,
674707
)
675708
return t1, next_made_jump

diffrax/_term.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from jaxtyping import Array, ArrayLike, PyTree, PyTreeDef
1313

1414
from ._custom_types import Args, Control, IntScalarLike, RealScalarLike, VF, Y
15+
from ._misc import upcast_or_raise
1516
from ._path import AbstractPath
1617

1718

@@ -159,7 +160,8 @@ class ODETerm(AbstractTerm):
159160
appearing on the right hand side of an ODE, in which the control is time.
160161
161162
`vector_field` should return some PyTree, with the same structure as the initial
162-
state `y0`, and with every leaf broadcastable to the equivalent leaf in `y0`.
163+
state `y0`, and with every leaf shape-broadcastable and dtype-upcastable to the
164+
equivalent leaf in `y0`.
163165
164166
!!! example
165167
@@ -179,13 +181,33 @@ def vf(self, t: RealScalarLike, y: Y, args: Args) -> VF:
179181
"The vector field inside `ODETerm` must return a pytree with the "
180182
"same structure as `y0`."
181183
)
182-
return jtu.tree_map(lambda o, yi: jnp.broadcast_to(o, jnp.shape(yi)), out, y)
184+
185+
def _broadcast_and_upcast(oi, yi):
186+
oi = jnp.broadcast_to(oi, jnp.shape(yi))
187+
oi = upcast_or_raise(
188+
oi,
189+
yi,
190+
"the vector field passed to `ODETerm`",
191+
"the corresponding leaf of `y`",
192+
)
193+
return oi
194+
195+
return jtu.tree_map(_broadcast_and_upcast, out, y)
183196

184197
def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> RealScalarLike:
185198
return t1 - t0
186199

187200
def prod(self, vf: VF, control: RealScalarLike) -> Y:
188-
return jtu.tree_map(lambda v: control * v, vf)
201+
def _mul(v):
202+
c = upcast_or_raise(
203+
control,
204+
v,
205+
"the output of `ODETerm.contr(...)`",
206+
"the output of `ODETerm.vf(...)`",
207+
)
208+
return c * v
209+
210+
return jtu.tree_map(_mul, vf)
189211

190212

191213
ODETerm.__init__.__doc__ = """**Arguments:**

test/conftest.py

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55

66
jax.config.update("jax_enable_x64", True) # pyright: ignore
7+
jax.config.update("jax_numpy_rank_promotion", "raise") # pyright: ignore
8+
jax.config.update("jax_numpy_dtype_promotion", "strict") # pyright: ignore
79

810

911
@pytest.fixture()

test/test_adaptive_stepsize_controller.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def run(ys, controller, state):
8383
_, tprev, tnext, _, state, _ = controller.adapt_step_size(
8484
0, 1, y0, y1_candidate, None, y_error, 5, state
8585
)
86-
return tprev + tnext + sum(jnp.sum(x) for x in jtu.tree_leaves(state))
86+
with jax.numpy_dtype_promotion("standard"):
87+
return tprev + tnext + sum(jnp.sum(x) for x in jtu.tree_leaves(state))
8788

8889
y0 = jnp.array(1.0)
8990
y1_candidate = jnp.array(2.0)

0 commit comments

Comments
 (0)