Skip to content

Commit

Permalink
Merge pull request #74 from vzotova/min-version
Browse files Browse the repository at this point in the history
Add node version filtering for sampling
  • Loading branch information
KPrasch authored Aug 12, 2024
2 parents 07f7da5 + 702334a commit 66938c7
Show file tree
Hide file tree
Showing 8 changed files with 285 additions and 8 deletions.
3 changes: 3 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,9 @@ Parameters
| | | | are greater than this max default value are |
| | | | capped at the default value |
+----------------------------------+------------------+------------------------------------------------+
| ``min_version`` | *(Optional)* | | Minimum acceptable version of Ursula. |
| | VersionString | | |
+----------------------------------+------------------+------------------------------------------------+


Returns
Expand Down
10 changes: 10 additions & 0 deletions porter/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import click
from marshmallow import fields
from packaging.version import parse

from porter.fields.exceptions import InvalidInputData

Expand Down Expand Up @@ -108,3 +109,12 @@ def _deserialize(self, value, attr, data, **kwargs):
f"Unexpected object type, {type(result)}; expected {self.expected_type}")

return result


class VersionString(String):

def _validate(self, value):
try:
parse(value)
except Exception:
raise InvalidInputData(f"{self.name} must be a correct version.")
4 changes: 4 additions & 0 deletions porter/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ def get_ursulas(
include_ursulas: Optional[List[ChecksumAddress]] = None,
timeout: Optional[int] = None,
duration: Optional[int] = None,
min_version: Optional[str] = None,
) -> Dict:
ursulas_info = self.implementer.get_ursulas(
quantity=quantity,
exclude_ursulas=exclude_ursulas,
include_ursulas=include_ursulas,
timeout=timeout,
duration=duration,
min_version=min_version,
)

response_data = {"ursulas": ursulas_info} # list of UrsulaInfo objects
Expand Down Expand Up @@ -104,13 +106,15 @@ def bucket_sampling(
exclude_ursulas: Optional[List[ChecksumAddress]] = None,
timeout: Optional[int] = None,
duration: Optional[int] = None,
min_version: Optional[str] = None,
) -> Dict:
ursulas, block_number = self.implementer.bucket_sampling(
quantity=quantity,
random_seed=random_seed,
exclude_ursulas=exclude_ursulas,
timeout=timeout,
duration=duration,
min_version=min_version,
)

response_data = {"ursulas": ursulas, "block_number": block_number}
Expand Down
59 changes: 51 additions & 8 deletions porter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
TreasureMap,
)
from nucypher_core.umbral import PublicKey
from packaging.version import Version, parse
from prometheus_flask_exporter import PrometheusMetrics

import porter
Expand Down Expand Up @@ -100,6 +101,12 @@ class DecryptOutcome(NamedTuple):
]
errors: Dict[ChecksumAddress, str]

class UrsulaVersionTooOld(Exception):
def __init__(self, ursula_address: str, version: str, min_version: str):
super().__init__(
f"Ursula ({ursula_address}) version is too old ({version} < {min_version})"
)

def __init__(
self,
eth_endpoint: str,
Expand Down Expand Up @@ -155,18 +162,30 @@ def _initialize_endpoints(eth_endpoint: str, polygon_endpoint: str):
):
BlockchainInterfaceFactory.initialize_interface(endpoint=polygon_endpoint)

@staticmethod
def _is_version_greater_or_equal(min_version: Version, version: str) -> bool:
return parse(version) >= min_version

def _get_ursula_version(self, ursula: Ursula) -> str:
response = self.network_middleware.client.get(
node_or_sprout=ursula, path="status", params={"json": "true"}
)
return response.json()["version"]

def get_ursulas(
self,
quantity: int,
exclude_ursulas: Optional[Sequence[ChecksumAddress]] = None,
include_ursulas: Optional[Sequence[ChecksumAddress]] = None,
timeout: Optional[int] = None,
duration: Optional[int] = None,
min_version: Optional[str] = None,
) -> List[UrsulaInfo]:
timeout = self._configure_timeout(
"sampling", timeout, self.MAX_GET_URSULAS_TIMEOUT
)
duration = duration or 0
parse_min_version = parse(min_version) if min_version else None

reservoir = self._make_reservoir(exclude_ursulas, include_ursulas, duration)
available_nodes_to_sample = len(reservoir.values) + len(reservoir.reservoir)
Expand All @@ -184,11 +203,18 @@ def get_ursula_info(ursula_address) -> Porter.UrsulaInfo:
ursula_address = to_checksum_address(ursula_address)
ursula = self.known_nodes[ursula_address]
try:
# ensure node is up and reachable
self.network_middleware.ping(ursula)
return Porter.UrsulaInfo(checksum_address=ursula_address,
uri=f"{ursula.rest_interface.formal_uri}",
encrypting_key=ursula.public_keys(DecryptingPower))
# ensure node is up and reachable and possibly check version
version = self._get_ursula_version(ursula)
if parse_min_version and not self._is_version_greater_or_equal(
parse_min_version, version
):
raise self.UrsulaVersionTooOld(ursula_address, version, min_version)

return Porter.UrsulaInfo(
checksum_address=ursula_address,
uri=f"{ursula.rest_interface.formal_uri}",
encrypting_key=ursula.public_keys(DecryptingPower),
)
except Exception as e:
self.log.debug(f"Ursula ({ursula_address}) is unreachable: {str(e)}")
raise
Expand Down Expand Up @@ -299,11 +325,13 @@ def bucket_sampling(
exclude_ursulas: Optional[Sequence[ChecksumAddress]] = None,
timeout: Optional[int] = None,
duration: Optional[int] = None,
min_version: Optional[str] = None,
) -> Tuple[List[ChecksumAddress], int]:
timeout = self._configure_timeout(
"bucket_sampling", timeout, self.MAX_BUCKET_SAMPLING_TIMEOUT
)
duration = duration or 0
parse_min_version = parse(min_version) if min_version else None

if self.domain not in self._ALLOWED_DOMAINS_FOR_BUCKET_SAMPLING:
raise ValueError("Bucket sampling is only for TACo Mainnet")
Expand Down Expand Up @@ -364,7 +392,10 @@ def __init__(self, _reservoir, need_successes: int):
self.reservoir = _reservoir
self.need_successes = need_successes
self.predefined_buckets = self.read_buckets()
self.bucketed_nodes = defaultdict(list)
self.bucketed_nodes = defaultdict(
list
) # <bucket> -> <list of checksum addresses>
self.selected_nodes = dict() # <checksum address> -> <bucket>

def read_buckets(self) -> Dict:
try:
Expand All @@ -391,6 +422,11 @@ def find_bucket(self, node):
return bucket_name
return None

def mark_as_not_successful(self, unsuccessful_node: ChecksumAddress):
bucket = self.selected_nodes.get(unsuccessful_node)
if bucket:
self.bucketed_nodes[bucket].remove(unsuccessful_node)

def __call__(self, _successes: int) -> Optional[List[ChecksumAddress]]:
batch = []
batch_size = self.need_successes - _successes
Expand All @@ -403,6 +439,7 @@ def __call__(self, _successes: int) -> Optional[List[ChecksumAddress]]:
if len(self.bucketed_nodes[bucket]) >= self.BUCKET_CAP:
continue
self.bucketed_nodes[bucket].append(selected)
self.selected_nodes[selected] = bucket
batch.append(selected)
if not batch:
return None
Expand All @@ -417,12 +454,18 @@ def make_sure_ursula_is_online(ursula_address) -> ChecksumAddress:
ursula_address = to_checksum_address(ursula_address)
ursula = self.known_nodes[ursula_address]
try:
# ensure node is up and reachable
self.network_middleware.ping(ursula)
# ensure node is up and reachable and possibly check version
version = self._get_ursula_version(ursula)
if parse_min_version and not self._is_version_greater_or_equal(
parse_min_version, version
):
raise self.UrsulaVersionTooOld(ursula_address, version, min_version)

return ursula_address
except Exception as e:
message = f"Ursula ({ursula_address}) is unreachable: {str(e)}"
self.log.debug(message)
value_factory.mark_as_not_successful(ursula_address)
raise

self.block_until_number_of_known_nodes_is(
Expand Down
25 changes: 25 additions & 0 deletions porter/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
NonNegativeInteger,
PositiveInteger,
StringList,
VersionString,
)
from porter.fields.exceptions import InvalidArgumentCombo, InvalidInputData
from porter.fields.retrieve import CapsuleFrag, RetrievalKit
Expand Down Expand Up @@ -125,6 +126,18 @@ class GetUrsulas(BaseSchema):
),
)

min_version = VersionString(
required=False,
load_only=True,
click=click.option(
"--min-version",
"-mv",
help="Minimum acceptable version of Ursula",
type=click.STRING,
required=False,
),
)

# output
ursulas = marshmallow_fields.List(marshmallow_fields.Nested(UrsulaInfoSchema), dump_only=True)

Expand Down Expand Up @@ -369,6 +382,18 @@ class BucketSampling(BaseSchema):
),
)

min_version = VersionString(
required=False,
load_only=True,
click=click.option(
"--min-version",
"-mv",
help="Minimum acceptable version of Ursula",
type=click.STRING,
required=False,
),
)

# output
ursulas = marshmallow_fields.List(UrsulaChecksumAddress, dump_only=True)
block_number = marshmallow_fields.Int(dump_only=True)
46 changes: 46 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@
from tests.constants import (
MOCK_ETH_PROVIDER_URI,
TEMPORARY_DOMAIN,
TEST_ETH_PROVIDER_URI,
TESTERCHAIN_CHAIN_ID,
)
from tests.mock.interfaces import MockBlockchain
from tests.utils.middleware import MockRestMiddleware, _TestMiddlewareClient
from tests.utils.registry import MockRegistrySource, mock_registry_sources

# Crash on server error by default
Expand Down Expand Up @@ -245,6 +247,50 @@ def mock_signer(get_random_checksum_address):
return signer


class _MockMiddlewareClient(_TestMiddlewareClient):
class MockResponse:
def __init__(self, json_data, status_code):
self.json_data = json_data
self.status_code = status_code

def json(self):
return self.json_data

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ursulas_versions = {}

def get(self, *args, **kwargs):
if kwargs.get("path") == "status" and kwargs.get("params")["json"]:
node_address = kwargs.get("node_or_sprout").checksum_address
version = self.ursulas_versions.get(node_address, "1.1.1")
return _MockMiddlewareClient.MockResponse({"version": version}, 200)

real_get = super(_TestMiddlewareClient, self).__getattr__("get")
return real_get(*args, **kwargs)


class _MockRestMiddleware(MockRestMiddleware):
"""
Modified middleware to emulate returning status with version.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = _MockMiddlewareClient(eth_endpoint=TEST_ETH_PROVIDER_URI)

def set_ursulas_versions(self, ursulas_versions: dict):
self.client.ursulas_versions = dict(ursulas_versions)

def clean_ursulas_versions(self):
self.client.ursulas_versions = {}


@pytest.fixture(scope="module")
def mock_rest_middleware():
return _MockRestMiddleware(eth_endpoint=TEST_ETH_PROVIDER_URI)


@pytest.fixture(scope="module")
@pytest.mark.usefixtures('testerchain', 'agency')
def porter(ursulas, mock_rest_middleware, test_registry):
Expand Down
Loading

0 comments on commit 66938c7

Please sign in to comment.