diff --git a/equinox/_jit.py b/equinox/_jit.py index ae67ec00..c6586609 100644 --- a/equinox/_jit.py +++ b/equinox/_jit.py @@ -165,10 +165,8 @@ def __call__(self, /, *args, **kwargs): if len(e.args) != 1 or not isinstance(e.args[0], str): raise # No idea if this ever happens. But if it does, just bail. (msg,) = e.args - prefix = "INTERNAL: Generated function failed: CpuCallback error: EqxRuntimeError: " # noqa: E501 - is_eqx_error = msg.startswith(prefix) - if is_eqx_error: - msg = msg.removeprefix(prefix) + if "EqxRuntimeError: " in msg: + _, msg = msg.split("EqxRuntimeError: ", 1) msg, _ = msg.rsplit("\n\nAt:\n", 1) msg = msg + _eqx_on_error_msg e.args = (msg,)