Skip to content

Commit

Permalink
GCS Folder Fix (#656)
Browse files Browse the repository at this point in the history
* working fix

* ran basic tests with my private repo
  • Loading branch information
Varun Shenoy authored Sep 13, 2023
1 parent 97b2415 commit 39a790d
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 1 deletion.
17 changes: 16 additions & 1 deletion truss/contexts/image_builder/cache_warmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,36 @@ def _download_from_url_using_b10cp(
)


def split_gs_path(gs_path):
# Remove the 'gs://' prefix
path = gs_path.replace("gs://", "")

# Split on the first slash
parts = path.split("/", 1)

bucket_name = parts[0]
prefix = parts[1] if len(parts) > 1 else ""

return bucket_name, prefix


def download_file(
repo_name, file_name, revision_name=None, key_file="/app/data/service_account.json"
):
# Check if repo_name starts with "gs://"
if "gs://" in repo_name:
# Create directory if not exist
bucket_name, _ = split_gs_path(repo_name)
repo_name = repo_name.replace("gs://", "")
cache_dir = Path(f"/app/hf_cache/{repo_name}")
cache_dir = Path(f"/app/hf_cache/{bucket_name}")
cache_dir.mkdir(parents=True, exist_ok=True)

# Connect to GCS storage
try:
storage_client = storage.Client.from_service_account_json(key_file)
bucket = storage_client.bucket(repo_name)
blob = bucket.blob(file_name)

dst_file = Path(f"{cache_dir}/{file_name}")
if not dst_file.parent.exists():
dst_file.parent.mkdir(parents=True)
Expand Down
3 changes: 3 additions & 0 deletions truss/contexts/image_builder/serving_image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def list_bucket_files(bucket_name, data_dir, is_trusted=False):

all_objects = []
for blob in blobs:
# leave out folders
if blob.name[-1] == "/":
continue
all_objects.append(blob.name)

return all_objects
Expand Down
38 changes: 38 additions & 0 deletions truss/tests/contexts/image_builder/test_serving_image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,44 @@ def test_correct_gcs_files_accessed_for_caching(mock_list_bucket_files):
assert "fake_model-001-of-002.bin" in model_files[model]["files"]


@patch("truss.contexts.image_builder.serving_image_builder.list_bucket_files")
def test_correct_nested_gcs_files_accessed_for_caching(mock_list_bucket_files):
mock_list_bucket_files.return_value = [
"folder_a/folder_b/fake_model-001-of-002.bin",
"folder_a/folder_b/fake_model-002-of-002.bin",
]
model = "gs://crazy-good-new-model-7b/folder_a/folder_b"

config = TrussConfig(
python_version="py39",
hf_cache=HuggingFaceCache(models=[HuggingFaceModel(repo_id=model)]),
)

with TemporaryDirectory() as tmp_dir:
truss_path = Path(tmp_dir)
build_path = truss_path / "build"
build_path.mkdir(parents=True, exist_ok=True)

model_files, files_to_cache = get_files_to_cache(config, truss_path, build_path)
print(files_to_cache)

assert (
"/app/hf_cache/crazy-good-new-model-7b/folder_a/folder_b/fake_model-001-of-002.bin"
in files_to_cache
)
assert (
"/app/hf_cache/crazy-good-new-model-7b/folder_a/folder_b/fake_model-002-of-002.bin"
in files_to_cache
)

assert (
"folder_a/folder_b/fake_model-001-of-002.bin" in model_files[model]["files"]
)
assert (
"folder_a/folder_b/fake_model-001-of-002.bin" in model_files[model]["files"]
)


@pytest.mark.integration
def test_tgi_caching_truss():
with ensure_kill_all():
Expand Down

0 comments on commit 39a790d

Please sign in to comment.