Skip to content

Commit

Permalink
Refactor how download_remote_files handles cache directories. (#200)
Browse files Browse the repository at this point in the history
A few improvements for local workflows:

* Allow passing the cache directory via either an environment variable
or a CLI flag
* Delete preexisting files before writing out a symlink to fix
`FileExistsError: [WinError 183] Cannot create a file when that file
already exists` on Windows
* Pull code into a helper function so it can be shared between Azure
file downloading and generic (GCS/Huggingface/etc.) file downloading
* Update a few variable names and docstrings
  • Loading branch information
ScottTodd authored Apr 23, 2024
1 parent cc020eb commit 2dba6a8
Showing 1 changed file with 93 additions and 44 deletions.
137 changes: 93 additions & 44 deletions iree_tests/download_remote_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

from azure.storage.blob import BlobClient, BlobProperties
from pathlib import Path
from typing import Optional
import argparse
import hashlib
import logging
import mmap
import os
import pyjson5
import re
import os

THIS_DIR = Path(__file__).parent
REPO_ROOT = Path(__file__).parent.parent
Expand All @@ -27,20 +28,46 @@ def human_readable_size(size, decimal_places=2):
return f"{size:.{decimal_places}f} {unit}"


def get_remote_md5(remote_file: str, blob_properties: BlobProperties):
content_settings = blob_properties.get("content_settings")
def setup_cache_symlink_if_needed(
cache_dir: Optional[Path], local_dir: Path, file_name: str
):
"""Creates a symlink from local_dir/file_name to cache_dir/file_name."""
if not cache_dir:
return

local_file_path = local_dir / file_name
cache_file_path = cache_dir / file_name
if local_file_path.is_symlink():
if os.path.samefile(str(local_file_path), str(cache_file_path)):
# Symlink matches, no need to recreate.
return
os.remove(local_file_path)
elif local_file_path.exists():
logger.warning(
f" Local file '{local_file_path}' exists but cache_dir is set. Deleting and "
"replacing with a symlink"
)
os.remove(local_file_path)
os.symlink(cache_file_path, local_file_path)
logger.info(f" Created symlink for '{local_file_path}' to '{cache_file_path}'")


def get_azure_md5(remote_file: str, azure_blob_properties: BlobProperties):
"""Gets the content_md5 hash for a blob on Azure, if available."""
content_settings = azure_blob_properties.get("content_settings")
if not content_settings:
return None
remote_md5 = content_settings.get("content_md5")
if not remote_md5:
azure_md5 = content_settings.get("content_md5")
if not azure_md5:
logger.warning(
f" Remote file '{remote_file}' on Azure is missing the "
"'content_md5' property, can't check if local matches remote"
)
return remote_md5
return azure_md5


def get_local_md5(local_file_path: Path):
"""Gets the content_md5 hash for a lolca file, if it exists."""
if not local_file_path.exists() or local_file_path.stat().st_size == 0:
return None

Expand All @@ -50,7 +77,15 @@ def get_local_md5(local_file_path: Path):
return hashlib.md5(file).digest()


def download_azure_remote_file(test_dir: Path, remote_file: str):
def download_azure_remote_file(
remote_file: str, test_dir: Path, cache_dir: Optional[Path]
):
"""
Downloads a file from Azure into test_dir.
If cache_dir is set, downloads there instead, creating a symlink from
test_dir/file_name to cache_dir/file_name.
"""
remote_file_name = remote_file.rsplit("/", 1)[-1]
relative_dir = test_dir.relative_to(THIS_DIR)

Expand All @@ -63,10 +98,6 @@ def download_azure_remote_file(test_dir: Path, remote_file: str):
# account_url: https://sharkpublic.blob.core.windows.net
# container_name: sharkpublic
# blob_name: path/to/blob.txt
#
# Note: we could also use the generic handler (e.g. wget, 'requests'), but
# the client library offers other APIs.

result = re.search(r"(https.+\.net)/([^/]+)/(.+)", remote_file)
account_url = result.groups()[0]
container_name = result.groups()[1]
Expand All @@ -81,31 +112,20 @@ def download_azure_remote_file(test_dir: Path, remote_file: str):
) as blob_client:
blob_properties = blob_client.get_blob_properties()
blob_size_str = human_readable_size(blob_properties.size)
remote_md5 = get_remote_md5(remote_file, blob_properties)

cache_location = os.getenv("IREE_TEST_FILES", default="")
if cache_location == "":
os.environ["IREE_TEST_FILES"] = str(REPO_ROOT)
cache_location = REPO_ROOT
if cache_location == REPO_ROOT:
local_dir_path = test_dir
local_file_path = test_dir / remote_file_name
else:
cache_location = Path(os.path.expanduser(cache_location)).resolve()
local_dir_path = cache_location / "iree_tests" / relative_dir
local_file_path = cache_location / "iree_tests" / relative_dir / remote_file_name
azure_md5 = get_azure_md5(remote_file, blob_properties)

if cache_dir:
local_file_path = cache_dir / remote_file_name
else:
local_file_path = test_dir / remote_file_name
local_md5 = get_local_md5(local_file_path)

if remote_md5 and remote_md5 == local_md5:
if azure_md5 and azure_md5 == local_md5:
logger.info(
f" Skipping '{remote_file_name}' download ({blob_size_str}) "
"- local MD5 hash matches"
)
os.symlink(local_file_path, test_dir / remote_file_name)
logger.info(
f" Created symlink for '{local_file_path}' to '{test_dir / remote_file_name}'"
)
setup_cache_symlink_if_needed(cache_dir, test_dir, remote_file_name)
return

if not local_md5:
Expand All @@ -119,24 +139,32 @@ def download_azure_remote_file(test_dir: Path, remote_file: str):
f"to '{relative_dir}' (local MD5 does not match)"
)

if not os.path.isdir(local_dir_path):
os.makedirs(local_dir_path)
with open(local_file_path, mode="wb") as local_blob:
download_stream = blob_client.download_blob(max_concurrency=4)
local_blob.write(download_stream.readall())
if str(cache_location) != str(REPO_ROOT):
os.symlink(local_file_path, test_dir / remote_file_name)
logger.info(
f" Created symlink for '{local_file_path}' to '{test_dir / remote_file_name}'"
)
setup_cache_symlink_if_needed(cache_dir, test_dir, remote_file_name)


def download_generic_remote_file(
remote_file: str, test_dir: Path, cache_dir: Optional[Path]
):
"""
Downloads a file from a generic URL into test_dir.
If cache_dir is set, downloads there instead, creating a symlink from
test_dir/file_name to cache_dir/file_name.
"""

def download_generic_remote_file(test_dir: Path, remote_file: str):
# TODO(scotttodd): use https://pypi.org/project/requests/
raise NotImplementedError("generic remote file downloads not implemented yet")


def download_for_test_case(test_dir: Path, test_case_json: dict):
def download_files_for_test_case(
test_case_json: dict, test_dir: Path, cache_dir: Optional[Path]
):
if "remote_files" not in test_case_json:
return

# This is naive (greedy, serial) for now. We could batch downloads that
# share a source:
# * Iterate over all files (across all included paths), building a list
Expand All @@ -147,14 +175,11 @@ def download_for_test_case(test_dir: Path, test_case_json: dict):
# * Group files based on source (e.g. Azure container)
# * Start batched/parallel downloads

if "remote_files" not in test_case_json:
return

for remote_file in test_case_json["remote_files"]:
if "blob.core.windows.net" in remote_file:
download_azure_remote_file(test_dir, remote_file)
download_azure_remote_file(remote_file, test_dir, cache_dir)
else:
download_generic_remote_file(test_dir, remote_file)
download_generic_remote_file(remote_file, test_dir, cache_dir)


if __name__ == "__main__":
Expand All @@ -164,6 +189,12 @@ def download_for_test_case(test_dir: Path, test_case_json: dict):
default="",
help="Root directory to search for files to download from (e.g. 'pytorch/models/resnet50')",
)
parser.add_argument(
"--cache-dir",
default=os.getenv("IREE_TEST_FILES", default=""),
help="Local cache directory to download into. If set, symlinks will be created pointing to "
"this location",
)
args = parser.parse_args()

# Adjust logging levels.
Expand All @@ -172,6 +203,10 @@ def download_for_test_case(test_dir: Path, test_case_json: dict):
if log_name.startswith("azure"):
logging.getLogger(log_name).setLevel(logging.WARNING)

# Resolve cache location.
if args.cache_dir:
args.cache_dir = Path(os.path.expanduser(args.cache_dir)).resolve()

# TODO(scotttodd): build list of files _then_ download
# TODO(scotttodd): report size needed for requested files and size available on disk

Expand All @@ -184,5 +219,19 @@ def download_for_test_case(test_dir: Path, test_case_json: dict):
logger.info(f"Processing {test_cases_path.relative_to(THIS_DIR)}")

test_dir = test_cases_path.parent
relative_dir = test_dir.relative_to(THIS_DIR)

# Expand directory structure in the cache matching the test tree.
if args.cache_dir:
cache_dir_for_test = args.cache_dir / "iree_tests" / relative_dir
if not os.path.isdir(cache_dir_for_test):
os.makedirs(cache_dir_for_test)
else:
cache_dir_for_test = None

for test_case_json in test_cases_json["test_cases"]:
download_for_test_case(test_dir, test_case_json)
download_files_for_test_case(
test_case_json=test_case_json,
test_dir=test_dir,
cache_dir=cache_dir_for_test,
)

0 comments on commit 2dba6a8

Please sign in to comment.