Skip to content

Commit

Permalink
feat(framework) Introduce server_fn to setup ServerApp (#3773)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel J. Beutel <daniel@flower.ai>
  • Loading branch information
jafermarq and danieljanes authored Jul 11, 2024
1 parent ea01fd1 commit 67bbd4d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 10 deletions.
66 changes: 56 additions & 10 deletions src/py/flwr/server/server_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,32 @@

from typing import Callable, Optional

from flwr.common import Context, RecordSet
from flwr.common.logger import warn_preview_feature
from flwr.common import Context
from flwr.common.logger import (
warn_deprecated_feature_with_example,
warn_preview_feature,
)
from flwr.server.strategy import Strategy

from .client_manager import ClientManager
from .compat import start_driver
from .driver import Driver
from .server import Server
from .server_config import ServerConfig
from .typing import ServerAppCallable
from .typing import ServerAppCallable, ServerFn

SERVER_FN_USAGE_EXAMPLE = """
def server_fn(context: Context):
server_config = ServerConfig(num_rounds=3)
strategy = FedAvg()
return ServerAppComponents(
strategy=strategy,
server_config=server_config,
)
app = ServerApp(server_fn=server_fn)
"""


class ServerApp:
Expand All @@ -36,13 +52,15 @@ class ServerApp:
--------
Use the `ServerApp` with an existing `Strategy`:
>>> server_config = ServerConfig(num_rounds=3)
>>> strategy = FedAvg()
>>> def server_fn(context: Context):
>>> server_config = ServerConfig(num_rounds=3)
>>> strategy = FedAvg()
>>> return ServerAppComponents(
>>> strategy=strategy,
>>> server_config=server_config,
>>> )
>>>
>>> app = ServerApp(
>>> server_config=server_config,
>>> strategy=strategy,
>>> )
>>> app = ServerApp(server_fn=server_fn)
Use the `ServerApp` with a custom main function:
Expand All @@ -53,23 +71,52 @@ class ServerApp:
>>> print("ServerApp running")
"""

# pylint: disable=too-many-arguments
def __init__(
self,
server: Optional[Server] = None,
config: Optional[ServerConfig] = None,
strategy: Optional[Strategy] = None,
client_manager: Optional[ClientManager] = None,
server_fn: Optional[ServerFn] = None,
) -> None:
if any([server, config, strategy, client_manager]):
warn_deprecated_feature_with_example(
deprecation_message="Passing either `server`, `config`, `strategy` or "
"`client_manager` directly to the ServerApp "
"constructor is deprecated.",
example_message="Pass `ServerApp` arguments wrapped "
"in a `flwr.server.ServerAppComponents` object that gets "
"returned by a function passed as the `server_fn` argument "
"to the `ServerApp` constructor. For example: ",
code_example=SERVER_FN_USAGE_EXAMPLE,
)

if server_fn:
raise ValueError(
"Passing `server_fn` is incompatible with passing the "
"other arguments (now deprecated) to ServerApp. "
"Use `server_fn` exclusively."
)

self._server = server
self._config = config
self._strategy = strategy
self._client_manager = client_manager
self._server_fn = server_fn
self._main: Optional[ServerAppCallable] = None

def __call__(self, driver: Driver, context: Context) -> None:
"""Execute `ServerApp`."""
# Compatibility mode
if not self._main:
if self._server_fn:
# Execute server_fn()
components = self._server_fn(context)
self._server = components.server
self._config = components.config
self._strategy = components.strategy
self._client_manager = components.client_manager
start_driver(
server=self._server,
config=self._config,
Expand All @@ -80,7 +127,6 @@ def __call__(self, driver: Driver, context: Context) -> None:
return

# New execution mode
context = Context(state=RecordSet(), run_config={})
self._main(driver, context)

def main(self) -> Callable[[ServerAppCallable], ServerAppCallable]:
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from flwr.common import Context

from .driver import Driver
from .serverapp_components import ServerAppComponents

ServerAppCallable = Callable[[Driver, Context], None]
Workflow = Callable[[Driver, Context], None]
ServerFn = Callable[[Context], ServerAppComponents]

0 comments on commit 67bbd4d

Please sign in to comment.