From 7ae277075009458c1e4a9d11268b77bdad675945 Mon Sep 17 00:00:00 2001 From: Javier Date: Wed, 28 Aug 2024 18:52:16 +0200 Subject: [PATCH] fix(framework:skip) Support overriding run config from a `TOML` (#4080) --- src/py/flwr/common/config.py | 25 +++++++++------- src/py/flwr/common/config_test.py | 48 +++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 11 deletions(-) diff --git a/src/py/flwr/common/config.py b/src/py/flwr/common/config.py index eec7cfb726b..42039fa959a 100644 --- a/src/py/flwr/common/config.py +++ b/src/py/flwr/common/config.py @@ -185,23 +185,26 @@ def parse_config_args( if config is None: return overrides + # Handle if .toml file is passed + if len(config) == 1 and config[0].endswith(".toml"): + with Path(config[0]).open("rb") as config_file: + overrides = flatten_dict(tomli.load(config_file)) + return overrides + # Regular expression to capture key-value pairs with possible quoted values pattern = re.compile(r"(\S+?)=(\'[^\']*\'|\"[^\"]*\"|\S+)") for config_line in config: if config_line: - matches = pattern.findall(config_line) + # .toml files aren't allowed alongside other configs + if config_line.endswith(".toml"): + raise ValueError( + "TOML files cannot be passed alongside key-value pairs." + ) - if ( - len(matches) == 1 - and "=" not in matches[0][0] - and matches[0][0].endswith(".toml") - ): - with Path(matches[0][0]).open("rb") as config_file: - overrides = flatten_dict(tomli.load(config_file)) - else: - toml_str = "\n".join(f"{k} = {v}" for k, v in matches) - overrides.update(tomli.loads(toml_str)) + matches = pattern.findall(config_line) + toml_str = "\n".join(f"{k} = {v}" for k, v in matches) + overrides.update(tomli.loads(toml_str)) return overrides diff --git a/src/py/flwr/common/config_test.py b/src/py/flwr/common/config_test.py index 712e07264d3..34bc691cc95 100644 --- a/src/py/flwr/common/config_test.py +++ b/src/py/flwr/common/config_test.py @@ -15,6 +15,7 @@ """Test util functions handling Flower config.""" import os +import tempfile import textwrap from pathlib import Path from unittest.mock import patch @@ -254,3 +255,50 @@ def test_parse_config_args_overrides() -> None: "key5": True, "key6": "value6", } + + +def test_parse_config_args_from_toml_file() -> None: + """Test if a toml passed to --run-config it is loaded and fused correctly.""" + # Will be saved as a temp .toml file + toml_config = """ + num-server-rounds = 10 + momentum = 0.1 + verbose = true + """ + # This is the UserConfig that would be extracted from pyproject.toml + initial_run_config: UserConfig = { + "num-server-rounds": 5, + "momentum": 0.2, + "dataset": "my-fancy-dataset", + "verbose": False, + } + expected_config = { + "num-server-rounds": 10, + "momentum": 0.1, + "dataset": "my-fancy-dataset", + "verbose": True, + } + + # Create a temporary directory using a context manager + with tempfile.TemporaryDirectory() as temp_dir: + # Create a temporary TOML file within that directory + toml_config_file = os.path.join(temp_dir, "extra_config.toml") + + # Write the data to the TOML file + with open(toml_config_file, "w", encoding="utf-8") as toml_file: + toml_file.write(textwrap.dedent(toml_config)) + + # Parse config (this mimics what `--run-config path/to/config.toml` does) + config_from_toml = parse_config_args([toml_config_file]) + # Fuse + config = fuse_dicts(initial_run_config, config_from_toml) + + # Assert + assert config == expected_config + + +def test_parse_config_args_passing_toml_and_key_value() -> None: + """Test that passing a toml and key-value configs aren't allowed.""" + config = ["my-other-config.toml", "lr=0.1", "epochs=99"] + with pytest.raises(ValueError): + parse_config_args(config)