diff --git a/src/py/flwr/server/server_app.py b/src/py/flwr/server/server_app.py index f19a9d91986..e9cb4ddcaf0 100644 --- a/src/py/flwr/server/server_app.py +++ b/src/py/flwr/server/server_app.py @@ -17,8 +17,11 @@ from typing import Callable, Optional -from flwr.common import Context, RecordSet -from flwr.common.logger import warn_preview_feature +from flwr.common import Context +from flwr.common.logger import ( + warn_deprecated_feature_with_example, + warn_preview_feature, +) from flwr.server.strategy import Strategy from .client_manager import ClientManager @@ -26,7 +29,20 @@ from .driver import Driver from .server import Server from .server_config import ServerConfig -from .typing import ServerAppCallable +from .typing import ServerAppCallable, ServerFn + +SERVER_FN_USAGE_EXAMPLE = """ + + def server_fn(context: Context): + server_config = ServerConfig(num_rounds=3) + strategy = FedAvg() + return ServerAppComponents( + strategy=strategy, + server_config=server_config, + ) + + app = ServerApp(server_fn=server_fn) +""" class ServerApp: @@ -36,13 +52,15 @@ class ServerApp: -------- Use the `ServerApp` with an existing `Strategy`: - >>> server_config = ServerConfig(num_rounds=3) - >>> strategy = FedAvg() + >>> def server_fn(context: Context): + >>> server_config = ServerConfig(num_rounds=3) + >>> strategy = FedAvg() + >>> return ServerAppComponents( + >>> strategy=strategy, + >>> server_config=server_config, + >>> ) >>> - >>> app = ServerApp( - >>> server_config=server_config, - >>> strategy=strategy, - >>> ) + >>> app = ServerApp(server_fn=server_fn) Use the `ServerApp` with a custom main function: @@ -53,23 +71,52 @@ class ServerApp: >>> print("ServerApp running") """ + # pylint: disable=too-many-arguments def __init__( self, server: Optional[Server] = None, config: Optional[ServerConfig] = None, strategy: Optional[Strategy] = None, client_manager: Optional[ClientManager] = None, + server_fn: Optional[ServerFn] = None, ) -> None: + if any([server, config, strategy, client_manager]): + warn_deprecated_feature_with_example( + deprecation_message="Passing either `server`, `config`, `strategy` or " + "`client_manager` directly to the ServerApp " + "constructor is deprecated.", + example_message="Pass `ServerApp` arguments wrapped " + "in a `flwr.server.ServerAppComponents` object that gets " + "returned by a function passed as the `server_fn` argument " + "to the `ServerApp` constructor. For example: ", + code_example=SERVER_FN_USAGE_EXAMPLE, + ) + + if server_fn: + raise ValueError( + "Passing `server_fn` is incompatible with passing the " + "other arguments (now deprecated) to ServerApp. " + "Use `server_fn` exclusively." + ) + self._server = server self._config = config self._strategy = strategy self._client_manager = client_manager + self._server_fn = server_fn self._main: Optional[ServerAppCallable] = None def __call__(self, driver: Driver, context: Context) -> None: """Execute `ServerApp`.""" # Compatibility mode if not self._main: + if self._server_fn: + # Execute server_fn() + components = self._server_fn(context) + self._server = components.server + self._config = components.config + self._strategy = components.strategy + self._client_manager = components.client_manager start_driver( server=self._server, config=self._config, @@ -80,7 +127,6 @@ def __call__(self, driver: Driver, context: Context) -> None: return # New execution mode - context = Context(state=RecordSet(), run_config={}) self._main(driver, context) def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]: diff --git a/src/py/flwr/server/typing.py b/src/py/flwr/server/typing.py index 01143af7439..cdb1c0db4fe 100644 --- a/src/py/flwr/server/typing.py +++ b/src/py/flwr/server/typing.py @@ -20,6 +20,8 @@ from flwr.common import Context from .driver import Driver +from .serverapp_components import ServerAppComponents ServerAppCallable = Callable[[Driver, Context], None] Workflow = Callable[[Driver, Context], None] +ServerFn = Callable[[Context], ServerAppComponents]