From 6b0c7cd6720da34b2d1064845b04a94dea89a6a7 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Thu, 12 Dec 2024 16:56:49 -0500 Subject: [PATCH 1/2] add flag to restore old (jax<0.4.36) tracing behavior --- tests/test_1447_jax_autodiff_slices_ufuncs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_1447_jax_autodiff_slices_ufuncs.py b/tests/test_1447_jax_autodiff_slices_ufuncs.py index 9d3b049e30..ae8d1584b7 100644 --- a/tests/test_1447_jax_autodiff_slices_ufuncs.py +++ b/tests/test_1447_jax_autodiff_slices_ufuncs.py @@ -10,6 +10,7 @@ jax = pytest.importorskip("jax") jax.config.update("jax_platform_name", "cpu") jax.config.update("jax_enable_x64", True) +jax.config.update("jax_data_dependent_tracing_fallback", True) ak.jax.register_and_check() From 0dbf0c33caa0c17c88247639661b41271de8aaa4 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Thu, 12 Dec 2024 17:04:27 -0500 Subject: [PATCH 2/2] enable flag only for jax version of 0.4.36 or higher --- tests/test_1447_jax_autodiff_slices_ufuncs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_1447_jax_autodiff_slices_ufuncs.py b/tests/test_1447_jax_autodiff_slices_ufuncs.py index ae8d1584b7..875757c669 100644 --- a/tests/test_1447_jax_autodiff_slices_ufuncs.py +++ b/tests/test_1447_jax_autodiff_slices_ufuncs.py @@ -4,13 +4,15 @@ import numpy as np import pytest +from packaging.version import parse as parse_version import awkward as ak jax = pytest.importorskip("jax") jax.config.update("jax_platform_name", "cpu") jax.config.update("jax_enable_x64", True) -jax.config.update("jax_data_dependent_tracing_fallback", True) +if parse_version(jax.__version__) >= parse_version("0.4.36"): + jax.config.update("jax_data_dependent_tracing_fallback", True) ak.jax.register_and_check()