diff --git a/warp/tests/test_async.py b/warp/tests/test_async.py index f5d541f51..ea9a2c34f 100644 --- a/warp/tests/test_async.py +++ b/warp/tests/test_async.py @@ -21,7 +21,9 @@ def __init__(self, use_graph=True, stream=None): def __enter__(self): if self.use_graph: - wp.capture_begin(stream=self.stream) + # preload module before graph capture + wp.load_module(device=wp.get_device()) + wp.capture_begin(stream=self.stream, force_module_load=False) def __exit__(self, exc_type, exc_value, traceback): if self.use_graph: diff --git a/warp/tests/test_reload.py b/warp/tests/test_reload.py index d5550f172..9f7312448 100644 --- a/warp/tests/test_reload.py +++ b/warp/tests/test_reload.py @@ -197,8 +197,11 @@ def foo(a: wp.array(dtype=int)): with wp.ScopedDevice(device): a = wp.zeros(1, dtype=int) + # preload module before graph capture + wp.load_module(device=device) + # capture a launch - with wp.ScopedCapture() as capture: + with wp.ScopedCapture(force_module_load=False) as capture: wp.launch(foo, dim=1, inputs=[a]) # unload the module