Skip to content

Commit

Permalink
Changing mocking in tests for RestMiddleware and testing versioning
Browse files Browse the repository at this point in the history
  • Loading branch information
vzotova committed Aug 8, 2024
1 parent f735076 commit 43947dc
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 13 deletions.
16 changes: 6 additions & 10 deletions porter/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import click
from marshmallow import fields
from packaging.version import Version, parse
from packaging.version import parse

from porter.fields.exceptions import InvalidInputData

Expand Down Expand Up @@ -113,12 +113,8 @@ def _deserialize(self, value, attr, data, **kwargs):

class VersionString(String):

def _serialize(self, value, attr, obj, **kwargs) -> str:
if type(value) is not Version:
raise InvalidInputData(
f"Unexpected object type, {type(value)}; expected Version"
)
return str(value)

def _deserialize(self, value, attr, data, **kwargs) -> list:
return parse(value)
def _validate(self, value):
try:
parse(value)
except Exception:
raise InvalidInputData(f"{self.name} must be a correct version.")
2 changes: 1 addition & 1 deletion porter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _initialize_endpoints(eth_endpoint: str, polygon_endpoint: str):

@staticmethod
def _is_version_greater_or_equal(min_version: str, version: str) -> bool:
return parse(version) >= min_version
return parse(version) >= parse(min_version)

def _get_ursula_version(self, ursula: Ursula) -> str:
response = self.network_middleware.client.get(
Expand Down
49 changes: 47 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@
from tests.constants import (
MOCK_ETH_PROVIDER_URI,
TEMPORARY_DOMAIN,
TEST_ETH_PROVIDER_URI,
TESTERCHAIN_CHAIN_ID,
)
from tests.mock.interfaces import MockBlockchain
from tests.utils.middleware import MockRestMiddleware, _TestMiddlewareClient
from tests.utils.registry import MockRegistrySource, mock_registry_sources

# Crash on server error by default
Expand Down Expand Up @@ -245,9 +247,53 @@ def mock_signer(get_random_checksum_address):
return signer


class _MockMiddlewareClient(_TestMiddlewareClient):
class MockResponse:
def __init__(self, json_data, status_code):
self.json_data = json_data
self.status_code = status_code

def json(self):
return self.json_data

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ursulas_versions = {}

def get(self, *args, **kwargs):
if kwargs.get("path") == "status" and kwargs.get("params")["json"]:
node_address = kwargs.get("node_or_sprout").checksum_address
version = self.ursulas_versions.get(node_address, "1.1.1")
return _MockMiddlewareClient.MockResponse({"version": version}, 200)

real_get = super(_TestMiddlewareClient, self).__getattr__("get")
return real_get(*args, **kwargs)


class _MockRestMiddleware(MockRestMiddleware):
"""
Modified middleware to emulate returning status with version.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = _MockMiddlewareClient(eth_endpoint=TEST_ETH_PROVIDER_URI)

def set_ursulas_versions(self, ursulas_versions: dict):
self.client.ursulas_versions = ursulas_versions

def clean_ursulas_versions(self):
self.client.ursulas_versions = {}


@pytest.fixture(scope="module")
def mock_rest_middleware():
return _MockRestMiddleware(eth_endpoint=TEST_ETH_PROVIDER_URI)


@pytest.fixture(scope="module")
@pytest.mark.usefixtures('testerchain', 'agency')
def porter(ursulas, mock_rest_middleware, test_registry, module_mocker):
def porter(ursulas, mock_rest_middleware, test_registry):
porter = Porter(
domain=TEMPORARY_DOMAIN,
eth_endpoint=MOCK_ETH_PROVIDER_URI,
Expand All @@ -259,7 +305,6 @@ def porter(ursulas, mock_rest_middleware, test_registry, module_mocker):
verify_node_bonding=False,
network_middleware=mock_rest_middleware,
)
module_mocker.patch.object(porter, "_get_ursula_version", return_value="7.4.0")
yield porter
porter.stop_learning_loop()

Expand Down
57 changes: 57 additions & 0 deletions tests/test_bucket_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def test_bucket_sampling_schema(get_random_checksum_address):
updated_data["timeout"] = 20
BucketSampling().load(updated_data)

# min version
updated_data = dict(required_data)
updated_data["min_version"] = "1.1.1"
BucketSampling().load(updated_data)

# list input formatted as ',' separated strings
updated_data = dict(required_data)
updated_data["exclude_ursulas"] = ",".join(exclude_ursulas)
Expand Down Expand Up @@ -133,6 +138,18 @@ def test_bucket_sampling_schema(get_random_checksum_address):
updated_data["duration"] = -1
BucketSampling().load(updated_data)

# invalid min version
with pytest.raises(InvalidInputData):
updated_data = dict(required_data)
updated_data["min_version"] = "v1x1.1"
BucketSampling().load(updated_data)

# invalid min version
with pytest.raises(InvalidInputData):
updated_data = dict(required_data)
updated_data["min_version"] = "1-1-1"
BucketSampling().load(updated_data)

#
# Output i.e. dump
#
Expand Down Expand Up @@ -210,11 +227,22 @@ def test_bucket_sampling_python_interface(
with pytest.raises(WorkerPool.OutOfValues):
_, _ = porter.bucket_sampling(quantity=5)

# no nodes with specified version
# with pytest.raises(WorkerPool.OutOfValues):
# _, _ = porter.bucket_sampling(quantity=1, min_version="2.2.2")
# porter.network_middleware.set_ursulas_versions({sampled_ursulas[1]: "3.0.0"})
# ursulas_info, _ = porter.bucket_sampling(quantity=1, min_version="2.2.2")
# assert ursulas_info[0] == sampled_ursulas[1]
# with pytest.raises(WorkerPool.OutOfValues):
# porter.bucket_sampling(quantity=2, min_version="2.2.2")
# porter.network_middleware.clean_ursulas_versions()


@pytest.mark.parametrize("timeout", [None, 10])
@pytest.mark.parametrize("random_seed", [None, 42])
@pytest.mark.parametrize("duration", [None, 0, 60 * 60 * 24, 60 * 60 * 24 * 365])
def test_bucket_sampling_web_interface(
porter,
porter_web_controller,
ursulas,
timeout,
Expand Down Expand Up @@ -310,3 +338,32 @@ def test_bucket_sampling_web_interface(
)
assert response.status_code == 400
assert "Insufficient nodes" in response.text

#
# Failure case: no nodes with specified version
#
failed_ursula_params = dict(get_ursulas_params)
failed_ursula_params["quantity"] = 1
failed_ursula_params["min_version"] = "2.0.0"
# del failed_ursula_params["include_ursulas"]
response = porter_web_controller.get(
"/bucket_sampling", data=json.dumps(failed_ursula_params)
)
assert "has too old version (1.1.1)" in response.text

porter.network_middleware.set_ursulas_versions({sampled_ursulas[0]: "3.0.0"})
response = porter_web_controller.get(
"/bucket_sampling", data=json.dumps(failed_ursula_params)
)
# assert response.status_code == 200
response_data = json.loads(response.data)
print(response_data)
ursulas_info = response_data["result"]["ursulas"]
assert ursulas_info[0] == sampled_ursulas[0]

failed_ursula_params["quantity"] = 2
response = porter_web_controller.get(
"/bucket_sampling", data=json.dumps(failed_ursula_params)
)
assert "has too old version (1.1.1)" in response.text
porter.network_middleware.clean_ursulas_versions()
59 changes: 59 additions & 0 deletions tests/test_get_ursulas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json

import pytest
from nucypher.utilities.concurrency import WorkerPool
from nucypher_core.umbral import SecretKey

from porter.fields.exceptions import InvalidArgumentCombo, InvalidInputData
Expand Down Expand Up @@ -88,6 +89,11 @@ def test_get_ursulas_schema(get_random_checksum_address):
assert data["exclude_ursulas"] == [exclude_ursulas[0]]
assert data["include_ursulas"] == [include_ursulas[0]]

# min version
updated_data = dict(required_data)
updated_data["min_version"] = "1.1.1"
GetUrsulas().load(updated_data)

# invalid include entry
updated_data = dict(required_data)
updated_data["exclude_ursulas"] = exclude_ursulas
Expand Down Expand Up @@ -171,6 +177,18 @@ def test_get_ursulas_schema(get_random_checksum_address):
updated_data["duration"] = -1
GetUrsulas().load(updated_data)

# invalid min version
with pytest.raises(InvalidInputData):
updated_data = dict(required_data)
updated_data["min_version"] = "v1x1.1"
GetUrsulas().load(updated_data)

# invalid min version
with pytest.raises(InvalidInputData):
updated_data = dict(required_data)
updated_data["min_version"] = "1-1-1"
GetUrsulas().load(updated_data)

#
# Output i.e. dump
#
Expand Down Expand Up @@ -282,10 +300,23 @@ def test_get_ursulas_python_interface(
with pytest.raises(ValueError, match="Insufficient nodes"):
porter.get_ursulas(quantity=len(ursulas) + 1)

# no nodes with specified version
with pytest.raises(WorkerPool.OutOfValues):
porter.get_ursulas(quantity=1, min_version="2.2.2")
porter.network_middleware.set_ursulas_versions(
{ursulas[0].checksum_address: "3.0.0"}
)
ursulas_info = porter.get_ursulas(quantity=1, min_version="2.2.2")
assert ursulas[0].checksum_address == ursulas_info[0].checksum_address
with pytest.raises(WorkerPool.OutOfValues):
porter.get_ursulas(quantity=2, min_version="2.2.2")
porter.network_middleware.clean_ursulas_versions()


@pytest.mark.parametrize("timeout", [None, 10, 20])
@pytest.mark.parametrize("duration", [None, 0, 60 * 60 * 24, 60 * 60 * 24 * 365])
def test_get_ursulas_web_interface(
porter,
porter_web_controller,
ursulas,
timeout,
Expand Down Expand Up @@ -388,3 +419,31 @@ def test_get_ursulas_web_interface(
)
assert response.status_code == 400
assert "Insufficient nodes" in response.text

#
# Failure case: no nodes with specified version
#
failed_ursula_params = dict(get_ursulas_params)
failed_ursula_params["quantity"] = 1
failed_ursula_params["min_version"] = "2.0.0"
del failed_ursula_params["include_ursulas"]
response = porter_web_controller.get(
"/get_ursulas", data=json.dumps(failed_ursula_params)
)
assert "has too old version (1.1.1)" in response.text

porter.network_middleware.set_ursulas_versions({include_ursulas[0]: "3.0.0"})
response = porter_web_controller.get(
"/get_ursulas", data=json.dumps(failed_ursula_params)
)
assert response.status_code == 200
response_data = json.loads(response.data)
ursulas_info = response_data["result"]["ursulas"]
assert ursulas_info[0]["checksum_address"] == include_ursulas[0]

failed_ursula_params["quantity"] = 2
response = porter_web_controller.get(
"/get_ursulas", data=json.dumps(failed_ursula_params)
)
assert "has too old version (1.1.1)" in response.text
porter.network_middleware.clean_ursulas_versions()

0 comments on commit 43947dc

Please sign in to comment.