Skip to content

Commit

Permalink
[JAX] [XLA:Python] Move JAX configuration objects into C++.
Browse files Browse the repository at this point in the history
A noticeable amount of time during JAX tracing is spent getting and setting the value of config.State objects, in particular the thread-local values within that state. If we move that logic into C++, we can speed up that code.

There are two main ways we can get a speedup:
* Python thread-local state is based around a dictionary and isn't terribly fast.
* we can have the C++ jit dispatch path directly access the configuration items it needs to include in its cache key. We spend a considerable amount of time in effect eagerly computing cache keys via update_thread_local_jit_state, although most of that is pointless work. Instead, we can have `jit` simply pull the config items it needs on demand.

PiperOrigin-RevId: 693114411
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Nov 4, 2024
1 parent 38b4d00 commit ab47d46
Show file tree
Hide file tree
Showing 7 changed files with 406 additions and 297 deletions.
4 changes: 2 additions & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def _update_debug_special_global(_):
jax_jit.global_state().post_hook = None

def _update_debug_special_thread_local(_):
if (getattr(config._thread_local_state, "jax_debug_nans", False) or
getattr(config._thread_local_state, "jax_debug_infs", False)):
if (config.debug_nans.get_local() == True or
config.debug_infs.get_local() == True):
jax_jit.thread_local_state().post_hook = _nan_check_posthook
else:
jax_jit.thread_local_state().post_hook = None
Expand Down
7 changes: 3 additions & 4 deletions jax/_src/compute_on.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __init__(self):
@contextmanager
def extend_compute_type(c_type: str):
compute_on_context.stack.append(c_type)
config.update_thread_local_jit_state(
compute_on_context_manager=tuple(compute_on_context.stack))
config.compute_on_context_manager.set_local(
tuple(compute_on_context.stack))
try:
if len(set(filter(lambda x: x is not None, set(compute_on_context.stack)))) > 1:
raise NotImplementedError(
Expand All @@ -39,8 +39,7 @@ def extend_compute_type(c_type: str):
yield compute_on_context.stack[-1]
finally:
compute_on_context.stack.pop()
config.update_thread_local_jit_state(
compute_on_context_manager=tuple(compute_on_context.stack))
config.compute_on_context_manager.set_local(tuple(compute_on_context.stack))

def current_compute_type() -> str | None:
return compute_on_context.stack[-1] if compute_on_context.stack else None
Expand Down
Loading

0 comments on commit ab47d46

Please sign in to comment.