From 60575156719c1364131cacc5074f32d98d51f903 Mon Sep 17 00:00:00 2001 From: harshithapv <54084812+harshithapv@users.noreply.github.com> Date: Wed, 30 Jun 2021 20:04:16 -0700 Subject: [PATCH] Propagate ROCM version to onnxruntime wheel package (#8247) (#8250) Co-authored-by: Thiago Crepaldi --- setup.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 49412488331b0..fb70bdacbd2ce 100644 --- a/setup.py +++ b/setup.py @@ -424,7 +424,7 @@ def get_torch_version(): install_requires = f.read().splitlines() if enable_training: - def save_build_and_package_info(package_name, version_number, cuda_version): + def save_build_and_package_info(package_name, version_number, cuda_version, rocm_version): sys.path.append(path.join(path.dirname(__file__), 'onnxruntime', 'python')) from onnxruntime_collect_build_info import find_cudart_versions @@ -446,11 +446,10 @@ def save_build_and_package_info(package_name, version_number, cuda_version): "did not find any cudart library" if not cudart_versions or len(cudart_versions) == 0 else "found multiple cudart libraries") - else: - # TODO: rocm - pass + elif rocm_version: + f.write("rocm_version = '{}'\n".format(rocm_version)) - save_build_and_package_info(package_name, version_number, cuda_version) + save_build_and_package_info(package_name, version_number, cuda_version, rocm_version) # Setup setup(