Skip to content

Commit 98a1354

Browse files
authored
Merge branch 'patrick-kidger:main' into main
2 parents fe1ca9a + 6192f62 commit 98a1354

File tree

7 files changed

+52
-17
lines changed

7 files changed

+52
-17
lines changed

diffrax/integrate.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .saveat import SaveAt, SubSaveAt
2020
from .solution import is_okay, is_successful, RESULTS, Solution
2121
from .solver import (
22+
AbstractImplicitSolver,
2223
AbstractItoSolver,
2324
AbstractSolver,
2425
AbstractStratonovichSolver,
@@ -605,6 +606,18 @@ def diffeqsolve(
605606
pred = (t1 - t0) * dt0 < 0
606607
dt0 = eqxi.error_if(jnp.array(dt0), pred, msg)
607608

609+
# Error checking and warning for complex dtypes
610+
if any(jtu.tree_leaves(jtu.tree_map(jnp.iscomplexobj, y0))):
611+
if isinstance(solver, AbstractImplicitSolver):
612+
raise ValueError(
613+
"Implicit solvers in conjunction with complex dtypes is currently not "
614+
"supported."
615+
)
616+
warnings.warn(
617+
"Complex dtype support is work in progress, please read "
618+
"https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully."
619+
)
620+
608621
# Backward compatibility
609622
if isinstance(
610623
solver, (EulerHeun, ItoMilstein, StratonovichMilstein)
@@ -664,8 +677,10 @@ def _get_subsaveat_ts(saveat):
664677
)
665678

666679
# Time will affect state, so need to promote the state dtype as well if necessary.
680+
# fixing issue with float64 and weak dtypes, see discussion at:
681+
# https://github.com/patrick-kidger/diffrax/pull/197#discussion_r1130173527
667682
def _promote(yi):
668-
_dtype = jnp.result_type(yi, *timelikes) # noqa: F821
683+
_dtype = jnp.result_type(yi, dtype) # noqa: F821
669684
return jnp.asarray(yi, dtype=_dtype)
670685

671686
y0 = jtu.tree_map(_promote, y0)
@@ -759,7 +774,9 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
759774
save_index = 0
760775
ts = jnp.full(out_size, jnp.inf)
761776
struct = eqx.filter_eval_shape(subsaveat.fn, t0, y0, args)
762-
ys = jtu.tree_map(lambda y: jnp.full((out_size,) + y.shape, jnp.inf), struct)
777+
ys = jtu.tree_map(
778+
lambda y: jnp.full((out_size,) + y.shape, jnp.inf, dtype=y.dtype), struct
779+
)
763780
return SaveState(
764781
ts=ts, ys=ys, save_index=save_index, saveat_ts_index=saveat_ts_index
765782
)
@@ -779,7 +796,9 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
779796
solver.step, terms, tprev, tnext, y0, args, solver_state, made_jump
780797
)
781798
dense_ts = jnp.full(max_steps + 1, jnp.inf)
782-
_make_full = lambda x: jnp.full((max_steps,) + jnp.shape(x), jnp.inf)
799+
_make_full = lambda x: jnp.full(
800+
(max_steps,) + jnp.shape(x), jnp.inf, dtype=x.dtype
801+
)
783802
dense_infos = jtu.tree_map(_make_full, dense_info)
784803
dense_save_index = 0
785804
else:

diffrax/misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ def _rms_norm_jvp(x, tx):
107107
pred = (out == 0) | jnp.isinf(out)
108108
numerator = jnp.where(pred, 0, x)
109109
denominator = jnp.where(pred, 1, out * x.size)
110-
t_out = jnp.dot(numerator / denominator, tx)
111-
return out, t_out
110+
t_out = jnp.dot(numerator / denominator, jnp.conj(tx))
111+
return out, jnp.real(t_out)
112112

113113

114114
def adjoint_rms_seminorm(x: Tuple[PyTree, PyTree, PyTree, PyTree]) -> Scalar:

diffrax/nonlinear_solver/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,6 @@ def jac(fn: Callable, x: PyTree, args: PyTree) -> LU_Jacobian:
106106
if not jnp.issubdtype(flat, jnp.inexact):
107107
# Handle integer arguments
108108
flat = flat.astype(jnp.float32)
109-
return jsp.linalg.lu_factor(jax.jacfwd(curried)(flat))
109+
return jsp.linalg.lu_factor(
110+
jax.jacfwd(curried, holomorphic=jnp.iscomplexobj(flat))(flat)
111+
)

diffrax/step_size_controller/adaptive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def adapt_step_size(
424424
# ε_n = atol + norm(y) * rtol with y on the nth step
425425
# r_n = norm(y_error) with y_error on the nth step
426426
# δ_{n,m} = norm(y_error / (atol + norm(y) * rtol))^(-1) with y_error on the nth
427-
# step and y on the mth step
427+
# step and y on the mth step
428428
# β_1 = pcoeff + icoeff + dcoeff
429429
# β_2 = -(pcoeff + 2 * dcoeff)
430430
# β_3 = dcoeff

examples/hessian.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@
7878
"id": "a3ec6532-5b0a-4e4c-af33-bef58c0a7319",
7979
"metadata": {},
8080
"source": [
81-
"Note the use of the `scan_kind` argument to `Tsit5`. By default, Diffrax internally uses constructs that are optimised specifically for first-order reverse mode autodifferentiation. This argument is needed to switch to a different implementation that is compatible with higher-order autodiff. (In this case: for the loop-over-stages in the Runge--Kutta solver.)\n",
81+
"Note the use of the `scan_kind` argument to `Tsit5`. By default, Diffrax internally uses constructs that are optimised specifically for first-order reverse-mode autodifferentiation. This argument is needed to switch to a different implementation that is compatible with higher-order autodiff. (In this case: for the loop-over-stages in the Runge--Kutta solver.)\n",
8282
"\n",
83-
"In similar fashion, if using `saveat=SaveAt(steps=True)` then you will need to pass `adjoint=DirectAdjoint()`. (In this case: for the loop-over-saving output.)"
83+
"In similar fashion, if using `saveat=SaveAt(ts=...)` (or a handful of other esoteric cases) then you will need to pass `adjoint=DirectAdjoint()`. (In this case: for the loop-over-saving output.)"
8484
]
8585
}
8686
],

test/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ def implicit_tol(solver):
4444
return solver
4545

4646

47-
def random_pytree(key, treedef):
47+
def random_pytree(key, treedef, dtype=None):
4848
keys = jrandom.split(key, treedef.num_leaves)
4949
leaves = []
5050
for key in keys:
5151
dimkey, sizekey, valuekey = jrandom.split(key, 3)
5252
num_dims = jrandom.randint(dimkey, (), 0, 5)
5353
dim_sizes = jrandom.randint(sizekey, (num_dims,), 0, 5)
54-
value = jrandom.normal(valuekey, dim_sizes)
54+
value = jrandom.normal(valuekey, dim_sizes, dtype=dtype)
5555
leaves.append(value)
5656
return jtu.tree_unflatten(treedef, leaves)
5757

test/test_integrate.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _all_pairs(*args):
4141

4242

4343
@pytest.mark.parametrize(
44-
"solver,t_dtype,treedef,stepsize_controller",
44+
"solver,t_dtype,y_dtype,treedef,stepsize_controller",
4545
_all_pairs(
4646
dict(
4747
default=diffrax.Euler(),
@@ -58,21 +58,32 @@ def _all_pairs(*args):
5858
),
5959
),
6060
dict(default=jnp.float32, opts=(int, float, jnp.int32)),
61+
dict(default=jnp.float32, opts=(jnp.complex64,)),
6162
dict(default=treedefs[0], opts=treedefs[1:]),
6263
dict(
6364
default=diffrax.ConstantStepSize(),
6465
opts=(diffrax.PIDController(rtol=1e-3, atol=1e-6),),
6566
),
6667
),
6768
)
68-
def test_basic(solver, t_dtype, treedef, stepsize_controller, getkey):
69+
def test_basic(solver, t_dtype, y_dtype, treedef, stepsize_controller, getkey):
6970
if not isinstance(solver, diffrax.AbstractAdaptiveSolver) and isinstance(
7071
stepsize_controller, diffrax.PIDController
7172
):
7273
return
7374

74-
def f(t, y, args):
75-
return jtu.tree_map(operator.neg, y)
75+
if jnp.iscomplexobj(y_dtype):
76+
77+
def f(t, y, args):
78+
return jtu.tree_map(lambda _y: operator.mul(-1j, _y), y)
79+
80+
if isinstance(solver, diffrax.AbstractImplicitSolver):
81+
return
82+
83+
else:
84+
85+
def f(t, y, args):
86+
return jtu.tree_map(operator.neg, y)
7687

7788
if t_dtype is int:
7889
t0 = 0
@@ -92,7 +103,7 @@ def f(t, y, args):
92103
dt0 = jnp.array(0.01)
93104
else:
94105
raise ValueError
95-
y0 = random_pytree(getkey(), treedef)
106+
y0 = random_pytree(getkey(), treedef, dtype=y_dtype)
96107
try:
97108
sol = diffrax.diffeqsolve(
98109
diffrax.ODETerm(f),
@@ -113,7 +124,10 @@ def f(t, y, args):
113124
else:
114125
raise
115126
y1 = sol.ys
116-
true_y1 = jtu.tree_map(lambda x: (x * math.exp(-1))[None], y0)
127+
if jnp.iscomplexobj(y_dtype):
128+
true_y1 = jtu.tree_map(lambda x: (x * jnp.exp(-1j))[None], y0)
129+
else:
130+
true_y1 = jtu.tree_map(lambda x: (x * math.exp(-1))[None], y0)
117131
assert shaped_allclose(y1, true_y1, atol=1e-2, rtol=1e-2)
118132

119133

0 commit comments

Comments
 (0)