diff --git a/onnxruntime/python/onnxruntime_validation.py b/onnxruntime/python/onnxruntime_validation.py index 4f29c7f424845..167f7976aecca 100644 --- a/onnxruntime/python/onnxruntime_validation.py +++ b/onnxruntime/python/onnxruntime_validation.py @@ -99,32 +99,33 @@ def validate_build_package_info(): version = "" cuda_version = "" - if has_ortmodule: - try: - # collect onnxruntime package name, version, and cuda version - from .build_and_package_info import __version__ as version - from .build_and_package_info import package_name + try: + # collect onnxruntime package name, version, and cuda version + from .build_and_package_info import __version__ as version + from .build_and_package_info import package_name + + try: # noqa: SIM105 + from .build_and_package_info import cuda_version + except Exception: + pass - try: # noqa: SIM105 - from .build_and_package_info import cuda_version + if cuda_version: + # collect cuda library build info. the library info may not be available + # when the build environment has none or multiple libraries installed + try: + from .build_and_package_info import cudart_version except Exception: - pass - - if cuda_version: - # collect cuda library build info. the library info may not be available - # when the build environment has none or multiple libraries installed - try: - from .build_and_package_info import cudart_version - except Exception: - warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.") - cudart_version = None - - def print_build_package_info(): - warnings.warn(f"onnxruntime training package info: package_name: {package_name}") - warnings.warn(f"onnxruntime training package info: __version__: {version}") - warnings.warn(f"onnxruntime training package info: cuda_version: {cuda_version}") - warnings.warn(f"onnxruntime build info: cudart_version: {cudart_version}") + warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.") + cudart_version = None + + def print_build_package_info(): + warnings.warn(f"onnxruntime training package info: package_name: {package_name}") + warnings.warn(f"onnxruntime training package info: __version__: {version}") + warnings.warn(f"onnxruntime training package info: cuda_version: {cuda_version}") + warnings.warn(f"onnxruntime build info: cudart_version: {cudart_version}") + # Cudart only available on Linux + if platform.system().lower() == "linux": # collection cuda library info from current environment. from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions @@ -133,13 +134,13 @@ def print_build_package_info(): print_build_package_info() warnings.warn("WARNING: failed to find cudart version that matches onnxruntime build info") warnings.warn(f"WARNING: found cudart versions: {local_cudart_versions}") - else: - # TODO: rcom - pass + else: + # TODO: rcom + pass - except Exception as e: - warnings.warn("WARNING: failed to collect onnxruntime version and build info") - print(e) + except Exception as e: + warnings.warn("WARNING: failed to collect onnxruntime version and build info") + print(e) if import_ortmodule_exception: raise import_ortmodule_exception diff --git a/setup.py b/setup.py index 277792ec12f15..2c9a8a5600401 100644 --- a/setup.py +++ b/setup.py @@ -758,18 +758,19 @@ def save_build_and_package_info(package_name, version_number, cuda_version, rocm f.write(f"cuda_version = '{cuda_version}'\n") # cudart_versions are integers - cudart_versions = find_cudart_versions(build_env=True) - if cudart_versions and len(cudart_versions) == 1: - f.write(f"cudart_version = {cudart_versions[0]}\n") - else: - print( - "Error getting cudart version. ", - ( - "did not find any cudart library" - if not cudart_versions or len(cudart_versions) == 0 - else "found multiple cudart libraries" - ), - ) + if platform.system().lower() == "linux": + cudart_versions = find_cudart_versions(build_env=True) + if cudart_versions and len(cudart_versions) == 1: + f.write(f"cudart_version = {cudart_versions[0]}\n") + else: + print( + "Error getting cudart version. ", + ( + "did not find any cudart library" + if not cudart_versions or len(cudart_versions) == 0 + else "found multiple cudart libraries" + ), + ) elif rocm_version: f.write(f"rocm_version = '{rocm_version}'\n")