diff --git a/docs/index.rst b/docs/index.rst index 91bae8a..6c2791d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -13,7 +13,7 @@ Pigeon is a combination of a `STOMP client ` .. autoclass:: pigeon.BaseMessage .. autoexception:: pigeon.exceptions.NoSuchTopicException -.. autoexception:: pigeon.exceptions.VersionMismatchException +.. autoexception:: pigeon.exceptions.SignatureException Indices and tables ================== diff --git a/examples/publisher.py b/examples/publisher.py index 371a3fe..7341af6 100644 --- a/examples/publisher.py +++ b/examples/publisher.py @@ -2,7 +2,7 @@ import time from pigeon.client import Pigeon -from pigeon.logging import setup_logging +from pigeon.utils import setup_logging from pigeon.base_msg import BaseMessage @@ -15,8 +15,8 @@ class TestMsg(BaseMessage): host = os.environ.get("ARTEMIS_HOST", "127.0.0.1") port = int(os.environ.get("ARTEMIS_PORT", 61616)) -connection = Pigeon("Publisher", host=host, port=port, logger=logger) -connection.register_topic("test", TestMsg, "1.0") +connection = Pigeon("Publisher", host=host, port=port, logger=logger, load_topics=False) +connection.register_topic("test", TestMsg) connection.connect(username="admin", password="password") while True: diff --git a/examples/subscriber.py b/examples/subscriber.py index dd0dd55..9f91491 100644 --- a/examples/subscriber.py +++ b/examples/subscriber.py @@ -1,7 +1,7 @@ import os from pigeon.client import Pigeon -from pigeon.logging import setup_logging +from pigeon.utils import setup_logging from pigeon.base_msg import BaseMessage logger = setup_logging("subscriber") @@ -18,8 +18,8 @@ def handle_test_message(topic, message): logger.info(f"Received {topic} message: {message}") -connection = Pigeon("Subscriber", host=host, port=port, logger=logger) -connection.register_topic("test", TestMsg, "1.0") +connection = Pigeon("Subscriber", host=host, port=port, logger=logger, load_topics=False) +connection.register_topic("test", TestMsg) connection.connect(username="admin", password="password") connection.subscribe("test", handle_test_message) diff --git a/pigeon/client.py b/pigeon/client.py index 1dda9d1..06f95d5 100644 --- a/pigeon/client.py +++ b/pigeon/client.py @@ -9,7 +9,7 @@ from pydantic import ValidationError from . import exceptions -from .utils import call_with_correct_args +from .utils import get_message_hash, call_with_correct_args def get_str_time_ms(): @@ -26,10 +26,9 @@ class Pigeon: in two ways. One is to use the register_topic(), or register_topics() methods. The other is to have message definitions in a Python package with an entry point defined in the pigeon.msgs group. This entry point - should provide a tuple containing a mapping of topics to Pydantic models, - and the message version. Topics defined in this manner will be - automatically discovered and loaded at runtime, unless this mechanism is - manually disabled. + should provide a tuple containing a mapping of topics to Pydantic models. + Topics defined in this manner will be automatically discovered and loaded at + runtime, unless this mechanism is manually disabled. """ def __init__( @@ -53,7 +52,7 @@ def __init__( self._service = service self._connection = stomp.Connection12([(host, port)], heartbeats=(10000, 10000)) self._topics = {} - self._msg_versions = {} + self._hashes = {} if load_topics: self._load_topics() self._callbacks: Dict[str, Callable] = {} @@ -72,26 +71,24 @@ def _load_topics(self): for entrypoint in entry_points(group="pigeon.msgs"): self.register_topics(*entrypoint.load()) - def register_topic(self, topic: str, msg_class: Callable, version: str): + def register_topic(self, topic: str, msg_class: Callable): """Register message definition for a given topic. Args: topic: The topic that this message definition applies to. msg_class: The Pydantic model definition of the message. - version: The version of the message. """ self._topics[topic] = msg_class - self._msg_versions[topic] = version + self._hashes[topic] = get_message_hash(msg_class) - def register_topics(self, topics: Dict[str, Callable], version: str): + def register_topics(self, topics: Dict[str, Callable]): """Register a number of message definitions for multiple topics. Args: topics: A mapping of topics to Pydantic model message definitions. - version: The version of these messages. """ for topic in topics.items(): - self.register_topic(*topic, version) + self.register_topic(*topic) def connect( self, @@ -143,26 +140,26 @@ def send(self, topic: str, **data): serialized_data = self._topics[topic](**data).serialize() headers = dict( service=self._service, - version=self._msg_versions[topic], + hash=self._hashes[topic], sent_at=get_str_time_ms(), ) self._connection.send(destination=topic, body=serialized_data, headers=headers) self._logger.debug(f"Sent data to {topic}: {serialized_data}") def _ensure_topic_exists(self, topic: str): - if topic not in self._topics or topic not in self._msg_versions: + if topic not in self._topics or topic not in self._hashes: raise exceptions.NoSuchTopicException(f"Topic {topic} not defined.") def _handle_message(self, message_frame: Frame): topic = message_frame.headers["subscription"] - if topic not in self._topics or topic not in self._msg_versions: + if topic not in self._topics or topic not in self._hashes: self._logger.warning( f"Received a message on an unregistered topic: {topic}" ) return - if message_frame.headers.get("version") != self._msg_versions.get(topic): + if message_frame.headers.get("hash") != self._hashes.get(topic): self._logger.warning( - f"Received a message on topic '{topic}' with an incorrect version {message_frame.headers.get('version')}. Version should be {self._msg_versions.get(topic)}" + f"Received a message on topic '{topic}' with an incorrect hash: {message_frame.headers.get('hash')}. Expected: {self._hashes.get(topic)}" ) return try: diff --git a/pigeon/utils.py b/pigeon/utils.py index 494fa51..725765b 100644 --- a/pigeon/utils.py +++ b/pigeon/utils.py @@ -1,6 +1,8 @@ import logging import inspect from copy import copy +import hashlib +from typing import Callable from .exceptions import SignatureException @@ -17,6 +19,12 @@ def setup_logging(logger_name: str, log_level: int = logging.INFO): return logger +def get_message_hash(msg_cls: Callable): + hash = hashlib.sha1() + hash.update(inspect.getsource(msg_cls).encode("utf8")) + return hash.hexdigest() + + def call_with_correct_args(func, *args, **kwargs): args = copy(args) kwargs = copy(kwargs) diff --git a/tests/test_client.py b/tests/test_client.py index 4890aec..7f2be09 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -6,9 +6,6 @@ from pigeon import BaseMessage -__version__ = "v1.2.3" - - class MockMessage(BaseMessage): field1: str @@ -18,9 +15,10 @@ def pigeon_client(): with patch("pigeon.utils.setup_logging") as mock_logging: topics = {"topic1": MockMessage} client = Pigeon( - "test", host="localhost", port=61613, logger=mock_logging.Logger() + "test", host="localhost", port=61613, logger=mock_logging.Logger(), + load_topics=False, ) - client.register_topics(topics, __version__) + client.register_topics(topics) yield client @@ -83,7 +81,7 @@ def test_connect_failure(pigeon_client, username, password): ) def test_send(pigeon_client, topic, data, expected_serialized_data): - expected_headers = {"service": "test", "version": __version__, "sent_at": "1"} + expected_headers = {"service": "test", "hash": pigeon_client._hashes["topic1"], "sent_at": "1"} # Arrange with patch("pigeon.client.time.time_ns", lambda: 1e6): pigeon_client._topics[topic] = MockMessage diff --git a/tests/test_message_handler.py b/tests/test_message_handler.py index 5f21def..c793fc9 100644 --- a/tests/test_message_handler.py +++ b/tests/test_message_handler.py @@ -22,7 +22,7 @@ def test_one_arg(pigeon_client): mock_stomp_message = MagicMock() mock_stomp_message.headers = { "subscription": "test.msg", - "version": "v1.2.3", + "hash": "abcd", } mock_message = MagicMock() @@ -32,7 +32,8 @@ def callback(msg): assert msg == mock_message.deserialize() pigeon_client._connection = MagicMock() - pigeon_client.register_topic("test.msg", mock_message, "v1.2.3") + pigeon_client._topics["test.msg"] = mock_message + pigeon_client._hashes["test.msg"] = "abcd" pigeon_client.subscribe("test.msg", callback) pigeon_client._handle_message(mock_stomp_message) @@ -42,7 +43,7 @@ def test_two_args(pigeon_client): mock_stomp_message = MagicMock() mock_stomp_message.headers = { "subscription": "test.msg", - "version": "v1.2.3", + "hash": "abcde", } mock_message = MagicMock() @@ -53,7 +54,8 @@ def callback(msg, topic): assert topic == "test.msg" pigeon_client._connection = MagicMock() - pigeon_client.register_topic("test.msg", mock_message, "v1.2.3") + pigeon_client._topics["test.msg"] = mock_message + pigeon_client._hashes["test.msg"] = "abcde" pigeon_client.subscribe("test.msg", callback) pigeon_client._handle_message(mock_stomp_message) @@ -63,7 +65,7 @@ def test_three_args(pigeon_client): mock_stomp_message = MagicMock() mock_stomp_message.headers = { "subscription": "test.msg", - "version": "v1.2.3", + "hash": "123abc", } mock_message = MagicMock() @@ -75,7 +77,8 @@ def callback(msg, topic, headers): assert headers == mock_stomp_message.headers pigeon_client._connection = MagicMock() - pigeon_client.register_topic("test.msg", mock_message, "v1.2.3") + pigeon_client._topics["test.msg"] = mock_message + pigeon_client._hashes["test.msg"] = "123abc" pigeon_client.subscribe("test.msg", callback) pigeon_client._handle_message(mock_stomp_message) @@ -85,7 +88,7 @@ def test_var_args(pigeon_client): mock_stomp_message = MagicMock() mock_stomp_message.headers = { "subscription": "test.msg", - "version": "v1.2.3", + "hash": "xyz987", } mock_message = MagicMock() @@ -98,7 +101,8 @@ def callback(*args): assert args[2] == mock_stomp_message.headers pigeon_client._connection = MagicMock() - pigeon_client.register_topic("test.msg", mock_message, "v1.2.3") + pigeon_client._topics["test.msg"] = mock_message + pigeon_client._hashes["test.msg"] = "xyz987" pigeon_client.subscribe("test.msg", callback) pigeon_client._handle_message(mock_stomp_message) @@ -118,25 +122,27 @@ def test_topic_does_not_exist(pigeon_client): ) -def test_version_mismatch(pigeon_client): - mock_message = create_mock_message(subscription="test", version="v0.1.1") +def test_hash_mismatch(pigeon_client): + mock_message = create_mock_message(subscription="test", hash="abc1") - pigeon_client.register_topic("test", lambda x: x, "v0.1.0") + pigeon_client._topics["test"] = None + pigeon_client._hashes["test"] = "abcd" pigeon_client._handle_message(mock_message) pigeon_client._logger.warning.assert_called_with( - "Received a message on topic 'test' with an incorrect version v0.1.1. Version should be v0.1.0" + "Received a message on topic 'test' with an incorrect hash: abc1. Expected: abcd" ) def test_validation_error(pigeon_client): - mock_message = create_mock_message(subscription="test", version="v0.1.0") + mock_message = create_mock_message(subscription="test", hash="abc123") mock_msg_def = MagicMock() mock_msg_def.deserialize.side_effect = ValidationError.from_exception_data( title="Test", line_errors=[] ) - pigeon_client.register_topic("test", mock_msg_def, "v0.1.0") + pigeon_client._topics["test"] = mock_msg_def + pigeon_client._hashes["test"] = "abc123" pigeon_client._handle_message(mock_message) pigeon_client._logger.warning.assert_called_with( @@ -145,9 +151,10 @@ def test_validation_error(pigeon_client): def test_no_callback(pigeon_client): - mock_message = create_mock_message(subscription="test", version="v0.1.0") + mock_message = create_mock_message(subscription="test", hash="4321") - pigeon_client.register_topic("test", MagicMock(), "v0.1.0") + pigeon_client._topics["test"] = MagicMock() + pigeon_client._hashes["test"] = "4321" pigeon_client._handle_message(mock_message) pigeon_client._logger.warning.assert_called_with( @@ -156,10 +163,11 @@ def test_no_callback(pigeon_client): def test_bad_signature(pigeon_client): - mock_message = create_mock_message(subscription="test", version="v0.1.0") + mock_message = create_mock_message(subscription="test", hash="lmnop") callback = lambda a, b, c, d: None - pigeon_client.register_topic("test", MagicMock(), "v0.1.0") + pigeon_client._topics["test"] = MagicMock() + pigeon_client._hashes["test"] = "lmnop" pigeon_client.subscribe("test", callback) pigeon_client._handle_message(mock_message) @@ -169,9 +177,10 @@ def test_bad_signature(pigeon_client): def test_callback_exception(pigeon_client): - mock_message = create_mock_message(subscription="test", version="v0.1.0") + mock_message = create_mock_message(subscription="test", hash="987654321") - pigeon_client.register_topic("test", MagicMock(), "v0.1.0") + pigeon_client._topics["test"] = MagicMock() + pigeon_client._hashes["test"] = "987654321" pigeon_client.subscribe( "test", MagicMock(side_effect=RecursionError("This is a test error.")) ) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..d86f341 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,10 @@ +from pigeon.utils import get_message_hash + + +def test_get_message_hash(): + class TestMsg: + attr_one: int + attr_two: str + attr_three: float + + assert get_message_hash(TestMsg) == "e6b05f8920682eca0ba8b415c9fa7a7f248ddfce"