Skip to content

Commit

Permalink
feat(framework) Add FAB hash to flwr build/install (#4304)
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng authored Oct 10, 2024
1 parent af5befe commit 2b9dbcc
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 51 deletions.
89 changes: 60 additions & 29 deletions src/py/flwr/cli/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,50 @@
# ==============================================================================
"""Flower command line interface `build` command."""

import hashlib
import os
import shutil
import tempfile
import zipfile
from pathlib import Path
from typing import Annotated, Optional
from typing import Annotated, Any, Optional, Union

import pathspec
import tomli_w
import typer

from flwr.common.constant import FAB_ALLOWED_EXTENSIONS, FAB_DATE, FAB_HASH_TRUNCATION

from .config_utils import load_and_validate
from .utils import get_sha256_hash, is_valid_project_name
from .utils import is_valid_project_name


def write_to_zip(
zipfile_obj: zipfile.ZipFile, filename: str, contents: Union[bytes, str]
) -> zipfile.ZipFile:
"""Set a fixed date and write contents to a zip file."""
zip_info = zipfile.ZipInfo(filename)
zip_info.date_time = FAB_DATE
zipfile_obj.writestr(zip_info, contents)
return zipfile_obj


def get_fab_filename(conf: dict[str, Any], fab_hash: str) -> str:
"""Get the FAB filename based on the given config and FAB hash."""
publisher = conf["tool"]["flwr"]["app"]["publisher"]
name = conf["project"]["name"]
version = conf["project"]["version"].replace(".", "-")
fab_hash_truncated = fab_hash[:FAB_HASH_TRUNCATION]
return f"{publisher}.{name}.{version}.{fab_hash_truncated}.fab"

# pylint: disable=too-many-locals

# pylint: disable=too-many-locals, too-many-statements
def build(
app: Annotated[
Optional[Path],
typer.Option(help="Path of the Flower App to bundle into a FAB"),
] = None,
) -> str:
) -> tuple[str, str]:
"""Build a Flower App into a Flower App Bundle (FAB).
You can run ``flwr build`` without any arguments to bundle the app located in the
Expand Down Expand Up @@ -85,16 +109,8 @@ def build(
# Load .gitignore rules if present
ignore_spec = _load_gitignore(app)

# Set the name of the zip file
fab_filename = (
f"{conf['tool']['flwr']['app']['publisher']}"
f".{conf['project']['name']}"
f".{conf['project']['version'].replace('.', '-')}.fab"
)
list_file_content = ""

allowed_extensions = {".py", ".toml", ".md"}

# Remove the 'federations' field from 'tool.flwr' if it exists
if (
"tool" in conf
Expand All @@ -105,38 +121,53 @@ def build(

toml_contents = tomli_w.dumps(conf)

with zipfile.ZipFile(fab_filename, "w", zipfile.ZIP_DEFLATED) as fab_file:
fab_file.writestr("pyproject.toml", toml_contents)
with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as temp_file:
temp_filename = temp_file.name

with zipfile.ZipFile(temp_filename, "w", zipfile.ZIP_DEFLATED) as fab_file:
write_to_zip(fab_file, "pyproject.toml", toml_contents)

# Continue with adding other files
for root, _, files in os.walk(app, topdown=True):
files = [
# Continue with adding other files
all_files = [
f
for f in files
if not ignore_spec.match_file(Path(root) / f)
and f != fab_filename
and Path(f).suffix in allowed_extensions
and f != "pyproject.toml" # Exclude the original pyproject.toml
for f in app.rglob("*")
if not ignore_spec.match_file(f)
and f.name != temp_filename
and f.suffix in FAB_ALLOWED_EXTENSIONS
and f.name != "pyproject.toml" # Exclude the original pyproject.toml
]

for file in files:
file_path = Path(root) / file
for file_path in all_files:
# Read the file content manually
with open(file_path, "rb") as f:
file_contents = f.read()

archive_path = file_path.relative_to(app)
fab_file.write(file_path, archive_path)
write_to_zip(fab_file, str(archive_path), file_contents)

# Calculate file info
sha256_hash = get_sha256_hash(file_path)
sha256_hash = hashlib.sha256(file_contents).hexdigest()
file_size_bits = os.path.getsize(file_path) * 8 # size in bits
list_file_content += f"{archive_path},{sha256_hash},{file_size_bits}\n"

# Add CONTENT and CONTENT.jwt to the zip file
fab_file.writestr(".info/CONTENT", list_file_content)
# Add CONTENT and CONTENT.jwt to the zip file
write_to_zip(fab_file, ".info/CONTENT", list_file_content)

# Get hash of FAB file
content = Path(temp_filename).read_bytes()
fab_hash = hashlib.sha256(content).hexdigest()

# Set the name of the zip file
fab_filename = get_fab_filename(conf, fab_hash)

# Once the temporary zip file is created, rename it to the final filename
shutil.move(temp_filename, fab_filename)

typer.secho(
f"🎊 Successfully built {fab_filename}", fg=typer.colors.GREEN, bold=True
)

return fab_filename
return fab_filename, fab_hash


def _load_gitignore(app: Path) -> pathspec.PathSpec:
Expand Down
76 changes: 59 additions & 17 deletions src/py/flwr/cli/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""Flower command line interface `install` command."""


import hashlib
import shutil
import subprocess
import tempfile
Expand All @@ -25,7 +25,8 @@

import typer

from flwr.common.config import get_flwr_dir
from flwr.common.config import get_flwr_dir, get_metadata_from_config
from flwr.common.constant import FAB_HASH_TRUNCATION

from .config_utils import load_and_validate
from .utils import get_sha256_hash
Expand Down Expand Up @@ -91,9 +92,11 @@ def install_from_fab(
fab_name: Optional[str]
if isinstance(fab_file, bytes):
fab_file_archive = BytesIO(fab_file)
fab_hash = hashlib.sha256(fab_file).hexdigest()
fab_name = None
elif isinstance(fab_file, Path):
fab_file_archive = fab_file
fab_hash = hashlib.sha256(fab_file.read_bytes()).hexdigest()
fab_name = fab_file.stem
else:
raise ValueError("fab_file must be either a Path or bytes")
Expand Down Expand Up @@ -126,14 +129,16 @@ def install_from_fab(
shutil.rmtree(info_dir)

installed_path = validate_and_install(
tmpdir_path, fab_name, flwr_dir, skip_prompt
tmpdir_path, fab_hash, fab_name, flwr_dir, skip_prompt
)

return installed_path


# pylint: disable=too-many-locals
def validate_and_install(
project_dir: Path,
fab_hash: str,
fab_name: Optional[str],
flwr_dir: Optional[Path],
skip_prompt: bool = False,
Expand All @@ -149,21 +154,12 @@ def validate_and_install(
)
raise typer.Exit(code=1)

publisher = config["tool"]["flwr"]["app"]["publisher"]
project_name = config["project"]["name"]
version = config["project"]["version"]
version, fab_id = get_metadata_from_config(config)
publisher, project_name = fab_id.split("/")
config_metadata = (publisher, project_name, version, fab_hash)

if (
fab_name
and fab_name != f"{publisher}.{project_name}.{version.replace('.', '-')}"
):
typer.secho(
"❌ FAB file has incorrect name. The file name must follow the format "
"`<publisher>.<project_name>.<version>.fab`.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)
if fab_name:
_validate_fab_and_config_metadata(fab_name, config_metadata)

install_dir: Path = (
(get_flwr_dir() if not flwr_dir else flwr_dir)
Expand Down Expand Up @@ -226,3 +222,49 @@ def _verify_hashes(list_content: str, tmpdir: Path) -> bool:
if not file_path.exists() or get_sha256_hash(file_path) != hash_expected:
return False
return True


def _validate_fab_and_config_metadata(
fab_name: str, config_metadata: tuple[str, str, str, str]
) -> None:
"""Validate metadata from the FAB filename and config."""
publisher, project_name, version, fab_hash = config_metadata

fab_name = fab_name.removesuffix(".fab")

fab_publisher, fab_project_name, fab_version, fab_shorthash = fab_name.split(".")
fab_version = fab_version.replace("-", ".")

# Check FAB filename format
if (
f"{fab_publisher}.{fab_project_name}.{fab_version}"
!= f"{publisher}.{project_name}.{version}"
or len(fab_shorthash) != FAB_HASH_TRUNCATION # Verify hash length
):
typer.secho(
"❌ FAB file has incorrect name. The file name must follow the format "
"`<publisher>.<project_name>.<version>.<8hexchars>.fab`.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)

# Verify hash is a valid hexadecimal
try:
_ = int(fab_shorthash, 16)
except Exception as e:
typer.secho(
f"❌ FAB file has an invalid hexadecimal string `{fab_shorthash}`.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1) from e

# Verify shorthash matches
if fab_shorthash != fab_hash[:FAB_HASH_TRUNCATION]:
typer.secho(
"❌ The hash in the FAB file name does not match the hash of the FAB.",
fg=typer.colors.RED,
bold=True,
)
raise typer.Exit(code=1)
10 changes: 5 additions & 5 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# ==============================================================================
"""Flower command line interface `run` command."""

import hashlib
import json
import subprocess
import sys
Expand Down Expand Up @@ -134,6 +133,7 @@ def run(
_run_without_superexec(app, federation_config, config_overrides, federation)


# pylint: disable=too-many-locals
def _run_with_superexec(
app: Path,
federation_config: dict[str, Any],
Expand Down Expand Up @@ -179,9 +179,9 @@ def _run_with_superexec(
channel.subscribe(on_channel_state_change)
stub = ExecStub(channel)

fab_path = Path(build(app))
content = fab_path.read_bytes()
fab = Fab(hashlib.sha256(content).hexdigest(), content)
fab_path, fab_hash = build(app)
content = Path(fab_path).read_bytes()
fab = Fab(fab_hash, content)

req = StartRunRequest(
fab=fab_to_proto(fab),
Expand All @@ -193,7 +193,7 @@ def _run_with_superexec(
res = stub.StartRun(req)

# Delete FAB file once it has been sent to the SuperExec
fab_path.unlink()
Path(fab_path).unlink()
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)

if stream:
Expand Down

0 comments on commit 2b9dbcc

Please sign in to comment.