diff --git a/tests/test_1447_jax_autodiff_slices_ufuncs.py b/tests/test_1447_jax_autodiff_slices_ufuncs.py index 9d3b049e30..875757c669 100644 --- a/tests/test_1447_jax_autodiff_slices_ufuncs.py +++ b/tests/test_1447_jax_autodiff_slices_ufuncs.py @@ -4,12 +4,15 @@ import numpy as np import pytest +from packaging.version import parse as parse_version import awkward as ak jax = pytest.importorskip("jax") jax.config.update("jax_platform_name", "cpu") jax.config.update("jax_enable_x64", True) +if parse_version(jax.__version__) >= parse_version("0.4.36"): + jax.config.update("jax_data_dependent_tracing_fallback", True) ak.jax.register_and_check()