Skip to content

Commit

Permalink
fix(framework:skip) Check consumer_id and producer_id when saving…
Browse files Browse the repository at this point in the history
… TaskRes (#4313)
  • Loading branch information
mohammadnaseri authored Oct 9, 2024
1 parent 96e548f commit 334e534
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 4 deletions.
12 changes: 12 additions & 0 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def get_task_ins(
# Return TaskIns
return task_ins_list

# pylint: disable=R0911
def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
"""Store one TaskRes."""
# Validate task
Expand All @@ -129,6 +130,17 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
task_ins_id = task_res.task.ancestry[0]
task_ins = self.task_ins_store.get(UUID(task_ins_id))

# Ensure that the consumer_id of taskIns matches the producer_id of taskRes.
if (
task_ins
and task_res
and not (
task_ins.task.consumer.anonymous or task_res.task.producer.anonymous
)
and task_ins.task.consumer.node_id != task_res.task.producer.node_id
):
return None

if task_ins is None:
log(ERROR, "TaskIns with task_id %s does not exist.", task_ins_id)
return None
Expand Down
10 changes: 10 additions & 0 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,16 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
)
return None

# Ensure that the consumer_id of taskIns matches the producer_id of taskRes.
if (
task_ins
and task_res
and not (task_ins["consumer_anonymous"] or task_res.task.producer.anonymous)
and convert_sint64_to_uint64(task_ins["consumer_node_id"])
!= task_res.task.producer.node_id
):
return None

# Fail if the TaskRes TTL exceeds the
# expiration time of the TaskIns it replies to.
# Condition: TaskIns.created_at + TaskIns.ttl ≥
Expand Down
34 changes: 30 additions & 4 deletions src/py/flwr/server/superlink/state/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Tests all state implemenations have to conform to."""
# pylint: disable=invalid-name, disable=R0904,R0913
# pylint: disable=invalid-name, too-many-lines, R0904, R0913

import tempfile
import time
Expand Down Expand Up @@ -149,7 +149,7 @@ def test_store_and_delete_tasks(self) -> None:

# Insert one TaskRes and retrive it to mark it as delivered
task_res_0 = create_task_res(
producer_node_id=100,
producer_node_id=consumer_node_id,
anonymous=False,
ancestry=[str(task_id_0)],
run_id=run_id,
Expand All @@ -160,7 +160,7 @@ def test_store_and_delete_tasks(self) -> None:

# Insert one TaskRes, but don't retrive it
task_res_1: TaskRes = create_task_res(
producer_node_id=100,
producer_node_id=consumer_node_id,
anonymous=False,
ancestry=[str(task_id_1)],
run_id=run_id,
Expand Down Expand Up @@ -662,7 +662,7 @@ def test_node_unavailable_error(self) -> None:

# Create and store TaskRes
task_res_0 = create_task_res(
producer_node_id=100,
producer_node_id=node_id_0,
anonymous=False,
ancestry=[str(task_id_0)],
run_id=run_id,
Expand Down Expand Up @@ -871,6 +871,32 @@ def test_get_task_res_return_if_not_expired(self) -> None:
# Assert
assert len(task_res_list) != 0

def test_store_task_res_fail_if_consumer_producer_id_mismatch(self) -> None:
"""Test store_task_res to fail if there is a mismatch between the
consumer_node_id of taskIns and the producer_node_id of taskRes."""
# Prepare
consumer_node_id = 1
state = self.state_factory()
run_id = state.create_run(None, None, "9f86d08", {})
task_ins = create_task_ins(
consumer_node_id=consumer_node_id, anonymous=False, run_id=run_id
)

task_id = state.store_task_ins(task_ins=task_ins)

task_res = create_task_res(
producer_node_id=100, # different than consumer_node_id
anonymous=False,
ancestry=[str(task_id)],
run_id=run_id,
)

# Execute
task_res_uuid = state.store_task_res(task_res=task_res)

# Assert
assert task_res_uuid is None


def create_task_ins(
consumer_node_id: int,
Expand Down

0 comments on commit 334e534

Please sign in to comment.