diff --git a/README.rst b/README.rst index 27c3f18..a26bf08 100644 --- a/README.rst +++ b/README.rst @@ -339,6 +339,9 @@ Parameters | | | | are greater than this max default value are | | | | | capped at the default value | +----------------------------------+------------------+------------------------------------------------+ +| ``min_version`` | *(Optional)* | | Minimum acceptable version of Ursula. | +| | VersionString | | | ++----------------------------------+------------------+------------------------------------------------+ Returns diff --git a/porter/fields/base.py b/porter/fields/base.py index 68dc27e..1d0c65f 100644 --- a/porter/fields/base.py +++ b/porter/fields/base.py @@ -3,6 +3,7 @@ import click from marshmallow import fields +from packaging.version import parse from porter.fields.exceptions import InvalidInputData @@ -108,3 +109,12 @@ def _deserialize(self, value, attr, data, **kwargs): f"Unexpected object type, {type(result)}; expected {self.expected_type}") return result + + +class VersionString(String): + + def _validate(self, value): + try: + parse(value) + except Exception: + raise InvalidInputData(f"{self.name} must be a correct version.") diff --git a/porter/interfaces.py b/porter/interfaces.py index 0efc634..69aa027 100644 --- a/porter/interfaces.py +++ b/porter/interfaces.py @@ -40,6 +40,7 @@ def get_ursulas( include_ursulas: Optional[List[ChecksumAddress]] = None, timeout: Optional[int] = None, duration: Optional[int] = None, + min_version: Optional[str] = None, ) -> Dict: ursulas_info = self.implementer.get_ursulas( quantity=quantity, @@ -47,6 +48,7 @@ def get_ursulas( include_ursulas=include_ursulas, timeout=timeout, duration=duration, + min_version=min_version, ) response_data = {"ursulas": ursulas_info} # list of UrsulaInfo objects @@ -104,6 +106,7 @@ def bucket_sampling( exclude_ursulas: Optional[List[ChecksumAddress]] = None, timeout: Optional[int] = None, duration: Optional[int] = None, + min_version: Optional[str] = None, ) -> Dict: ursulas, block_number = self.implementer.bucket_sampling( quantity=quantity, @@ -111,6 +114,7 @@ def bucket_sampling( exclude_ursulas=exclude_ursulas, timeout=timeout, duration=duration, + min_version=min_version, ) response_data = {"ursulas": ursulas, "block_number": block_number} diff --git a/porter/main.py b/porter/main.py index 81e19fc..767a69a 100644 --- a/porter/main.py +++ b/porter/main.py @@ -32,6 +32,7 @@ TreasureMap, ) from nucypher_core.umbral import PublicKey +from packaging.version import Version, parse from prometheus_flask_exporter import PrometheusMetrics import porter @@ -100,6 +101,12 @@ class DecryptOutcome(NamedTuple): ] errors: Dict[ChecksumAddress, str] + class UrsulaVersionTooOld(Exception): + def __init__(self, ursula_address: str, version: str, min_version: str): + super().__init__( + f"Ursula ({ursula_address}) version is too old ({version} < {min_version})" + ) + def __init__( self, eth_endpoint: str, @@ -155,6 +162,16 @@ def _initialize_endpoints(eth_endpoint: str, polygon_endpoint: str): ): BlockchainInterfaceFactory.initialize_interface(endpoint=polygon_endpoint) + @staticmethod + def _is_version_greater_or_equal(min_version: Version, version: str) -> bool: + return parse(version) >= min_version + + def _get_ursula_version(self, ursula: Ursula) -> str: + response = self.network_middleware.client.get( + node_or_sprout=ursula, path="status", params={"json": "true"} + ) + return response.json()["version"] + def get_ursulas( self, quantity: int, @@ -162,11 +179,13 @@ def get_ursulas( include_ursulas: Optional[Sequence[ChecksumAddress]] = None, timeout: Optional[int] = None, duration: Optional[int] = None, + min_version: Optional[str] = None, ) -> List[UrsulaInfo]: timeout = self._configure_timeout( "sampling", timeout, self.MAX_GET_URSULAS_TIMEOUT ) duration = duration or 0 + parse_min_version = parse(min_version) if min_version else None reservoir = self._make_reservoir(exclude_ursulas, include_ursulas, duration) available_nodes_to_sample = len(reservoir.values) + len(reservoir.reservoir) @@ -184,11 +203,18 @@ def get_ursula_info(ursula_address) -> Porter.UrsulaInfo: ursula_address = to_checksum_address(ursula_address) ursula = self.known_nodes[ursula_address] try: - # ensure node is up and reachable - self.network_middleware.ping(ursula) - return Porter.UrsulaInfo(checksum_address=ursula_address, - uri=f"{ursula.rest_interface.formal_uri}", - encrypting_key=ursula.public_keys(DecryptingPower)) + # ensure node is up and reachable and possibly check version + version = self._get_ursula_version(ursula) + if parse_min_version and not self._is_version_greater_or_equal( + parse_min_version, version + ): + raise self.UrsulaVersionTooOld(ursula_address, version, min_version) + + return Porter.UrsulaInfo( + checksum_address=ursula_address, + uri=f"{ursula.rest_interface.formal_uri}", + encrypting_key=ursula.public_keys(DecryptingPower), + ) except Exception as e: self.log.debug(f"Ursula ({ursula_address}) is unreachable: {str(e)}") raise @@ -299,11 +325,13 @@ def bucket_sampling( exclude_ursulas: Optional[Sequence[ChecksumAddress]] = None, timeout: Optional[int] = None, duration: Optional[int] = None, + min_version: Optional[str] = None, ) -> Tuple[List[ChecksumAddress], int]: timeout = self._configure_timeout( "bucket_sampling", timeout, self.MAX_BUCKET_SAMPLING_TIMEOUT ) duration = duration or 0 + parse_min_version = parse(min_version) if min_version else None if self.domain not in self._ALLOWED_DOMAINS_FOR_BUCKET_SAMPLING: raise ValueError("Bucket sampling is only for TACo Mainnet") @@ -364,7 +392,10 @@ def __init__(self, _reservoir, need_successes: int): self.reservoir = _reservoir self.need_successes = need_successes self.predefined_buckets = self.read_buckets() - self.bucketed_nodes = defaultdict(list) + self.bucketed_nodes = defaultdict( + list + ) # -> + self.selected_nodes = dict() # -> def read_buckets(self) -> Dict: try: @@ -391,6 +422,11 @@ def find_bucket(self, node): return bucket_name return None + def mark_as_not_successful(self, unsuccessful_node: ChecksumAddress): + bucket = self.selected_nodes.get(unsuccessful_node) + if bucket: + self.bucketed_nodes[bucket].remove(unsuccessful_node) + def __call__(self, _successes: int) -> Optional[List[ChecksumAddress]]: batch = [] batch_size = self.need_successes - _successes @@ -403,6 +439,7 @@ def __call__(self, _successes: int) -> Optional[List[ChecksumAddress]]: if len(self.bucketed_nodes[bucket]) >= self.BUCKET_CAP: continue self.bucketed_nodes[bucket].append(selected) + self.selected_nodes[selected] = bucket batch.append(selected) if not batch: return None @@ -417,12 +454,18 @@ def make_sure_ursula_is_online(ursula_address) -> ChecksumAddress: ursula_address = to_checksum_address(ursula_address) ursula = self.known_nodes[ursula_address] try: - # ensure node is up and reachable - self.network_middleware.ping(ursula) + # ensure node is up and reachable and possibly check version + version = self._get_ursula_version(ursula) + if parse_min_version and not self._is_version_greater_or_equal( + parse_min_version, version + ): + raise self.UrsulaVersionTooOld(ursula_address, version, min_version) + return ursula_address except Exception as e: message = f"Ursula ({ursula_address}) is unreachable: {str(e)}" self.log.debug(message) + value_factory.mark_as_not_successful(ursula_address) raise self.block_until_number_of_known_nodes_is( diff --git a/porter/schema.py b/porter/schema.py index b64c0cc..17b2eac 100644 --- a/porter/schema.py +++ b/porter/schema.py @@ -9,6 +9,7 @@ NonNegativeInteger, PositiveInteger, StringList, + VersionString, ) from porter.fields.exceptions import InvalidArgumentCombo, InvalidInputData from porter.fields.retrieve import CapsuleFrag, RetrievalKit @@ -125,6 +126,18 @@ class GetUrsulas(BaseSchema): ), ) + min_version = VersionString( + required=False, + load_only=True, + click=click.option( + "--min-version", + "-mv", + help="Minimum acceptable version of Ursula", + type=click.STRING, + required=False, + ), + ) + # output ursulas = marshmallow_fields.List(marshmallow_fields.Nested(UrsulaInfoSchema), dump_only=True) @@ -369,6 +382,18 @@ class BucketSampling(BaseSchema): ), ) + min_version = VersionString( + required=False, + load_only=True, + click=click.option( + "--min-version", + "-mv", + help="Minimum acceptable version of Ursula", + type=click.STRING, + required=False, + ), + ) + # output ursulas = marshmallow_fields.List(UrsulaChecksumAddress, dump_only=True) block_number = marshmallow_fields.Int(dump_only=True) diff --git a/tests/conftest.py b/tests/conftest.py index 1763e9e..6c47cc2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,9 +34,11 @@ from tests.constants import ( MOCK_ETH_PROVIDER_URI, TEMPORARY_DOMAIN, + TEST_ETH_PROVIDER_URI, TESTERCHAIN_CHAIN_ID, ) from tests.mock.interfaces import MockBlockchain +from tests.utils.middleware import MockRestMiddleware, _TestMiddlewareClient from tests.utils.registry import MockRegistrySource, mock_registry_sources # Crash on server error by default @@ -245,6 +247,50 @@ def mock_signer(get_random_checksum_address): return signer +class _MockMiddlewareClient(_TestMiddlewareClient): + class MockResponse: + def __init__(self, json_data, status_code): + self.json_data = json_data + self.status_code = status_code + + def json(self): + return self.json_data + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ursulas_versions = {} + + def get(self, *args, **kwargs): + if kwargs.get("path") == "status" and kwargs.get("params")["json"]: + node_address = kwargs.get("node_or_sprout").checksum_address + version = self.ursulas_versions.get(node_address, "1.1.1") + return _MockMiddlewareClient.MockResponse({"version": version}, 200) + + real_get = super(_TestMiddlewareClient, self).__getattr__("get") + return real_get(*args, **kwargs) + + +class _MockRestMiddleware(MockRestMiddleware): + """ + Modified middleware to emulate returning status with version. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.client = _MockMiddlewareClient(eth_endpoint=TEST_ETH_PROVIDER_URI) + + def set_ursulas_versions(self, ursulas_versions: dict): + self.client.ursulas_versions = dict(ursulas_versions) + + def clean_ursulas_versions(self): + self.client.ursulas_versions = {} + + +@pytest.fixture(scope="module") +def mock_rest_middleware(): + return _MockRestMiddleware(eth_endpoint=TEST_ETH_PROVIDER_URI) + + @pytest.fixture(scope="module") @pytest.mark.usefixtures('testerchain', 'agency') def porter(ursulas, mock_rest_middleware, test_registry): diff --git a/tests/test_bucket_sampling.py b/tests/test_bucket_sampling.py index 4d63e44..7b257d2 100644 --- a/tests/test_bucket_sampling.py +++ b/tests/test_bucket_sampling.py @@ -23,6 +23,16 @@ def json(self): mocker.patch("requests.get", return_value=MockRequestResponse()) +@pytest.fixture(autouse=True) +def mock_worker_pool_sleep(monkeypatch): + original = WorkerPool._sleep + + def _sleep(worker_pool, timeout): + original(worker_pool, 0.01) + + monkeypatch.setattr(WorkerPool, "_sleep", _sleep) + + def test_bucket_sampling_schema(get_random_checksum_address): # # Input i.e. load @@ -81,6 +91,11 @@ def test_bucket_sampling_schema(get_random_checksum_address): updated_data["timeout"] = 20 BucketSampling().load(updated_data) + # min version + updated_data = dict(required_data) + updated_data["min_version"] = "1.1.1" + BucketSampling().load(updated_data) + # list input formatted as ',' separated strings updated_data = dict(required_data) updated_data["exclude_ursulas"] = ",".join(exclude_ursulas) @@ -133,6 +148,18 @@ def test_bucket_sampling_schema(get_random_checksum_address): updated_data["duration"] = -1 BucketSampling().load(updated_data) + # invalid min version + with pytest.raises(InvalidInputData): + updated_data = dict(required_data) + updated_data["min_version"] = "v1x1.1" + BucketSampling().load(updated_data) + + # invalid min version + with pytest.raises(InvalidInputData): + updated_data = dict(required_data) + updated_data["min_version"] = "1-1-1" + BucketSampling().load(updated_data) + # # Output i.e. dump # @@ -210,11 +237,22 @@ def test_bucket_sampling_python_interface( with pytest.raises(WorkerPool.OutOfValues): _, _ = porter.bucket_sampling(quantity=5) + # no nodes with specified version + with pytest.raises(WorkerPool.OutOfValues): + _, _ = porter.bucket_sampling(quantity=1, timeout=30, min_version="2.2.2") + porter.network_middleware.set_ursulas_versions({sampled_ursulas[0]: "3.0.0"}) + ursulas_info, _ = porter.bucket_sampling(quantity=1, min_version="2.2.2") + assert ursulas_info[0] == sampled_ursulas[0] + with pytest.raises(WorkerPool.OutOfValues): + porter.bucket_sampling(quantity=2, min_version="2.2.2") + porter.network_middleware.clean_ursulas_versions() + @pytest.mark.parametrize("timeout", [None, 10]) @pytest.mark.parametrize("random_seed", [None, 42]) @pytest.mark.parametrize("duration", [None, 0, 60 * 60 * 24, 60 * 60 * 24 * 365]) def test_bucket_sampling_web_interface( + porter, porter_web_controller, ursulas, timeout, @@ -310,3 +348,36 @@ def test_bucket_sampling_web_interface( ) assert response.status_code == 400 assert "Insufficient nodes" in response.text + + # + # Failure case: no nodes with specified version + # + failed_ursula_params = dict(get_ursulas_params) + failed_ursula_params["quantity"] = 1 + failed_ursula_params["min_version"] = "2.0.0" + response = porter_web_controller.get( + "/bucket_sampling", data=json.dumps(failed_ursula_params) + ) + assert ( + f"version is too old (1.1.1 < {failed_ursula_params['min_version']})" + in response.text + ) + + porter.network_middleware.set_ursulas_versions({sampled_ursulas[0]: "3.0.0"}) + response = porter_web_controller.get( + "/bucket_sampling", data=json.dumps(failed_ursula_params) + ) + assert response.status_code == 200 + response_data = json.loads(response.data) + ursulas_info = response_data["result"]["ursulas"] + assert ursulas_info[0] == sampled_ursulas[0] + + failed_ursula_params["quantity"] = 2 + response = porter_web_controller.get( + "/bucket_sampling", data=json.dumps(failed_ursula_params) + ) + assert ( + f"version is too old (1.1.1 < {failed_ursula_params['min_version']})" + in response.text + ) + porter.network_middleware.clean_ursulas_versions() diff --git a/tests/test_get_ursulas.py b/tests/test_get_ursulas.py index bedaa34..294100b 100644 --- a/tests/test_get_ursulas.py +++ b/tests/test_get_ursulas.py @@ -1,6 +1,7 @@ import json import pytest +from nucypher.utilities.concurrency import WorkerPool from nucypher_core.umbral import SecretKey from porter.fields.exceptions import InvalidArgumentCombo, InvalidInputData @@ -8,6 +9,16 @@ from porter.schema import GetUrsulas, UrsulaInfoSchema +@pytest.fixture(autouse=True) +def mock_worker_pool_sleep(monkeypatch): + original = WorkerPool._sleep + + def _sleep(worker_pool, timeout): + original(worker_pool, 0.01) + + monkeypatch.setattr(WorkerPool, "_sleep", _sleep) + + def test_get_ursulas_schema(get_random_checksum_address): # # Input i.e. load @@ -88,6 +99,11 @@ def test_get_ursulas_schema(get_random_checksum_address): assert data["exclude_ursulas"] == [exclude_ursulas[0]] assert data["include_ursulas"] == [include_ursulas[0]] + # min version + updated_data = dict(required_data) + updated_data["min_version"] = "1.1.1" + GetUrsulas().load(updated_data) + # invalid include entry updated_data = dict(required_data) updated_data["exclude_ursulas"] = exclude_ursulas @@ -171,6 +187,18 @@ def test_get_ursulas_schema(get_random_checksum_address): updated_data["duration"] = -1 GetUrsulas().load(updated_data) + # invalid min version + with pytest.raises(InvalidInputData): + updated_data = dict(required_data) + updated_data["min_version"] = "v1x1.1" + GetUrsulas().load(updated_data) + + # invalid min version + with pytest.raises(InvalidInputData): + updated_data = dict(required_data) + updated_data["min_version"] = "1-1-1" + GetUrsulas().load(updated_data) + # # Output i.e. dump # @@ -282,10 +310,23 @@ def test_get_ursulas_python_interface( with pytest.raises(ValueError, match="Insufficient nodes"): porter.get_ursulas(quantity=len(ursulas) + 1) + # no nodes with specified version + with pytest.raises(WorkerPool.OutOfValues): + porter.get_ursulas(quantity=1, min_version="2.2.2") + porter.network_middleware.set_ursulas_versions( + {ursulas[0].checksum_address: "3.0.0"} + ) + ursulas_info = porter.get_ursulas(quantity=1, min_version="2.2.2") + assert ursulas[0].checksum_address == ursulas_info[0].checksum_address + with pytest.raises(WorkerPool.OutOfValues): + porter.get_ursulas(quantity=2, min_version="2.2.2") + porter.network_middleware.clean_ursulas_versions() + @pytest.mark.parametrize("timeout", [None, 10, 20]) @pytest.mark.parametrize("duration", [None, 0, 60 * 60 * 24, 60 * 60 * 24 * 365]) def test_get_ursulas_web_interface( + porter, porter_web_controller, ursulas, timeout, @@ -388,3 +429,37 @@ def test_get_ursulas_web_interface( ) assert response.status_code == 400 assert "Insufficient nodes" in response.text + + # + # Failure case: no nodes with specified version + # + failed_ursula_params = dict(get_ursulas_params) + failed_ursula_params["quantity"] = 1 + failed_ursula_params["min_version"] = "2.0.0" + del failed_ursula_params["include_ursulas"] + response = porter_web_controller.get( + "/get_ursulas", data=json.dumps(failed_ursula_params) + ) + assert ( + f"version is too old (1.1.1 < {failed_ursula_params['min_version']})" + in response.text + ) + + porter.network_middleware.set_ursulas_versions({include_ursulas[0]: "3.0.0"}) + response = porter_web_controller.get( + "/get_ursulas", data=json.dumps(failed_ursula_params) + ) + assert response.status_code == 200 + response_data = json.loads(response.data) + ursulas_info = response_data["result"]["ursulas"] + assert ursulas_info[0]["checksum_address"] == include_ursulas[0] + + failed_ursula_params["quantity"] = 2 + response = porter_web_controller.get( + "/get_ursulas", data=json.dumps(failed_ursula_params) + ) + assert ( + f"version is too old (1.1.1 < {failed_ursula_params['min_version']})" + in response.text + ) + porter.network_middleware.clean_ursulas_versions()