Skip to content

Commit

Permalink
Merge branch 'main' into add-fab-hash-install
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Oct 10, 2024
2 parents 7ae9b60 + c684cf4 commit 96921ed
Show file tree
Hide file tree
Showing 15 changed files with 160 additions and 15 deletions.
36 changes: 36 additions & 0 deletions .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,39 @@ jobs:
if grep -q "ERROR" flwr_output.log; then
exit 1
fi
build_and_install:
runs-on: ubuntu-22.04
timeout-minutes: 10
needs: wheel
strategy:
matrix:
framework: ["numpy"]
python-version: ["3.9", "3.10", "3.11"]

name: |
Build & Install /
Python ${{ matrix.python-version }} /
${{ matrix.framework }}
steps:
- uses: actions/checkout@v4
- name: Bootstrap
uses: ./.github/actions/bootstrap
with:
python-version: ${{ matrix.python-version }}
poetry-skip: 'true'
- name: Install Flower from repo
if: ${{ github.repository != 'adap/flower' || github.event.pull_request.head.repo.fork || github.actor == 'dependabot[bot]' }}
run: |
python -m pip install .
- name: Install Flower wheel from artifact store
if: ${{ github.repository == 'adap/flower' && !github.event.pull_request.head.repo.fork && github.actor != 'dependabot[bot]' }}
run: |
python -m pip install https://${{ env.ARTIFACT_BUCKET }}/py/${{ needs.wheel.outputs.dir }}/${{ needs.wheel.outputs.short_sha }}/${{ needs.wheel.outputs.whl_path }}
- name: Create project, build, and install it
run: |
flwr new tmp-${{ matrix.framework }} --framework ${{ matrix.framework }} --username gh_ci
cd tmp-${{ matrix.framework }}
flwr build
flwr install *.fab
4 changes: 4 additions & 0 deletions dev/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ echo "- docformatter: start"
python -m docformatter -c -r src/py/flwr e2e -e src/py/flwr/proto
echo "- docformatter: done"

echo "- docsig: start"
docsig src/py/flwr
echo "- docsig: done"

echo "- ruff: start"
python -m ruff check src/py/flwr
echo "- ruff: done"
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ pre-commit = "==3.5.0"
sphinx-substitution-extensions = "2022.02.16"
sphinxext-opengraph = "==0.9.1"
docstrfmt = { git = "https://github.com/charlesbvll/docstrfmt.git", branch = "patch-1" }
docsig = "==0.64.0"

[tool.docstrfmt]
extend_exclude = [
Expand Down Expand Up @@ -224,3 +225,7 @@ convention = "numpy"

[tool.ruff.per-file-ignores]
"src/py/flwr/server/strategy/*.py" = ["E501"]

[tool.docsig]
ignore-no-params = true
exclude = 'src/py/flwr/proto/.*|src/py/flwr/.*_test\.py|src/py/flwr/cli/new/templates/.*\.tpl'
10 changes: 10 additions & 0 deletions src/py/flwr/cli/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ def load_and_validate(
) -> tuple[Optional[dict[str, Any]], list[str], list[str]]:
"""Load and validate pyproject.toml as dict.
Parameters
----------
path : Optional[Path] (default: None)
The path of the Flower App config file to load. By default it
will try to use `pyproject.toml` inside the current directory.
check_module: bool (default: True)
Whether the validity of the Python module should be checked.
This requires the project to be installed in the currently
running environment. True by default.
Returns
-------
Tuple[Optional[config], List[str], List[str]]
Expand Down
10 changes: 10 additions & 0 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ class `flwr.client.Client` (default: None)
- 'grpc-bidi': gRPC, bidirectional streaming
- 'grpc-rere': gRPC, request-response (experimental)
- 'rest': HTTP (experimental)
authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
Tuple containing the elliptic curve private key and public key for
authentication from the cryptography library.
Source: https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/
Used to establish an authenticated connection with the server.
max_retries: Optional[int] (default: None)
The maximum number of times the client will try to connect to the
server before giving up in case of a connection error. If set to None,
Expand Down Expand Up @@ -249,6 +254,11 @@ class `flwr.client.Client` (default: None)
- 'grpc-bidi': gRPC, bidirectional streaming
- 'grpc-rere': gRPC, request-response (experimental)
- 'rest': HTTP (experimental)
authentication_keys : Optional[Tuple[PrivateKey, PublicKey]] (default: None)
Tuple containing the elliptic curve private key and public key for
authentication from the cryptography library.
Source: https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/
Used to establish an authenticated connection with the server.
max_retries: Optional[int] (default: None)
The maximum number of times the client will try to connect to the
server before giving up in case of a connection error. If set to None,
Expand Down
3 changes: 3 additions & 0 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915
authentication from the cryptography library.
Source: https://cryptography.io/en/latest/hazmat/primitives/asymmetric/ec/
Used to establish an authenticated connection with the server.
adapter_cls: Optional[Union[type[FleetStub], type[GrpcAdapter]]] (default: None)
A GrpcStub Class that can be used to send messages. By default the FleetStub
will be used.
Returns
-------
Expand Down
5 changes: 5 additions & 0 deletions src/py/flwr/common/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,11 @@ def create_error_reply(self, error: Error, ttl: float | None = None) -> Message:
follows the equation:
ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at)
Returns
-------
message : Message
A Message containing only the relevant error and metadata.
"""
# If no TTL passed, use default for message creation (will update after
# message creation)
Expand Down
10 changes: 10 additions & 0 deletions src/py/flwr/common/recordset_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def parametersrecord_to_parameters(
keep_input : bool
A boolean indicating whether entries in the record should be deleted from the
input dictionary immediately after adding them to the record.
Returns
-------
parameters : Parameters
The parameters in the legacy format Parameters.
"""
parameters = Parameters(tensors=[], tensor_type="")

Expand Down Expand Up @@ -94,6 +99,11 @@ def parameters_to_parametersrecord(
A boolean indicating whether parameters should be deleted from the input
Parameters object (i.e. a list of serialized NumPy arrays) immediately after
adding them to the record.
Returns
-------
ParametersRecord
The ParametersRecord containing the provided parameters.
"""
tensor_type = parameters.tensor_type

Expand Down
15 changes: 15 additions & 0 deletions src/py/flwr/common/retry_invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def exponential(
Factor by which the delay is multiplied after each retry.
max_delay: Optional[float] (default: None)
The maximum delay duration between two consecutive retries.
Returns
-------
Generator[float, None, None]
A generator for the delay between 2 retries.
"""
delay = base_delay if max_delay is None else min(base_delay, max_delay)
while True:
Expand All @@ -56,6 +61,11 @@ def constant(
----------
interval: Union[float, Iterable[float]] (default: 1)
A constant value to yield or an iterable of such values.
Returns
-------
Generator[float, None, None]
A generator for the delay between 2 retries.
"""
if not isinstance(interval, Iterable):
interval = itertools.repeat(interval)
Expand All @@ -73,6 +83,11 @@ def full_jitter(max_value: float) -> float:
----------
max_value : float
The upper limit for the randomized value.
Returns
-------
float
A random float that is less than max_value.
"""
return random.uniform(0, max_value)

Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/client_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def register(self, client: ClientProxy) -> bool:
Parameters
----------
client : flwr.server.client_proxy.ClientProxy
The ClientProxy of the Client to register.
Returns
-------
Expand All @@ -64,6 +65,7 @@ def unregister(self, client: ClientProxy) -> None:
Parameters
----------
client : flwr.server.client_proxy.ClientProxy
The ClientProxy of the Client to unregister.
"""

@abstractmethod
Expand Down
4 changes: 3 additions & 1 deletion src/py/flwr/server/superlink/fleet/grpc_bidi/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
Parameters
----------
servicer_and_add_fn : Tuple
servicer_and_add_fn : tuple
A tuple holding a servicer implementation and a matching
add_Servicer_to_server function.
server_address : str
Expand Down Expand Up @@ -214,6 +214,8 @@ def generic_create_grpc_server( # pylint: disable=too-many-arguments
* CA certificate.
* server certificate.
* server private key.
interceptors : Optional[Sequence[grpc.ServerInterceptor]] (default: None)
A list of gRPC interceptors.
Returns
-------
Expand Down
12 changes: 12 additions & 0 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def get_task_ins(
# Return TaskIns
return task_ins_list

# pylint: disable=R0911
def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
"""Store one TaskRes."""
# Validate task
Expand All @@ -129,6 +130,17 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
task_ins_id = task_res.task.ancestry[0]
task_ins = self.task_ins_store.get(UUID(task_ins_id))

# Ensure that the consumer_id of taskIns matches the producer_id of taskRes.
if (
task_ins
and task_res
and not (
task_ins.task.consumer.anonymous or task_res.task.producer.anonymous
)
and task_ins.task.consumer.node_id != task_res.task.producer.node_id
):
return None

if task_ins is None:
log(ERROR, "TaskIns with task_id %s does not exist.", task_ins_id)
return None
Expand Down
15 changes: 15 additions & 0 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ def initialize(self, log_queries: bool = False) -> list[tuple[str]]:
----------
log_queries : bool
Log each query which is executed.
Returns
-------
list[tuple[str]]
The list of all tables in the DB.
"""
self.conn = sqlite3.connect(self.database_path)
self.conn.execute("PRAGMA foreign_keys = ON;")
Expand Down Expand Up @@ -390,6 +395,16 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
)
return None

# Ensure that the consumer_id of taskIns matches the producer_id of taskRes.
if (
task_ins
and task_res
and not (task_ins["consumer_anonymous"] or task_res.task.producer.anonymous)
and convert_sint64_to_uint64(task_ins["consumer_node_id"])
!= task_res.task.producer.node_id
):
return None

# Fail if the TaskRes TTL exceeds the
# expiration time of the TaskIns it replies to.
# Condition: TaskIns.created_at + TaskIns.ttl ≥
Expand Down
34 changes: 30 additions & 4 deletions src/py/flwr/server/superlink/state/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Tests all state implemenations have to conform to."""
# pylint: disable=invalid-name, disable=R0904,R0913
# pylint: disable=invalid-name, too-many-lines, R0904, R0913

import tempfile
import time
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_store_and_delete_tasks(self) -> None:

# Insert one TaskRes and retrive it to mark it as delivered
task_res_0 = create_task_res(
producer_node_id=100,
producer_node_id=consumer_node_id,
anonymous=False,
ancestry=[str(task_id_0)],
run_id=run_id,
Expand All @@ -160,7 +160,7 @@ def test_store_and_delete_tasks(self) -> None:

# Insert one TaskRes, but don't retrive it
task_res_1: TaskRes = create_task_res(
producer_node_id=100,
producer_node_id=consumer_node_id,
anonymous=False,
ancestry=[str(task_id_1)],
run_id=run_id,
Expand Down Expand Up @@ -662,7 +662,7 @@ def test_node_unavailable_error(self) -> None:

# Create and store TaskRes
task_res_0 = create_task_res(
producer_node_id=100,
producer_node_id=node_id_0,
anonymous=False,
ancestry=[str(task_id_0)],
run_id=run_id,
Expand Down Expand Up @@ -871,6 +871,32 @@ def test_get_task_res_return_if_not_expired(self) -> None:
# Assert
assert len(task_res_list) != 0

def test_store_task_res_fail_if_consumer_producer_id_mismatch(self) -> None:
"""Test store_task_res to fail if there is a mismatch between the
consumer_node_id of taskIns and the producer_node_id of taskRes."""
# Prepare
consumer_node_id = 1
state = self.state_factory()
run_id = state.create_run(None, None, "9f86d08", {})
task_ins = create_task_ins(
consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id
)

task_id = state.store_task_ins(task_ins=task_ins)

task_res = create_task_res(
producer_node_id=100, # different than consumer_node_id
anonymous=False,
ancestry=[str(task_id)],
run_id=run_id,
)

# Execute
task_res_uuid = state.store_task_res(task_res=task_res)

# Assert
assert task_res_uuid is None


def create_task_ins(
consumer_node_id: int,
Expand Down
10 changes: 0 additions & 10 deletions src/py/flwr/server/superlink/state/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,6 @@ def convert_uint64_values_in_dict_to_sint64(
A dictionary where the values are integers to be converted.
keys : list[str]
A list of keys in the dictionary whose values need to be converted.
Returns
-------
None
This function does not return a value. It modifies `data_dict` in place.
"""
for key in keys:
if key in data_dict:
Expand All @@ -122,11 +117,6 @@ def convert_sint64_values_in_dict_to_uint64(
A dictionary where the values are integers to be converted.
keys : list[str]
A list of keys in the dictionary whose values need to be converted.
Returns
-------
None
This function does not return a value. It modifies `data_dict` in place.
"""
for key in keys:
if key in data_dict:
Expand Down

0 comments on commit 96921ed

Please sign in to comment.