Skip to content

Commit

Permalink
add caching support
Browse files Browse the repository at this point in the history
  • Loading branch information
saienduri committed Apr 16, 2024
1 parent 07d8938 commit ec8e887
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
1 change: 1 addition & 0 deletions .github/workflows/test_iree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ jobs:
runs-on: nodai-amdgpu-w7900-x86-64
env:
VENV_DIR: ${{ github.workspace }}/.venv
IREE_TEST_FILES: ~/iree_tests_cache
steps:
- name: "Checking out repository"
uses: actions/checkout@v4
Expand Down
20 changes: 19 additions & 1 deletion iree_tests/download_remote_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import mmap
import pyjson5
import re
import os

THIS_DIR = Path(__file__).parent
REPO_ROOT = Path(__file__).parent.parent
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -81,14 +83,26 @@ def download_azure_remote_file(test_dir: Path, remote_file: str):
blob_size_str = human_readable_size(blob_properties.size)
remote_md5 = get_remote_md5(remote_file, blob_properties)

local_file_path = test_dir / remote_file_name
cache_location = os.getenv("IREE_TEST_FILES", default="")
if cache_location == "":
os.environ["IREE_TEST_FILES"] = str(REPO_ROOT)
cache_location = REPO_ROOT
if cache_location == REPO_ROOT:
local_dir_path = test_dir
local_file_path = test_dir / remote_file_name
else:
cache_location = os.path.expanduser(cache_location)
local_dir_path = Path(cache_location) / "iree_tests" / relative_dir
local_file_path = Path(cache_location) / "iree_tests" / relative_dir / remote_file_name

local_md5 = get_local_md5(local_file_path)

if remote_md5 and remote_md5 == local_md5:
logger.info(
f" Skipping '{remote_file_name}' download ({blob_size_str}) "
"- local MD5 hash matches"
)
os.symlink(local_file_path, test_dir / remote_file_name)
return

if not local_md5:
Expand All @@ -102,9 +116,13 @@ def download_azure_remote_file(test_dir: Path, remote_file: str):
f"to '{relative_dir}' (local MD5 does not match)"
)

if not os.path.isdir(local_dir_path):
os.makedirs(local_dir_path)
with open(local_file_path, mode="wb") as local_blob:
download_stream = blob_client.download_blob(max_concurrency=4)
local_blob.write(download_stream.readall())
if str(cache_location) != str(REPO_ROOT):
os.symlink(local_file_path, test_dir / remote_file_name)


def download_generic_remote_file(test_dir: Path, remote_file: str):
Expand Down

0 comments on commit ec8e887

Please sign in to comment.