From a2a60844cbdb420c973db2f8d023daab00cc3fa3 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 5 Aug 2021 17:31:11 -0400 Subject: [PATCH] Use MainTrace payload mechanism in experimental callback tracer. A change that adds jit() decorators on a number of standard library functions was triggering incorrect cache hits for these tests. This is because the payload fields of the MainTrace were not being included in __hash__() and __eq__(). --- jax/experimental/callback.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/jax/experimental/callback.py b/jax/experimental/callback.py index 125d3f5d79e0..e3ec8e7dc4eb 100644 --- a/jax/experimental/callback.py +++ b/jax/experimental/callback.py @@ -100,7 +100,7 @@ def callback_fun(fun : lu.WrappedFun, in_vals, callback, strip_calls): @lu.transformation def callback_subtrace(main, *in_vals, **params): - trace = CallbackTrace(main, core.cur_sublevel()) + trace = main.with_cur_sublevel() in_tracers = [CallbackTracer(trace, val) for val in in_vals] outs = yield in_tracers, params out_tracers = map(trace.full_raise, outs) @@ -109,9 +109,8 @@ def callback_subtrace(main, *in_vals, **params): @lu.transformation def _callback_fun(callback, strip_calls, *in_vals, **params): - with core.new_main(CallbackTrace) as main: - main.callback = callback # NOTE: Is this OK? - main.strip_calls = strip_calls + with core.new_main(CallbackTrace, callback=callback, + strip_calls=strip_calls) as main: out_vals = yield (main,) + in_vals, params del main yield out_vals @@ -147,6 +146,11 @@ def full_lower(self): return self class CallbackTrace(Trace): + def __init__(self, *args, callback, strip_calls): + super().__init__(*args) + self.callback = callback + self.strip_calls = strip_calls + def pure(self, val): return CallbackTracer(self, val) @@ -160,13 +164,13 @@ def process_primitive(self, primitive, tracers, params): if primitive in custom_callback_rules: return custom_callback_rules[primitive](self, *tracers, **params) vals_in = [t.val for t in tracers] - vals_out = self.main.callback(primitive, vals_in, params) # type: ignore + vals_out = self.callback(primitive, vals_in, params) # type: ignore if primitive.multiple_results: return [CallbackTracer(self, val) for val in vals_out] return CallbackTracer(self, vals_out) def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): - if self.main.strip_calls: # type: ignore + if self.strip_calls: # type: ignore return f.call_wrapped(*tracers) vals_in = [t.val for t in tracers] f = callback_subtrace(f, self.main) @@ -206,8 +210,8 @@ def new_body(*vals): out = body_fun(*vals) out_carry, y = split_list(out, [num_carry]) return out_carry, y - main = trace.main - new_body = callback_transform(new_body, main.callback, strip_calls=main.strip_calls) # type: ignore + new_body = callback_transform(new_body, trace.callback, + strip_calls=trace.strip_calls) # type: ignore in_tree = tree_structure(carry_avals + xs_avals) new_jaxpr, new_consts, _ = lcf._initial_style_jaxpr( new_body, in_tree, tuple(carry_avals + x_avals)) @@ -237,9 +241,8 @@ def cond(*carry): def body(*carry): return body_fun(*it.chain(body_const_vals, carry)) - main = trace.main - new_cond = callback_transform(cond, main.callback, strip_calls=main.strip_calls) # type: ignore - new_body = callback_transform(body, main.callback, strip_calls=main.strip_calls) # type: ignore + new_cond = callback_transform(cond, trace.callback, strip_calls=trace.strip_calls) # type: ignore + new_body = callback_transform(body, trace.callback, strip_calls=trace.strip_calls) # type: ignore in_tree = tree_structure(init_avals) new_cond_jaxpr, new_cond_consts, _ = lcf._initial_style_jaxpr(new_cond, in_tree, tuple(init_avals)) @@ -259,7 +262,7 @@ def _custom_derivative_call_jaxpr_callback_rule(primitive, trace, *tracers, main = trace.main vals = [t.val for t in tracers] - new_closed_jaxpr = callback_jaxpr(fun_jaxpr, main.callback, strip_calls=main.strip_calls) + new_closed_jaxpr = callback_jaxpr(fun_jaxpr, trace.callback, strip_calls=trace.strip_calls) if primitive == cd.custom_jvp_call_jaxpr_p: thunk_name = 'jvp_jaxpr_thunk' elif primitive == cd.custom_vjp_call_jaxpr_p: @@ -272,7 +275,7 @@ def _custom_derivative_call_jaxpr_callback_rule(primitive, trace, *tracers, @pe._memoize def new_thunk(): thunk_jaxpr = core.ClosedJaxpr(*thunk()) - closed_jaxpr = callback_jaxpr(thunk_jaxpr, main.callback, main.strip_calls) + closed_jaxpr = callback_jaxpr(thunk_jaxpr, trace.callback, trace.strip_calls) return closed_jaxpr.jaxpr, closed_jaxpr.literals params[thunk_name] = new_thunk