Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

time.sleep affects the execution time of JAX #24941

Open
horse6 opened this issue Nov 18, 2024 · 1 comment
Open

time.sleep affects the execution time of JAX #24941

horse6 opened this issue Nov 18, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@horse6
Copy link

horse6 commented Nov 18, 2024

Description

I have a function that takes significantly longer to execute when running on Ray compared to running it directly. Then I took the function from the Ray worker and ran it directly, and I found that the execution time was affected by the time.sleep duration. This function represents one step of network training, and it has already been compiled into an executable file that can be directly loaded. The input data has also been saved. The reproducible code is as follows:
Note:
The executable file was compiled in an NVIDIA A100 80G environment, and only one GPU is used.
All files are here.

import os
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

import jax
jax.config.update("jax_platform_name", "gpu")
from jax.lib import xla_bridge as xb
import pickle
backend = xb.get_backend("cuda")

with open("loaded_exe", "rb") as f:
    a = f.read()
    compiled = backend.deserialize_executable(a)

input_arr = []
for i in range(10):
    with open("arrays.pkl", "rb") as f:
        input_array0 = pickle.load(f)
        input_arr.append(input_array0)

import time
total_time = 0.0
for i in range(1, 3000):
    # The following line will affect the execution time of execute_sharded_on_local_devices.
    # Different time.sleep arg(eg. 0.3,0.8) has different effect.
    # time.sleep(0.8)
    _t = time.time()
    out_ = compiled.execute_sharded_on_local_devices(input_arr[i%10])
    out_[0][0].block_until_ready()
    tmp = time.time() - _t
    total_time += tmp
    print(f"{i=}: avg time: {total_time/i}, cur time: {tmp}")
print(out_[0])

System info (python version, jaxlib version, accelerator, etc.)

captrue

@horse6 horse6 added the bug Something isn't working label Nov 18, 2024
@superbobry
Copy link
Collaborator

This looks WAI to me. You would want the time.sleep() call before

out_ = compiled.execute_sharded_on_local_devices(input_arr[i%10])

and

out_[0][0].block_until_ready()

in order for the sleeping to be done in parallel with compute.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants