diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 7ac339aa43c..012f584561a 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -311,3 +311,39 @@ jobs: if grep -q "ERROR" flwr_output.log; then exit 1 fi + + build_and_install: + runs-on: ubuntu-22.04 + timeout-minutes: 10 + needs: wheel + strategy: + matrix: + framework: ["numpy"] + python-version: ["3.9", "3.10", "3.11"] + + name: | + Build & Install / + Python ${{ matrix.python-version }} / + ${{ matrix.framework }} + + steps: + - uses: actions/checkout@v4 + - name: Bootstrap + uses: ./.github/actions/bootstrap + with: + python-version: ${{ matrix.python-version }} + poetry-skip: 'true' + - name: Install Flower from repo + if: ${{ github.repository != 'adap/flower' || github.event.pull_request.head.repo.fork || github.actor == 'dependabot[bot]' }} + run: | + python -m pip install . + - name: Install Flower wheel from artifact store + if: ${{ github.repository == 'adap/flower' && !github.event.pull_request.head.repo.fork && github.actor != 'dependabot[bot]' }} + run: | + python -m pip install https://${{ env.ARTIFACT_BUCKET }}/py/${{ needs.wheel.outputs.dir }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }} + - name: Create project, build, and install it + run: | + flwr new tmp-${{ matrix.framework }} --framework ${{ matrix.framework }} --username gh_ci + cd tmp-${{ matrix.framework }} + flwr build + flwr install *.fab diff --git a/dev/test.sh b/dev/test.sh index 8f8d9dedf6d..b8eeed14bc4 100755 --- a/dev/test.sh +++ b/dev/test.sh @@ -26,6 +26,10 @@ echo "- docformatter: start" python -m docformatter -c -r src/py/flwr e2e -e src/py/flwr/proto echo "- docformatter: done" +echo "- docsig: start" +docsig src/py/flwr +echo "- docsig: done" + echo "- ruff: start" python -m ruff check src/py/flwr echo "- ruff: done" diff --git a/pyproject.toml b/pyproject.toml index 11da893bec7..aa356835d98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,6 +134,7 @@ pre-commit = "==3.5.0" sphinx-substitution-extensions = "2022.02.16" sphinxext-opengraph = "==0.9.1" docstrfmt = { git = "https://github.com/charlesbvll/docstrfmt.git", branch = "patch-1" } +docsig = "==0.64.0" [tool.docstrfmt] extend_exclude = [ @@ -224,3 +225,7 @@ convention = "numpy" [tool.ruff.per-file-ignores] "src/py/flwr/server/strategy/*.py" = ["E501"] + +[tool.docsig] +ignore-no-params = true +exclude = 'src/py/flwr/proto/.*|src/py/flwr/.*_test\.py|src/py/flwr/cli/new/templates/.*\.tpl' diff --git a/src/py/flwr/cli/config_utils.py b/src/py/flwr/cli/config_utils.py index 79e4973ccf9..73ce779c3b5 100644 --- a/src/py/flwr/cli/config_utils.py +++ b/src/py/flwr/cli/config_utils.py @@ -90,6 +90,16 @@ def load_and_validate( ) -> tuple[Optional[dict[str, Any]], list[str], list[str]]: """Load and validate pyproject.toml as dict. + Parameters + ---------- + path : Optional[Path] (default: None) + The path of the Flower App config file to load. By default it + will try to use `pyproject.toml` inside the current directory. + check_module: bool (default: True) + Whether the validity of the Python module should be checked. + This requires the project to be installed in the currently + running environment. True by default. + Returns ------- Tuple[Optional[config], List[str], List[str]] diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 90c50aba7fa..54ed33aa66b 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -132,6 +132,11 @@ class `flwr.client.Client` (default: None) - 'grpc-bidi': gRPC, bidirectional streaming - 'grpc-rere': gRPC, request-response (experimental) - 'rest': HTTP (experimental) + authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None) + Tuple containing the elliptic curve private key and public key for + authentication from the cryptography library. + Source: https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/ + Used to establish an authenticated connection with the server. max_retries: Optional[int] (default: None) The maximum number of times the client will try to connect to the server before giving up in case of a connection error. If set to None, @@ -249,6 +254,11 @@ class `flwr.client.Client` (default: None) - 'grpc-bidi': gRPC, bidirectional streaming - 'grpc-rere': gRPC, request-response (experimental) - 'rest': HTTP (experimental) + authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None) + Tuple containing the elliptic curve private key and public key for + authentication from the cryptography library. + Source: https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/ + Used to establish an authenticated connection with the server. max_retries: Optional[int] (default: None) The maximum number of times the client will try to connect to the server before giving up in case of a connection error. If set to None, diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index b4fa2837360..06701376fac 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -120,6 +120,9 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 authentication from the cryptography library. Source: https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/ Used to establish an authenticated connection with the server. + adapter_cls: Optional[Union[type[FleetStub], type[GrpcAdapter]]] (default: None) + A GrpcStub Class that can be used to send messages. By default the FleetStub + will be used. Returns ------- diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index ce4fdb3dd82..4e792e8e02a 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -290,6 +290,11 @@ def create_error_reply(self, error: Error, ttl: float | None = None) -> Message: follows the equation: ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at) + + Returns + ------- + message : Message + A Message containing only the relevant error and metadata. """ # If no TTL passed, use default for message creation (will update after # message creation) diff --git a/src/py/flwr/common/recordset_compat.py b/src/py/flwr/common/recordset_compat.py index 35024fcd67d..4641b8f29c9 100644 --- a/src/py/flwr/common/recordset_compat.py +++ b/src/py/flwr/common/recordset_compat.py @@ -59,6 +59,11 @@ def parametersrecord_to_parameters( keep_input : bool A boolean indicating whether entries in the record should be deleted from the input dictionary immediately after adding them to the record. + + Returns + ------- + parameters : Parameters + The parameters in the legacy format Parameters. """ parameters = Parameters(tensors=[], tensor_type="") @@ -94,6 +99,11 @@ def parameters_to_parametersrecord( A boolean indicating whether parameters should be deleted from the input Parameters object (i.e. a list of serialized NumPy arrays) immediately after adding them to the record. + + Returns + ------- + ParametersRecord + The ParametersRecord containing the provided parameters. """ tensor_type = parameters.tensor_type diff --git a/src/py/flwr/common/retry_invoker.py b/src/py/flwr/common/retry_invoker.py index 303d5596f23..9785b0fbd9b 100644 --- a/src/py/flwr/common/retry_invoker.py +++ b/src/py/flwr/common/retry_invoker.py @@ -38,6 +38,11 @@ def exponential( Factor by which the delay is multiplied after each retry. max_delay: Optional[float] (default: None) The maximum delay duration between two consecutive retries. + + Returns + ------- + Generator[float, None, None] + A generator for the delay between 2 retries. """ delay = base_delay if max_delay is None else min(base_delay, max_delay) while True: @@ -56,6 +61,11 @@ def constant( ---------- interval: Union[float, Iterable[float]] (default: 1) A constant value to yield or an iterable of such values. + + Returns + ------- + Generator[float, None, None] + A generator for the delay between 2 retries. """ if not isinstance(interval, Iterable): interval = itertools.repeat(interval) @@ -73,6 +83,11 @@ def full_jitter(max_value: float) -> float: ---------- max_value : float The upper limit for the randomized value. + + Returns + ------- + float + A random float that is less than max_value. """ return random.uniform(0, max_value) diff --git a/src/py/flwr/server/client_manager.py b/src/py/flwr/server/client_manager.py index 175bd4a786e..9949e29f8f7 100644 --- a/src/py/flwr/server/client_manager.py +++ b/src/py/flwr/server/client_manager.py @@ -47,6 +47,7 @@ def register(self, client: ClientProxy) -> bool: Parameters ---------- client : flwr.server.client_proxy.ClientProxy + The ClientProxy of the Client to register. Returns ------- @@ -64,6 +65,7 @@ def unregister(self, client: ClientProxy) -> None: Parameters ---------- client : flwr.server.client_proxy.ClientProxy + The ClientProxy of the Client to unregister. """ @abstractmethod diff --git a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py index b161492000f..70283c0e129 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py +++ b/src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py @@ -174,7 +174,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments Parameters ---------- - servicer_and_add_fn : Tuple + servicer_and_add_fn : tuple A tuple holding a servicer implementation and a matching add_Servicer_to_server function. server_address : str @@ -214,6 +214,8 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments * CA certificate. * server certificate. * server private key. + interceptors : Optional[Sequence[grpc.ServerInterceptor]] (default: None) + A list of gRPC interceptors. Returns ------- diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index f8ae5a7e95b..2c4519d8c14 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -116,6 +116,7 @@ def get_task_ins( # Return TaskIns return task_ins_list + # pylint: disable=R0911 def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: """Store one TaskRes.""" # Validate task @@ -129,6 +130,17 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: task_ins_id = task_res.task.ancestry[0] task_ins = self.task_ins_store.get(UUID(task_ins_id)) + # Ensure that the consumer_id of taskIns matches the producer_id of taskRes. + if ( + task_ins + and task_res + and not ( + task_ins.task.consumer.anonymous or task_res.task.producer.anonymous + ) + and task_ins.task.consumer.node_id != task_res.task.producer.node_id + ): + return None + if task_ins is None: log(ERROR, "TaskIns with task_id %s does not exist.", task_ins_id) return None diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index 502b1e2461b..182b55d6f77 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -151,6 +151,11 @@ def initialize(self, log_queries: bool = False) -> list[tuple[str]]: ---------- log_queries : bool Log each query which is executed. + + Returns + ------- + list[tuple[str]] + The list of all tables in the DB. """ self.conn = sqlite3.connect(self.database_path) self.conn.execute("PRAGMA foreign_keys = ON;") @@ -390,6 +395,16 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: ) return None + # Ensure that the consumer_id of taskIns matches the producer_id of taskRes. + if ( + task_ins + and task_res + and not (task_ins["consumer_anonymous"] or task_res.task.producer.anonymous) + and convert_sint64_to_uint64(task_ins["consumer_node_id"]) + != task_res.task.producer.node_id + ): + return None + # Fail if the TaskRes TTL exceeds the # expiration time of the TaskIns it replies to. # Condition: TaskIns.created_at + TaskIns.ttl ≥ diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 2a5eab30b4b..c3e0ac70d56 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """Tests all state implemenations have to conform to.""" -# pylint: disable=invalid-name, disable=R0904,R0913 +# pylint: disable=invalid-name, too-many-lines, R0904, R0913 import tempfile import time @@ -149,7 +149,7 @@ def test_store_and_delete_tasks(self) -> None: # Insert one TaskRes and retrive it to mark it as delivered task_res_0 = create_task_res( - producer_node_id=100, + producer_node_id=consumer_node_id, anonymous=False, ancestry=[str(task_id_0)], run_id=run_id, @@ -160,7 +160,7 @@ def test_store_and_delete_tasks(self) -> None: # Insert one TaskRes, but don't retrive it task_res_1: TaskRes = create_task_res( - producer_node_id=100, + producer_node_id=consumer_node_id, anonymous=False, ancestry=[str(task_id_1)], run_id=run_id, @@ -662,7 +662,7 @@ def test_node_unavailable_error(self) -> None: # Create and store TaskRes task_res_0 = create_task_res( - producer_node_id=100, + producer_node_id=node_id_0, anonymous=False, ancestry=[str(task_id_0)], run_id=run_id, @@ -871,6 +871,32 @@ def test_get_task_res_return_if_not_expired(self) -> None: # Assert assert len(task_res_list) != 0 + def test_store_task_res_fail_if_consumer_producer_id_mismatch(self) -> None: + """Test store_task_res to fail if there is a mismatch between the + consumer_node_id of taskIns and the producer_node_id of taskRes.""" + # Prepare + consumer_node_id = 1 + state = self.state_factory() + run_id = state.create_run(None, None, "9f86d08", {}) + task_ins = create_task_ins( + consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id + ) + + task_id = state.store_task_ins(task_ins=task_ins) + + task_res = create_task_res( + producer_node_id=100, # different than consumer_node_id + anonymous=False, + ancestry=[str(task_id)], + run_id=run_id, + ) + + # Execute + task_res_uuid = state.store_task_res(task_res=task_res) + + # Assert + assert task_res_uuid is None + def create_task_ins( consumer_node_id: int, diff --git a/src/py/flwr/server/superlink/state/utils.py b/src/py/flwr/server/superlink/state/utils.py index 00ba02d2e43..db44719c6a8 100644 --- a/src/py/flwr/server/superlink/state/utils.py +++ b/src/py/flwr/server/superlink/state/utils.py @@ -100,11 +100,6 @@ def convert_uint64_values_in_dict_to_sint64( A dictionary where the values are integers to be converted. keys : list[str] A list of keys in the dictionary whose values need to be converted. - - Returns - ------- - None - This function does not return a value. It modifies `data_dict` in place. """ for key in keys: if key in data_dict: @@ -122,11 +117,6 @@ def convert_sint64_values_in_dict_to_uint64( A dictionary where the values are integers to be converted. keys : list[str] A list of keys in the dictionary whose values need to be converted. - - Returns - ------- - None - This function does not return a value. It modifies `data_dict` in place. """ for key in keys: if key in data_dict: