Skip to content

Commit

Permalink
Merge pull request #11 from rhytnen/broker_state_messages
Browse files Browse the repository at this point in the history
Broker state messages
  • Loading branch information
CameronDevine authored Sep 19, 2024
2 parents 642b809 + 962d151 commit 60c0d2e
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 25 deletions.
2 changes: 1 addition & 1 deletion examples/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pigeon.client import Pigeon
from pigeon.logging import setup_logging
from pigeon.base_msg import BaseMessage
from pigeon.messages import BaseMessage


class TestMsg(BaseMessage):
Expand Down
2 changes: 1 addition & 1 deletion examples/subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pigeon.client import Pigeon
from pigeon.logging import setup_logging
from pigeon.base_msg import BaseMessage
from pigeon.messages import BaseMessage

logger = setup_logging("subscriber")

Expand Down
2 changes: 1 addition & 1 deletion pigeon/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .client import Pigeon
from .base_msg import BaseMessage
from .messages import BaseMessage
81 changes: 60 additions & 21 deletions pigeon/client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import logging
import os
import socket
import time
from importlib.metadata import entry_points
from typing import Callable, Dict

import stomp
from typing import Callable, Dict
from stomp.utils import Frame
import stomp.exception
from importlib.metadata import entry_points
from pydantic import ValidationError
from stomp.utils import Frame

from . import messages
from . import exceptions
from .utils import call_with_correct_args

Expand All @@ -33,12 +36,12 @@ class Pigeon:
"""

def __init__(
self,
service: str,
host: str = "127.0.0.1",
port: int = 61616,
logger: logging.Logger = None,
load_topics: bool = True,
self,
service: str,
host: str = "127.0.0.1",
port: int = 61616,
logger: logging.Logger = None,
load_topics: bool = True,
):
"""
Args:
Expand All @@ -62,6 +65,22 @@ def __init__(
)
self._logger = logger if logger is not None else self._configure_logging()

self._pid = os.getpid()
self._hostname = socket.gethostname().split('.')[0]
self._name = f"{self._service}_{self._pid}_{self._hostname}"


for topic, callback in messages.topics.items():
self.register_topic(topic, callback, version=messages.msg_version)

def _announce(self, connected=True):
self.send("&_announce_connection", name=self._name, pid=self._pid, hostname=self._hostname,
service=self._service, connected=connected)

def _update_state(self):
self.send("&_update_state", name=self._name, pid=self._pid, hostname=self._hostname,
service=self._service, subscribed_to=list(self._callbacks.keys()))

@staticmethod
def _configure_logging() -> logging.Logger:
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -89,23 +108,24 @@ def register_topics(self, topics: Dict[str, Callable], version: str):
Args:
topics: A mapping of topics to Pydantic model message definitions.
version: The version of these messages.
"""
for topic in topics.items():
self.register_topic(*topic, version)

def connect(
self,
username: str = None,
password: str = None,
retry_limit: int = 8,
self,
username: str = None,
password: str = None,
retry_limit: int = 8,
):
"""
Connects to the STOMP server using the provided username and password.
Args:
username (str, optional): The username to authenticate with. Defaults to None.
password (str, optional): The password to authenticate with. Defaults to None.
retry_limit (int, optional): Number of times to attempt connection
Raises:
stomp.exception.ConnectFailedException: If the connection to the server fails.
Expand All @@ -127,6 +147,10 @@ def connect(
f"Could not connect to server: {e}"
) from e


self.subscribe("&_request_state", self._update_state)
self._announce()

def send(self, topic: str, **data):
"""
Sends data to the specified topic.
Expand All @@ -141,8 +165,12 @@ def send(self, topic: str, **data):
"""
self._ensure_topic_exists(topic)
serialized_data = self._topics[topic](**data).serialize()

headers = dict(
source = self._name,
service=self._service,
hostname=self._hostname,
pid=self._pid,
version=self._msg_versions[topic],
sent_at=get_str_time_ms(),
)
Expand Down Expand Up @@ -189,16 +217,16 @@ def _handle_message(self, message_frame: Frame):
f"Callback for topic '{topic}' failed with error:", exc_info=True
)

def subscribe(self, topic: str, callback: Callable):
def subscribe(self, topic: str, callback: Callable, send_update=True):
"""
Subscribes to a topic and associates a callback function to handle incoming messages.
Args:
topic (str): The topic to subscribe to.
callback (Callable): The callback function to handle incoming
messages. It may accept up to three arguments. In order, the
arguments are, the recieved message, the topic the message was
recieved on, and the message headers.
arguments are, the received message, the topic the message was
received on, and the message headers.
Raises:
NoSuchTopicException: If the specified topic is not defined.
Expand All @@ -209,16 +237,26 @@ def subscribe(self, topic: str, callback: Callable):
self._connection.subscribe(destination=topic, id=topic)
self._callbacks[topic] = callback
self._logger.info(f"Subscribed to {topic} with {callback}.")
if send_update:
self._update_state()

def subscribe_all(self, callback: Callable):
def subscribe_all(self, callback: Callable, include_core=False):
"""Subscribes to all registered topics.
Args:
callback: The function to call when a message is recieved. It must
callback: The function to call when a message is received. It must
accept two arguments, the topic and the message data.
include_core (bool): If true, subscribe all will subscribe the client to core messages.
"""

# Additional logic here is to avoid subscribe_all changing behavior and always subscribing to core topics.
for topic in self._topics:
self.subscribe(topic, callback)
if topic in messages.topics and not include_core:
continue
if topic is "&_request_state":
continue
self.subscribe(topic, callback, send_update=False)
self._update_state()

def unsubscribe(self, topic: str):
"""Unsubscribes from a given topic.
Expand All @@ -233,6 +271,7 @@ def unsubscribe(self, topic: str):
def disconnect(self):
"""Disconnect from the STOMP message broker."""
if self._connection.is_connected():
self._announce(connected=False)
self._connection.disconnect()
self._logger.info("Disconnected from STOMP server.")

Expand All @@ -242,5 +281,5 @@ def __init__(self, callback: Callable):
self.callback = callback

def on_message(self, frame):
frame.headers["recieved_at"] = get_str_time_ms()
frame.headers["received_at"] = get_str_time_ms()
self.callback(frame)
28 changes: 28 additions & 0 deletions pigeon/base_msg.py → pigeon/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,31 @@ def deserialize(cls, data: str):
An instantiation of the model using the JSON data.
"""
return cls.model_validate_json(data)


class AnnounceConnection(BaseMessage):
name: str
service: str
pid: int
hostname: str
connected: bool


class RequestState(BaseMessage):
...


class UpdateState(BaseMessage):
name: str
service: str
pid: int
hostname: str
subscribed_to: list[str]

msg_version = "1.0.0"

topics = {
"&_announce_connection": AnnounceConnection,
"&_request_state": RequestState,
"&_update_state": UpdateState
}
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pydantic
stomp-py
pyyaml
pyyaml


0 comments on commit 60c0d2e

Please sign in to comment.