Skip to content

Commit

Permalink
fix(framework:skip) Support overriding run config from a TOML (#4080)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Aug 28, 2024
1 parent 2dd161a commit 7ae2770
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 11 deletions.
25 changes: 14 additions & 11 deletions src/py/flwr/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,23 +185,26 @@ def parse_config_args(
if config is None:
return overrides

# Handle if .toml file is passed
if len(config) == 1 and config[0].endswith(".toml"):
with Path(config[0]).open("rb") as config_file:
overrides = flatten_dict(tomli.load(config_file))
return overrides

# Regular expression to capture key-value pairs with possible quoted values
pattern = re.compile(r"(\S+?)=(\'[^\']*\'|\"[^\"]*\"|\S+)")

for config_line in config:
if config_line:
matches = pattern.findall(config_line)
# .toml files aren't allowed alongside other configs
if config_line.endswith(".toml"):
raise ValueError(
"TOML files cannot be passed alongside key-value pairs."
)

if (
len(matches) == 1
and "=" not in matches[0][0]
and matches[0][0].endswith(".toml")
):
with Path(matches[0][0]).open("rb") as config_file:
overrides = flatten_dict(tomli.load(config_file))
else:
toml_str = "\n".join(f"{k} = {v}" for k, v in matches)
overrides.update(tomli.loads(toml_str))
matches = pattern.findall(config_line)
toml_str = "\n".join(f"{k} = {v}" for k, v in matches)
overrides.update(tomli.loads(toml_str))

return overrides

Expand Down
48 changes: 48 additions & 0 deletions src/py/flwr/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Test util functions handling Flower config."""

import os
import tempfile
import textwrap
from pathlib import Path
from unittest.mock import patch
Expand Down Expand Up @@ -254,3 +255,50 @@ def test_parse_config_args_overrides() -> None:
"key5": True,
"key6": "value6",
}


def test_parse_config_args_from_toml_file() -> None:
"""Test if a toml passed to --run-config it is loaded and fused correctly."""
# Will be saved as a temp .toml file
toml_config = """
num-server-rounds = 10
momentum = 0.1
verbose = true
"""
# This is the UserConfig that would be extracted from pyproject.toml
initial_run_config: UserConfig = {
"num-server-rounds": 5,
"momentum": 0.2,
"dataset": "my-fancy-dataset",
"verbose": False,
}
expected_config = {
"num-server-rounds": 10,
"momentum": 0.1,
"dataset": "my-fancy-dataset",
"verbose": True,
}

# Create a temporary directory using a context manager
with tempfile.TemporaryDirectory() as temp_dir:
# Create a temporary TOML file within that directory
toml_config_file = os.path.join(temp_dir, "extra_config.toml")

# Write the data to the TOML file
with open(toml_config_file, "w", encoding="utf-8") as toml_file:
toml_file.write(textwrap.dedent(toml_config))

# Parse config (this mimics what `--run-config path/to/config.toml` does)
config_from_toml = parse_config_args([toml_config_file])
# Fuse
config = fuse_dicts(initial_run_config, config_from_toml)

# Assert
assert config == expected_config


def test_parse_config_args_passing_toml_and_key_value() -> None:
"""Test that passing a toml and key-value configs aren't allowed."""
config = ["my-other-config.toml", "lr=0.1", "epochs=99"]
with pytest.raises(ValueError):
parse_config_args(config)

0 comments on commit 7ae2770

Please sign in to comment.