Skip to content

Commit

Permalink
check linux before calling find_cudart_versions. Also remove if has_o…
Browse files Browse the repository at this point in the history
…rtmodule
  • Loading branch information
jchen351 committed Dec 22, 2024
1 parent ad0cf6b commit fad62cb
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 41 deletions.
59 changes: 30 additions & 29 deletions onnxruntime/python/onnxruntime_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,32 +99,33 @@ def validate_build_package_info():
version = ""
cuda_version = ""

if has_ortmodule:
try:
# collect onnxruntime package name, version, and cuda version
from .build_and_package_info import __version__ as version
from .build_and_package_info import package_name
try:
# collect onnxruntime package name, version, and cuda version
from .build_and_package_info import __version__ as version
from .build_and_package_info import package_name

try: # noqa: SIM105
from .build_and_package_info import cuda_version
except Exception:

Check notice

Code scanning / CodeQL

Empty except Note

'except' clause does nothing but pass and there is no explanatory comment.
pass

try: # noqa: SIM105
from .build_and_package_info import cuda_version
if cuda_version:
# collect cuda library build info. the library info may not be available
# when the build environment has none or multiple libraries installed
try:
from .build_and_package_info import cudart_version
except Exception:
pass

if cuda_version:
# collect cuda library build info. the library info may not be available
# when the build environment has none or multiple libraries installed
try:
from .build_and_package_info import cudart_version
except Exception:
warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.")
cudart_version = None

def print_build_package_info():
warnings.warn(f"onnxruntime training package info: package_name: {package_name}")
warnings.warn(f"onnxruntime training package info: __version__: {version}")
warnings.warn(f"onnxruntime training package info: cuda_version: {cuda_version}")
warnings.warn(f"onnxruntime build info: cudart_version: {cudart_version}")
warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.")
cudart_version = None

def print_build_package_info():
warnings.warn(f"onnxruntime training package info: package_name: {package_name}")
warnings.warn(f"onnxruntime training package info: __version__: {version}")
warnings.warn(f"onnxruntime training package info: cuda_version: {cuda_version}")
warnings.warn(f"onnxruntime build info: cudart_version: {cudart_version}")

# Cudart only available on Linux
if platform.system().lower() == "linux":
# collection cuda library info from current environment.
from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions

Expand All @@ -133,13 +134,13 @@ def print_build_package_info():
print_build_package_info()
warnings.warn("WARNING: failed to find cudart version that matches onnxruntime build info")
warnings.warn(f"WARNING: found cudart versions: {local_cudart_versions}")
else:
# TODO: rcom
pass
else:
# TODO: rcom
pass

except Exception as e:
warnings.warn("WARNING: failed to collect onnxruntime version and build info")
print(e)
except Exception as e:
warnings.warn("WARNING: failed to collect onnxruntime version and build info")
print(e)

if import_ortmodule_exception:
raise import_ortmodule_exception
Expand Down
25 changes: 13 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,18 +758,19 @@ def save_build_and_package_info(package_name, version_number, cuda_version, rocm
f.write(f"cuda_version = '{cuda_version}'\n")

# cudart_versions are integers
cudart_versions = find_cudart_versions(build_env=True)
if cudart_versions and len(cudart_versions) == 1:
f.write(f"cudart_version = {cudart_versions[0]}\n")
else:
print(
"Error getting cudart version. ",
(
"did not find any cudart library"
if not cudart_versions or len(cudart_versions) == 0
else "found multiple cudart libraries"
),
)
if platform.system().lower() == "linux":
cudart_versions = find_cudart_versions(build_env=True)
if cudart_versions and len(cudart_versions) == 1:
f.write(f"cudart_version = {cudart_versions[0]}\n")
else:
print(
"Error getting cudart version. ",
(
"did not find any cudart library"
if not cudart_versions or len(cudart_versions) == 0
else "found multiple cudart libraries"
),
)
elif rocm_version:
f.write(f"rocm_version = '{rocm_version}'\n")

Expand Down

0 comments on commit fad62cb

Please sign in to comment.