Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JAX] [XLA:Python] Move JAX configuration objects into C++.
A noticeable amount of time during JAX tracing is spent getting and setting the value of config.State objects, in particular the thread-local values within that state. If we move that logic into C++, we can speed up that code. There are two main ways we can get a speedup: * Python thread-local state is based around a dictionary and isn't terribly fast. * we can have the C++ jit dispatch path directly access the configuration items it needs to include in its cache key. We spend a considerable amount of time in effect eagerly computing cache keys via update_thread_local_jit_state, although most of that is pointless work. Instead, we can have `jit` simply pull the config items it needs on demand. PiperOrigin-RevId: 693114411
- Loading branch information