Skip to content

Commit

Permalink
Merge pull request #7525 from hawkinsp:callback
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 389032835
  • Loading branch information
jax authors committed Aug 5, 2021
2 parents d2bd017 + a2a6084 commit 73dfd23
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions jax/experimental/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def callback_fun(fun : lu.WrappedFun, in_vals, callback, strip_calls):

@lu.transformation
def callback_subtrace(main, *in_vals, **params):
trace = CallbackTrace(main, core.cur_sublevel())
trace = main.with_cur_sublevel()
in_tracers = [CallbackTracer(trace, val) for val in in_vals]
outs = yield in_tracers, params
out_tracers = map(trace.full_raise, outs)
Expand All @@ -109,9 +109,8 @@ def callback_subtrace(main, *in_vals, **params):

@lu.transformation
def _callback_fun(callback, strip_calls, *in_vals, **params):
with core.new_main(CallbackTrace) as main:
main.callback = callback # NOTE: Is this OK?
main.strip_calls = strip_calls
with core.new_main(CallbackTrace, callback=callback,
strip_calls=strip_calls) as main:
out_vals = yield (main,) + in_vals, params
del main
yield out_vals
Expand Down Expand Up @@ -147,6 +146,11 @@ def full_lower(self):
return self

class CallbackTrace(Trace):
def __init__(self, *args, callback, strip_calls):
super().__init__(*args)
self.callback = callback
self.strip_calls = strip_calls

def pure(self, val):
return CallbackTracer(self, val)

Expand All @@ -160,13 +164,13 @@ def process_primitive(self, primitive, tracers, params):
if primitive in custom_callback_rules:
return custom_callback_rules[primitive](self, *tracers, **params)
vals_in = [t.val for t in tracers]
vals_out = self.main.callback(primitive, vals_in, params) # type: ignore
vals_out = self.callback(primitive, vals_in, params) # type: ignore
if primitive.multiple_results:
return [CallbackTracer(self, val) for val in vals_out]
return CallbackTracer(self, vals_out)

def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
if self.main.strip_calls: # type: ignore
if self.strip_calls: # type: ignore
return f.call_wrapped(*tracers)
vals_in = [t.val for t in tracers]
f = callback_subtrace(f, self.main)
Expand Down Expand Up @@ -206,8 +210,8 @@ def new_body(*vals):
out = body_fun(*vals)
out_carry, y = split_list(out, [num_carry])
return out_carry, y
main = trace.main
new_body = callback_transform(new_body, main.callback, strip_calls=main.strip_calls) # type: ignore
new_body = callback_transform(new_body, trace.callback,
strip_calls=trace.strip_calls) # type: ignore
in_tree = tree_structure(carry_avals + xs_avals)
new_jaxpr, new_consts, _ = lcf._initial_style_jaxpr(
new_body, in_tree, tuple(carry_avals + x_avals))
Expand Down Expand Up @@ -237,9 +241,8 @@ def cond(*carry):
def body(*carry):
return body_fun(*it.chain(body_const_vals, carry))

main = trace.main
new_cond = callback_transform(cond, main.callback, strip_calls=main.strip_calls) # type: ignore
new_body = callback_transform(body, main.callback, strip_calls=main.strip_calls) # type: ignore
new_cond = callback_transform(cond, trace.callback, strip_calls=trace.strip_calls) # type: ignore
new_body = callback_transform(body, trace.callback, strip_calls=trace.strip_calls) # type: ignore
in_tree = tree_structure(init_avals)

new_cond_jaxpr, new_cond_consts, _ = lcf._initial_style_jaxpr(new_cond, in_tree, tuple(init_avals))
Expand All @@ -259,7 +262,7 @@ def _custom_derivative_call_jaxpr_callback_rule(primitive, trace, *tracers,
main = trace.main
vals = [t.val for t in tracers]

new_closed_jaxpr = callback_jaxpr(fun_jaxpr, main.callback, strip_calls=main.strip_calls)
new_closed_jaxpr = callback_jaxpr(fun_jaxpr, trace.callback, strip_calls=trace.strip_calls)
if primitive == cd.custom_jvp_call_jaxpr_p:
thunk_name = 'jvp_jaxpr_thunk'
elif primitive == cd.custom_vjp_call_jaxpr_p:
Expand All @@ -272,7 +275,7 @@ def _custom_derivative_call_jaxpr_callback_rule(primitive, trace, *tracers,
@pe._memoize
def new_thunk():
thunk_jaxpr = core.ClosedJaxpr(*thunk())
closed_jaxpr = callback_jaxpr(thunk_jaxpr, main.callback, main.strip_calls)
closed_jaxpr = callback_jaxpr(thunk_jaxpr, trace.callback, trace.strip_calls)
return closed_jaxpr.jaxpr, closed_jaxpr.literals

params[thunk_name] = new_thunk
Expand Down

0 comments on commit 73dfd23

Please sign in to comment.