Skip to content

Commit

Permalink
implemented a message definition hashing system
Browse files Browse the repository at this point in the history
  • Loading branch information
CameronDevine committed Sep 18, 2024
1 parent 642b809 commit fb34085
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 50 deletions.
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Pigeon is a combination of a `STOMP client <https://pypi.org/project/stomp-py/>`
.. autoclass:: pigeon.BaseMessage

.. autoexception:: pigeon.exceptions.NoSuchTopicException
.. autoexception:: pigeon.exceptions.VersionMismatchException
.. autoexception:: pigeon.exceptions.SignatureException

Indices and tables
==================
Expand Down
6 changes: 3 additions & 3 deletions examples/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time

from pigeon.client import Pigeon
from pigeon.logging import setup_logging
from pigeon.utils import setup_logging
from pigeon.base_msg import BaseMessage


Expand All @@ -15,8 +15,8 @@ class TestMsg(BaseMessage):
host = os.environ.get("ARTEMIS_HOST", "127.0.0.1")
port = int(os.environ.get("ARTEMIS_PORT", 61616))

connection = Pigeon("Publisher", host=host, port=port, logger=logger)
connection.register_topic("test", TestMsg, "1.0")
connection = Pigeon("Publisher", host=host, port=port, logger=logger, load_topics=False)
connection.register_topic("test", TestMsg)
connection.connect(username="admin", password="password")

while True:
Expand Down
6 changes: 3 additions & 3 deletions examples/subscriber.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

from pigeon.client import Pigeon
from pigeon.logging import setup_logging
from pigeon.utils import setup_logging
from pigeon.base_msg import BaseMessage

logger = setup_logging("subscriber")
Expand All @@ -18,8 +18,8 @@ def handle_test_message(topic, message):
logger.info(f"Received {topic} message: {message}")


connection = Pigeon("Subscriber", host=host, port=port, logger=logger)
connection.register_topic("test", TestMsg, "1.0")
connection = Pigeon("Subscriber", host=host, port=port, logger=logger, load_topics=False)
connection.register_topic("test", TestMsg)
connection.connect(username="admin", password="password")
connection.subscribe("test", handle_test_message)

Expand Down
31 changes: 14 additions & 17 deletions pigeon/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pydantic import ValidationError

from . import exceptions
from .utils import call_with_correct_args
from .utils import get_message_hash, call_with_correct_args


def get_str_time_ms():
Expand All @@ -26,10 +26,9 @@ class Pigeon:
in two ways. One is to use the register_topic(), or register_topics()
methods. The other is to have message definitions in a Python package
with an entry point defined in the pigeon.msgs group. This entry point
should provide a tuple containing a mapping of topics to Pydantic models,
and the message version. Topics defined in this manner will be
automatically discovered and loaded at runtime, unless this mechanism is
manually disabled.
should provide a tuple containing a mapping of topics to Pydantic models.
Topics defined in this manner will be automatically discovered and loaded at
runtime, unless this mechanism is manually disabled.
"""

def __init__(
Expand All @@ -53,7 +52,7 @@ def __init__(
self._service = service
self._connection = stomp.Connection12([(host, port)], heartbeats=(10000, 10000))
self._topics = {}
self._msg_versions = {}
self._hashes = {}
if load_topics:
self._load_topics()
self._callbacks: Dict[str, Callable] = {}
Expand All @@ -72,26 +71,24 @@ def _load_topics(self):
for entrypoint in entry_points(group="pigeon.msgs"):
self.register_topics(*entrypoint.load())

def register_topic(self, topic: str, msg_class: Callable, version: str):
def register_topic(self, topic: str, msg_class: Callable):
"""Register message definition for a given topic.
Args:
topic: The topic that this message definition applies to.
msg_class: The Pydantic model definition of the message.
version: The version of the message.
"""
self._topics[topic] = msg_class
self._msg_versions[topic] = version
self._hashes[topic] = get_message_hash(msg_class)

def register_topics(self, topics: Dict[str, Callable], version: str):
def register_topics(self, topics: Dict[str, Callable]):
"""Register a number of message definitions for multiple topics.
Args:
topics: A mapping of topics to Pydantic model message definitions.
version: The version of these messages.
"""
for topic in topics.items():
self.register_topic(*topic, version)
self.register_topic(*topic)

def connect(
self,
Expand Down Expand Up @@ -143,26 +140,26 @@ def send(self, topic: str, **data):
serialized_data = self._topics[topic](**data).serialize()
headers = dict(
service=self._service,
version=self._msg_versions[topic],
hash=self._hashes[topic],
sent_at=get_str_time_ms(),
)
self._connection.send(destination=topic, body=serialized_data, headers=headers)
self._logger.debug(f"Sent data to {topic}: {serialized_data}")

def _ensure_topic_exists(self, topic: str):
if topic not in self._topics or topic not in self._msg_versions:
if topic not in self._topics or topic not in self._hashes:
raise exceptions.NoSuchTopicException(f"Topic {topic} not defined.")

def _handle_message(self, message_frame: Frame):
topic = message_frame.headers["subscription"]
if topic not in self._topics or topic not in self._msg_versions:
if topic not in self._topics or topic not in self._hashes:
self._logger.warning(
f"Received a message on an unregistered topic: {topic}"
)
return
if message_frame.headers.get("version") != self._msg_versions.get(topic):
if message_frame.headers.get("hash") != self._hashes.get(topic):
self._logger.warning(
f"Received a message on topic '{topic}' with an incorrect version {message_frame.headers.get('version')}. Version should be {self._msg_versions.get(topic)}"
f"Received a message on topic '{topic}' with an incorrect hash: {message_frame.headers.get('hash')}. Expected: {self._hashes.get(topic)}"
)
return
try:
Expand Down
8 changes: 8 additions & 0 deletions pigeon/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import inspect
from copy import copy
import hashlib
from typing import Callable

from .exceptions import SignatureException

Expand All @@ -17,6 +19,12 @@ def setup_logging(logger_name: str, log_level: int = logging.INFO):
return logger


def get_message_hash(msg_cls: Callable):
hash = hashlib.sha1()
hash.update(inspect.getsource(msg_cls).encode("utf8"))
return hash.hexdigest()


def call_with_correct_args(func, *args, **kwargs):
args = copy(args)
kwargs = copy(kwargs)
Expand Down
10 changes: 4 additions & 6 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from pigeon import BaseMessage


__version__ = "v1.2.3"


class MockMessage(BaseMessage):
field1: str

Expand All @@ -18,9 +15,10 @@ def pigeon_client():
with patch("pigeon.utils.setup_logging") as mock_logging:
topics = {"topic1": MockMessage}
client = Pigeon(
"test", host="localhost", port=61613, logger=mock_logging.Logger()
"test", host="localhost", port=61613, logger=mock_logging.Logger(),
load_topics=False,
)
client.register_topics(topics, __version__)
client.register_topics(topics)
yield client


Expand Down Expand Up @@ -83,7 +81,7 @@ def test_connect_failure(pigeon_client, username, password):
)
def test_send(pigeon_client, topic, data, expected_serialized_data):

expected_headers = {"service": "test", "version": __version__, "sent_at": "1"}
expected_headers = {"service": "test", "hash": pigeon_client._hashes["topic1"], "sent_at": "1"}
# Arrange
with patch("pigeon.client.time.time_ns", lambda: 1e6):
pigeon_client._topics[topic] = MockMessage
Expand Down
49 changes: 29 additions & 20 deletions tests/test_message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_one_arg(pigeon_client):
mock_stomp_message = MagicMock()
mock_stomp_message.headers = {
"subscription": "test.msg",
"version": "v1.2.3",
"hash": "abcd",
}

mock_message = MagicMock()
Expand All @@ -32,7 +32,8 @@ def callback(msg):
assert msg == mock_message.deserialize()

pigeon_client._connection = MagicMock()
pigeon_client.register_topic("test.msg", mock_message, "v1.2.3")
pigeon_client._topics["test.msg"] = mock_message
pigeon_client._hashes["test.msg"] = "abcd"
pigeon_client.subscribe("test.msg", callback)

pigeon_client._handle_message(mock_stomp_message)
Expand All @@ -42,7 +43,7 @@ def test_two_args(pigeon_client):
mock_stomp_message = MagicMock()
mock_stomp_message.headers = {
"subscription": "test.msg",
"version": "v1.2.3",
"hash": "abcde",
}

mock_message = MagicMock()
Expand All @@ -53,7 +54,8 @@ def callback(msg, topic):
assert topic == "test.msg"

pigeon_client._connection = MagicMock()
pigeon_client.register_topic("test.msg", mock_message, "v1.2.3")
pigeon_client._topics["test.msg"] = mock_message
pigeon_client._hashes["test.msg"] = "abcde"
pigeon_client.subscribe("test.msg", callback)

pigeon_client._handle_message(mock_stomp_message)
Expand All @@ -63,7 +65,7 @@ def test_three_args(pigeon_client):
mock_stomp_message = MagicMock()
mock_stomp_message.headers = {
"subscription": "test.msg",
"version": "v1.2.3",
"hash": "123abc",
}

mock_message = MagicMock()
Expand All @@ -75,7 +77,8 @@ def callback(msg, topic, headers):
assert headers == mock_stomp_message.headers

pigeon_client._connection = MagicMock()
pigeon_client.register_topic("test.msg", mock_message, "v1.2.3")
pigeon_client._topics["test.msg"] = mock_message
pigeon_client._hashes["test.msg"] = "123abc"
pigeon_client.subscribe("test.msg", callback)

pigeon_client._handle_message(mock_stomp_message)
Expand All @@ -85,7 +88,7 @@ def test_var_args(pigeon_client):
mock_stomp_message = MagicMock()
mock_stomp_message.headers = {
"subscription": "test.msg",
"version": "v1.2.3",
"hash": "xyz987",
}

mock_message = MagicMock()
Expand All @@ -98,7 +101,8 @@ def callback(*args):
assert args[2] == mock_stomp_message.headers

pigeon_client._connection = MagicMock()
pigeon_client.register_topic("test.msg", mock_message, "v1.2.3")
pigeon_client._topics["test.msg"] = mock_message
pigeon_client._hashes["test.msg"] = "xyz987"
pigeon_client.subscribe("test.msg", callback)

pigeon_client._handle_message(mock_stomp_message)
Expand All @@ -118,25 +122,27 @@ def test_topic_does_not_exist(pigeon_client):
)


def test_version_mismatch(pigeon_client):
mock_message = create_mock_message(subscription="test", version="v0.1.1")
def test_hash_mismatch(pigeon_client):
mock_message = create_mock_message(subscription="test", hash="abc1")

pigeon_client.register_topic("test", lambda x: x, "v0.1.0")
pigeon_client._topics["test"] = None
pigeon_client._hashes["test"] = "abcd"
pigeon_client._handle_message(mock_message)

pigeon_client._logger.warning.assert_called_with(
"Received a message on topic 'test' with an incorrect version v0.1.1. Version should be v0.1.0"
"Received a message on topic 'test' with an incorrect hash: abc1. Expected: abcd"
)


def test_validation_error(pigeon_client):
mock_message = create_mock_message(subscription="test", version="v0.1.0")
mock_message = create_mock_message(subscription="test", hash="abc123")
mock_msg_def = MagicMock()
mock_msg_def.deserialize.side_effect = ValidationError.from_exception_data(
title="Test", line_errors=[]
)

pigeon_client.register_topic("test", mock_msg_def, "v0.1.0")
pigeon_client._topics["test"] = mock_msg_def
pigeon_client._hashes["test"] = "abc123"
pigeon_client._handle_message(mock_message)

pigeon_client._logger.warning.assert_called_with(
Expand All @@ -145,9 +151,10 @@ def test_validation_error(pigeon_client):


def test_no_callback(pigeon_client):
mock_message = create_mock_message(subscription="test", version="v0.1.0")
mock_message = create_mock_message(subscription="test", hash="4321")

pigeon_client.register_topic("test", MagicMock(), "v0.1.0")
pigeon_client._topics["test"] = MagicMock()
pigeon_client._hashes["test"] = "4321"
pigeon_client._handle_message(mock_message)

pigeon_client._logger.warning.assert_called_with(
Expand All @@ -156,10 +163,11 @@ def test_no_callback(pigeon_client):


def test_bad_signature(pigeon_client):
mock_message = create_mock_message(subscription="test", version="v0.1.0")
mock_message = create_mock_message(subscription="test", hash="lmnop")
callback = lambda a, b, c, d: None

pigeon_client.register_topic("test", MagicMock(), "v0.1.0")
pigeon_client._topics["test"] = MagicMock()
pigeon_client._hashes["test"] = "lmnop"
pigeon_client.subscribe("test", callback)
pigeon_client._handle_message(mock_message)

Expand All @@ -169,9 +177,10 @@ def test_bad_signature(pigeon_client):


def test_callback_exception(pigeon_client):
mock_message = create_mock_message(subscription="test", version="v0.1.0")
mock_message = create_mock_message(subscription="test", hash="987654321")

pigeon_client.register_topic("test", MagicMock(), "v0.1.0")
pigeon_client._topics["test"] = MagicMock()
pigeon_client._hashes["test"] = "987654321"
pigeon_client.subscribe(
"test", MagicMock(side_effect=RecursionError("This is a test error."))
)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pigeon.utils import get_message_hash


def test_get_message_hash():
class TestMsg:
attr_one: int
attr_two: str
attr_three: float

assert get_message_hash(TestMsg) == "e6b05f8920682eca0ba8b415c9fa7a7f248ddfce"

0 comments on commit fb34085

Please sign in to comment.