From 39a790d550a8e84e56c90ca3d3a7e45394f4fd32 Mon Sep 17 00:00:00 2001 From: Varun Shenoy Date: Wed, 13 Sep 2023 13:12:49 -0700 Subject: [PATCH] GCS Folder Fix (#656) * working fix * ran basic tests with my private repo --- truss/contexts/image_builder/cache_warmer.py | 17 ++++++++- .../image_builder/serving_image_builder.py | 3 ++ .../test_serving_image_builder.py | 38 +++++++++++++++++++ 3 files changed, 57 insertions(+), 1 deletion(-) diff --git a/truss/contexts/image_builder/cache_warmer.py b/truss/contexts/image_builder/cache_warmer.py index 4ff8739e2..019b872cb 100644 --- a/truss/contexts/image_builder/cache_warmer.py +++ b/truss/contexts/image_builder/cache_warmer.py @@ -32,14 +32,28 @@ def _download_from_url_using_b10cp( ) +def split_gs_path(gs_path): + # Remove the 'gs://' prefix + path = gs_path.replace("gs://", "") + + # Split on the first slash + parts = path.split("/", 1) + + bucket_name = parts[0] + prefix = parts[1] if len(parts) > 1 else "" + + return bucket_name, prefix + + def download_file( repo_name, file_name, revision_name=None, key_file="/app/data/service_account.json" ): # Check if repo_name starts with "gs://" if "gs://" in repo_name: # Create directory if not exist + bucket_name, _ = split_gs_path(repo_name) repo_name = repo_name.replace("gs://", "") - cache_dir = Path(f"/app/hf_cache/{repo_name}") + cache_dir = Path(f"/app/hf_cache/{bucket_name}") cache_dir.mkdir(parents=True, exist_ok=True) # Connect to GCS storage @@ -47,6 +61,7 @@ def download_file( storage_client = storage.Client.from_service_account_json(key_file) bucket = storage_client.bucket(repo_name) blob = bucket.blob(file_name) + dst_file = Path(f"{cache_dir}/{file_name}") if not dst_file.parent.exists(): dst_file.parent.mkdir(parents=True) diff --git a/truss/contexts/image_builder/serving_image_builder.py b/truss/contexts/image_builder/serving_image_builder.py index 7e74e3899..8fd823043 100644 --- a/truss/contexts/image_builder/serving_image_builder.py +++ b/truss/contexts/image_builder/serving_image_builder.py @@ -114,6 +114,9 @@ def list_bucket_files(bucket_name, data_dir, is_trusted=False): all_objects = [] for blob in blobs: + # leave out folders + if blob.name[-1] == "/": + continue all_objects.append(blob.name) return all_objects diff --git a/truss/tests/contexts/image_builder/test_serving_image_builder.py b/truss/tests/contexts/image_builder/test_serving_image_builder.py index 7d0015a5b..2c3522ff4 100644 --- a/truss/tests/contexts/image_builder/test_serving_image_builder.py +++ b/truss/tests/contexts/image_builder/test_serving_image_builder.py @@ -163,6 +163,44 @@ def test_correct_gcs_files_accessed_for_caching(mock_list_bucket_files): assert "fake_model-001-of-002.bin" in model_files[model]["files"] +@patch("truss.contexts.image_builder.serving_image_builder.list_bucket_files") +def test_correct_nested_gcs_files_accessed_for_caching(mock_list_bucket_files): + mock_list_bucket_files.return_value = [ + "folder_a/folder_b/fake_model-001-of-002.bin", + "folder_a/folder_b/fake_model-002-of-002.bin", + ] + model = "gs://crazy-good-new-model-7b/folder_a/folder_b" + + config = TrussConfig( + python_version="py39", + hf_cache=HuggingFaceCache(models=[HuggingFaceModel(repo_id=model)]), + ) + + with TemporaryDirectory() as tmp_dir: + truss_path = Path(tmp_dir) + build_path = truss_path / "build" + build_path.mkdir(parents=True, exist_ok=True) + + model_files, files_to_cache = get_files_to_cache(config, truss_path, build_path) + print(files_to_cache) + + assert ( + "/app/hf_cache/crazy-good-new-model-7b/folder_a/folder_b/fake_model-001-of-002.bin" + in files_to_cache + ) + assert ( + "/app/hf_cache/crazy-good-new-model-7b/folder_a/folder_b/fake_model-002-of-002.bin" + in files_to_cache + ) + + assert ( + "folder_a/folder_b/fake_model-001-of-002.bin" in model_files[model]["files"] + ) + assert ( + "folder_a/folder_b/fake_model-001-of-002.bin" in model_files[model]["files"] + ) + + @pytest.mark.integration def test_tgi_caching_truss(): with ensure_kill_all():