Skip to content

Commit

Permalink
Merge branch 'add-fab-hash-install' into add-fab-hash-to-client-fn
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng committed Oct 10, 2024
2 parents d8f1334 + 7ae9b60 commit a773a4d
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 50 deletions.
4 changes: 2 additions & 2 deletions src/py/flwr/cli/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from flwr.common.constant import FAB_ALLOWED_EXTENSIONS, FAB_DATE, FAB_HASH_TRUNCATION

from .config_utils import load_and_validate
from .utils import get_sha256_hash, is_valid_project_name
from .utils import is_valid_project_name


def write_to_zip(
Expand Down Expand Up @@ -146,7 +146,7 @@ def build(
write_to_zip(fab_file, str(archive_path), file_contents)

# Calculate file info
sha256_hash = get_sha256_hash(file_path)
sha256_hash = hashlib.sha256(file_contents).hexdigest()
file_size_bits = os.path.getsize(file_path) * 8 # size in bits
list_file_content += f"{archive_path},{sha256_hash},{file_size_bits}\n"

Expand Down
70 changes: 39 additions & 31 deletions src/py/flwr/cli/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@

import typer

from flwr.common.config import (
get_flwr_dir,
get_metadata_from_config,
get_metadata_from_fab_filename,
)
from flwr.common.config import get_flwr_dir, get_metadata_from_config
from flwr.common.constant import FAB_HASH_TRUNCATION

from .config_utils import load_and_validate
Expand Down Expand Up @@ -160,34 +156,10 @@ def validate_and_install(

version, fab_id = get_metadata_from_config(config)
publisher, project_name = fab_id.split("/")
config_metadata = (publisher, project_name, version, fab_hash)

if fab_name:
fab_publisher, fab_project_name, fab_version, fab_shorthash = (
get_metadata_from_fab_filename(fab_name)
)
if (
f"{fab_publisher}.{fab_project_name}.{fab_version}"
!= f"{publisher}.{project_name}.{version}"
or len(fab_shorthash) != FAB_HASH_TRUNCATION # Verify hash length
):

typer.secho(
"❌ FAB file has incorrect name. The file name must follow the format "
"`<publisher>.<project_name>.<version>.<8hexchars>.fab`.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)

try:
_ = int(fab_shorthash, 16) # Verify hash is a valid hexadecimal
except ValueError as e:
typer.secho(
f"❌ FAB file has an invalid hexadecimal string `{fab_shorthash}`.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1) from e
_validate_fab_and_config_metadata(fab_name, config_metadata)

install_dir: Path = (
(get_flwr_dir() if not flwr_dir else flwr_dir)
Expand Down Expand Up @@ -248,3 +220,39 @@ def _verify_hashes(list_content: str, tmpdir: Path) -> bool:
if not file_path.exists() or get_sha256_hash(file_path) != hash_expected:
return False
return True


def _validate_fab_and_config_metadata(
fab_name: str, config_metadata: tuple[str, str, str, str]
) -> None:
"""Validate metadata from the FAB filename and config."""
publisher, project_name, version, fab_hash = config_metadata

fab_name = fab_name.removesuffix(".fab")

fab_publisher, fab_project_name, fab_version, fab_shorthash = fab_name.split(".")

# Check FAB filename format
if (
f"{fab_publisher}.{fab_project_name}.{fab_version}"
!= f"{publisher}.{project_name}.{version}"
or len(fab_shorthash) != FAB_HASH_TRUNCATION # Verify hash length
):
raise ValueError(
"❌ FAB file has incorrect name. The file name must follow the format "
"`<publisher>.<project_name>.<version>.<8hexchars>.fab`.",
)

# Verify hash is a valid hexadecimal
try:
_ = int(fab_shorthash, 16)
except Exception as e:
raise ValueError(
"❌ FAB file has an invalid hexadecimal string `{fab_shorthash}`."
) from e

# Verify shorthash matches
if fab_shorthash != fab_hash[:FAB_HASH_TRUNCATION]:
raise ValueError(
"❌ The hash in the FAB file name does not match the hash of the FAB."
)
10 changes: 5 additions & 5 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# ==============================================================================
"""Flower command line interface `run` command."""

import hashlib
import json
import subprocess
import sys
Expand Down Expand Up @@ -134,6 +133,7 @@ def run(
_run_without_superexec(app, federation_config, config_overrides, federation)


# pylint: disable=too-many-locals
def _run_with_superexec(
app: Path,
federation_config: dict[str, Any],
Expand Down Expand Up @@ -179,9 +179,9 @@ def _run_with_superexec(
channel.subscribe(on_channel_state_change)
stub = ExecStub(channel)

fab_path = Path(build(app)[0])
content = fab_path.read_bytes()
fab = Fab(hashlib.sha256(content).hexdigest(), content)
fab_path, fab_hash = build(app)
content = Path(fab_path).read_bytes()
fab = Fab(fab_hash, content)

req = StartRunRequest(
fab=fab_to_proto(fab),
Expand All @@ -193,7 +193,7 @@ def _run_with_superexec(
res = stub.StartRun(req)

# Delete FAB file once it has been sent to the SuperExec
fab_path.unlink()
Path(fab_path).unlink()
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)

if stream:
Expand Down
12 changes: 0 additions & 12 deletions src/py/flwr/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,15 +227,3 @@ def get_metadata_from_config(config: dict[str, Any]) -> tuple[str, str]:
config["project"]["version"],
f"{config['tool']['flwr']['app']['publisher']}/{config['project']['name']}",
)


def get_metadata_from_fab_filename(
fab_file: Union[Path, str]
) -> tuple[str, str, str, str]:
"""Extract metadata from the FAB filename."""
if isinstance(fab_file, Path):
fab_file_name = fab_file.stem
elif isinstance(fab_file, str):
fab_file_name = fab_file.removesuffix(".fab")
publisher, project_name, version, shorthash = fab_file_name.split(".")
return publisher, project_name, version.replace("-", "."), shorthash

0 comments on commit a773a4d

Please sign in to comment.