Skip to content

Commit d6d09dc

Browse files
Tweaked warnings
1 parent 754dd79 commit d6d09dc

File tree

4 files changed

+5
-9
lines changed

4 files changed

+5
-9
lines changed

diffrax/_integrate.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,6 @@ def diffeqsolve(
927927
"`diffrax.diffeqsolve(..., discrete_terminating_event=...)` is deprecated "
928928
"in favour of the more general `diffrax.diffeqsolve(..., event=...)` "
929929
"interface. This will be removed in some future version of Diffrax.",
930-
category=DeprecationWarning,
931930
stacklevel=2,
932931
)
933932
if event is None:

diffrax/_step_size_controller/adaptive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ def _scale(_y0, _y1_candidate, _y_error):
610610
# a grad API boundary as part of a larger model.)
611611
factor = lax.stop_gradient(factor)
612612
factor = eqxi.nondifferentiable(factor)
613-
dt = prev_dt * factor.astype(prev_dt)
613+
dt = prev_dt * factor.astype(jnp.result_type(prev_dt))
614614

615615
# E.g. we failed an implicit step, so y_error=inf, so inv_scaled_error=0,
616616
# so factor=factormin, and we shrunk our step.

test/test_event.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def event_fn(state, **kwargs):
2424
return state.tprev > 10
2525

2626
event = diffrax.DiscreteTerminatingEvent(event_fn)
27-
with pytest.warns(DeprecationWarning, match="discrete_terminating_event"):
27+
with pytest.warns(match="discrete_terminating_event"):
2828
sol = diffrax.diffeqsolve(
2929
term,
3030
solver,
@@ -51,7 +51,7 @@ def event_fn(state, **kwargs):
5151
return state.tprev > 10
5252

5353
event = diffrax.DiscreteTerminatingEvent(event_fn)
54-
with pytest.warns(DeprecationWarning, match="discrete_terminating_event"):
54+
with pytest.warns(match="discrete_terminating_event"):
5555
sol = diffrax.diffeqsolve(
5656
term,
5757
solver,
@@ -82,7 +82,7 @@ def event_fn(state, **kwargs):
8282
@jax.jit
8383
@jax.grad
8484
def run(y0):
85-
with pytest.warns(DeprecationWarning, match="discrete_terminating_event"):
85+
with pytest.warns(match="discrete_terminating_event"):
8686
sol = diffrax.diffeqsolve(
8787
term,
8888
solver,

test/test_term.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,7 @@ def __call__(self, t, y, args):
154154

155155

156156
def test_weaklydiagonal_deprecate():
157-
with pytest.warns(
158-
DeprecationWarning,
159-
match="WeaklyDiagonalControlTerm is pending deprecation",
160-
):
157+
with pytest.warns(match="WeaklyDiagonalControlTerm"):
161158
_ = diffrax.WeaklyDiagonalControlTerm(
162159
lambda t, y, args: 0.0, lambda t0, t1: jnp.array(t1 - t0)
163160
)

0 commit comments

Comments
 (0)