You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a function that takes significantly longer to execute when running on Ray compared to running it directly. Then I took the function from the Ray worker and ran it directly, and I found that the execution time was affected by the time.sleep duration. This function represents one step of network training, and it has already been compiled into an executable file that can be directly loaded. The input data has also been saved. The reproducible code is as follows:
Note:
The executable file was compiled in an NVIDIA A100 80G environment, and only one GPU is used.
All files are here.
importosos.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] ="platform"importjaxjax.config.update("jax_platform_name", "gpu")
fromjax.libimportxla_bridgeasxbimportpicklebackend=xb.get_backend("cuda")
withopen("loaded_exe", "rb") asf:
a=f.read()
compiled=backend.deserialize_executable(a)
input_arr= []
foriinrange(10):
withopen("arrays.pkl", "rb") asf:
input_array0=pickle.load(f)
input_arr.append(input_array0)
importtimetotal_time=0.0foriinrange(1, 3000):
# The following line will affect the execution time of execute_sharded_on_local_devices.# Different time.sleep arg(eg. 0.3,0.8) has different effect.# time.sleep(0.8)_t=time.time()
out_=compiled.execute_sharded_on_local_devices(input_arr[i%10])
out_[0][0].block_until_ready()
tmp=time.time() -_ttotal_time+=tmpprint(f"{i=}: avg time: {total_time/i}, cur time: {tmp}")
print(out_[0])
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered:
Description
I have a function that takes significantly longer to execute when running on Ray compared to running it directly. Then I took the function from the Ray worker and ran it directly, and I found that the execution time was affected by the time.sleep duration. This function represents one step of network training, and it has already been compiled into an executable file that can be directly loaded. The input data has also been saved. The reproducible code is as follows:
Note:
The executable file was compiled in an NVIDIA A100 80G environment, and only one GPU is used.
All files are here.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: