diff --git a/gsplat/cuda/_backend.py b/gsplat/cuda/_backend.py index 8fcc441cb..0e4f19e6e 100644 --- a/gsplat/cuda/_backend.py +++ b/gsplat/cuda/_backend.py @@ -33,23 +33,7 @@ def cuda_toolkit_version(): return cuda_version -name = "gsplat_cuda" -build_dir = _get_build_directory(name, verbose=False) -extra_include_paths = [os.path.join(PATH, "csrc/third_party/glm")] -extra_cflags = ["-O3"] -extra_cuda_cflags = ["-O3"] - _C = None -sources = list(glob.glob(os.path.join(PATH, "csrc/*.cu"))) + list( - glob.glob(os.path.join(PATH, "csrc/*.cpp")) -) -# sources = [ -# os.path.join(PATH, "csrc/ext.cpp"), -# os.path.join(PATH, "csrc/rasterize.cu"), -# os.path.join(PATH, "csrc/bindings.cu"), -# os.path.join(PATH, "csrc/forward.cu"), -# os.path.join(PATH, "csrc/backward.cu"), -# ] try: # try to import the compiled module (via setup.py) @@ -57,6 +41,15 @@ def cuda_toolkit_version(): except ImportError: # if failed, try with JIT compilation if cuda_toolkit_available(): + name = "gsplat_cuda" + build_dir = _get_build_directory(name, verbose=False) + sources = list(glob.glob(os.path.join(PATH, "csrc/*.cu"))) + list( + glob.glob(os.path.join(PATH, "csrc/*.cpp")) + ) + extra_include_paths = [os.path.join(PATH, "csrc/third_party/glm")] + extra_cflags = ["-O3"] + extra_cuda_cflags = ["-O3"] + # If JIT is interrupted it might leave a lock in the build directory. # We dont want it to exist in any case. try: