Skip to content

Commit

Permalink
feat(framework) Add run configs (#3725)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbvll authored Jul 11, 2024
1 parent ac98491 commit ea01fd1
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
10 changes: 10 additions & 0 deletions src/py/flwr/cli/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ def load(path: Optional[Path] = None) -> Optional[Dict[str, Any]]:
return load_from_string(toml_file.read())


def _validate_run_config(config_dict: Dict[str, Any], errors: List[str]) -> None:
for key, value in config_dict.items():
if isinstance(value, dict):
_validate_run_config(config_dict[key], errors)
elif not isinstance(value, str):
errors.append(f"Config value of key {key} is not of type `str`.")


# pylint: disable=too-many-branches
def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]:
"""Validate pyproject.toml fields."""
Expand All @@ -133,6 +141,8 @@ def validate_fields(config: Dict[str, Any]) -> Tuple[bool, List[str], List[str]]
else:
if "publisher" not in config["flower"]:
errors.append('Property "publisher" missing in [flower]')
if "config" in config["flower"]:
_validate_run_config(config["flower"]["config"], errors)
if "components" not in config["flower"]:
errors.append("Missing [flower.components] section")
else:
Expand Down
33 changes: 25 additions & 8 deletions src/py/flwr/cli/run/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from enum import Enum
from logging import DEBUG
from pathlib import Path
from typing import Optional
from typing import Dict, Optional

import typer
from typing_extensions import Annotated

from flwr.cli import config_utils
from flwr.cli.build import build
from flwr.common.config import parse_config_args
from flwr.common.constant import SUPEREXEC_DEFAULT_ADDRESS
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
from flwr.common.logger import log
Expand Down Expand Up @@ -58,15 +59,20 @@ def run(
Optional[Path],
typer.Option(help="Path of the Flower project to run"),
] = None,
config_overrides: Annotated[
Optional[str],
typer.Option(
"--config",
"-c",
help="Override configuration key-value pairs",
),
] = None,
) -> None:
"""Run Flower project."""
if use_superexec:
_start_superexec_run(directory)
return

typer.secho("Loading project configuration... ", fg=typer.colors.BLUE)

config, errors, warnings = config_utils.load_and_validate()
pyproject_path = directory / "pyproject.toml" if directory else None
config, errors, warnings = config_utils.load_and_validate(path=pyproject_path)

if config is None:
typer.secho(
Expand All @@ -88,6 +94,12 @@ def run(

typer.secho("Success", fg=typer.colors.GREEN)

if use_superexec:
_start_superexec_run(
parse_config_args(config_overrides, separator=","), directory
)
return

server_app_ref = config["flower"]["components"]["serverapp"]
client_app_ref = config["flower"]["components"]["clientapp"]

Expand Down Expand Up @@ -115,7 +127,9 @@ def run(
)


def _start_superexec_run(directory: Optional[Path]) -> None:
def _start_superexec_run(
override_config: Dict[str, str], directory: Optional[Path]
) -> None:
def on_channel_state_change(channel_connectivity: str) -> None:
"""Log channel connectivity."""
log(DEBUG, channel_connectivity)
Expand All @@ -132,6 +146,9 @@ def on_channel_state_change(channel_connectivity: str) -> None:

fab_path = build(directory)

req = StartRunRequest(fab_file=Path(fab_path).read_bytes())
req = StartRunRequest(
fab_file=Path(fab_path).read_bytes(),
override_config=override_config,
)
res = stub.StartRun(req)
typer.secho(f"🎊 Successfully started run {res.run_id}", fg=typer.colors.GREEN)

0 comments on commit ea01fd1

Please sign in to comment.