Skip to content

Commit

Permalink
Merge pull request #1 from AllenInstitute/variable_args
Browse files Browse the repository at this point in the history
Add support for a variable number of callback arguments
  • Loading branch information
CameronDevine authored Jun 27, 2024
2 parents 49b60ed + 9a2d102 commit 0098952
Show file tree
Hide file tree
Showing 7 changed files with 326 additions and 55 deletions.
67 changes: 54 additions & 13 deletions pigeon/__main__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,71 @@
from .client import Pigeon
import argparse
import yaml
from functools import partial


class Listener:
def __init__(self):
def __init__(self, disp_headers):
self.message_received = False
self.disp_headers = disp_headers

def callback(self, topic, msg):
def callback(self, msg, topic, headers):
print(f"Recieved message on topic '{topic}':")
print(msg)
if self.disp_headers:
print("With headers:")
for key, val in headers.items():
print(f"{key}={val}")
self.message_received = True


def main():
parser = argparse.ArgumentParser(prog="Pigeon CLI")
parser.add_argument("--host", type=str, default="127.0.0.1", help="The message broker to connect to.")
parser.add_argument("--port", type=int, default=61616, help="The port to use for the connection.")
parser.add_argument("--username", type=str, help="The username to use when connecting to the STOMP server.")
parser.add_argument("--password", type=str, help="The password to use when connecting to the STOMP server.")
parser.add_argument("-p", "--publish", type=str, help="The topic to publish a message to.")
parser.add_argument("-d", "--data", type=str, help="The YAML/JSON formatted data to publish.")
parser.add_argument("-s", "--subscribe", type=str, action="append", default=[], help="The topic to subscribe to.")
parser.add_argument("--one", action="store_true", help="Exit after receiving one message.")
parser.add_argument("-a", "--all", action="store_true", help="Subscribe to all registered topics.")
parser.add_argument("-l", "--list", action="store_true", help="List registered topics and exit.")
parser.add_argument(
"--host",
type=str,
default="127.0.0.1",
help="The message broker to connect to.",
)
parser.add_argument(
"--port", type=int, default=61616, help="The port to use for the connection."
)
parser.add_argument(
"--username",
type=str,
help="The username to use when connecting to the STOMP server.",
)
parser.add_argument(
"--password",
type=str,
help="The password to use when connecting to the STOMP server.",
)
parser.add_argument(
"-p", "--publish", type=str, help="The topic to publish a message to."
)
parser.add_argument(
"-d", "--data", type=str, help="The YAML/JSON formatted data to publish."
)
parser.add_argument(
"-s",
"--subscribe",
type=str,
action="append",
default=[],
help="The topic to subscribe to.",
)
parser.add_argument(
"-a", "--all", action="store_true", help="Subscribe to all registered topics."
)
parser.add_argument(
"--one", action="store_true", help="Exit after receiving one message."
)
parser.add_argument(
"-l", "--list", action="store_true", help="List registered topics and exit."
)
parser.add_argument(
"--headers", action="store_true", help="Display headers of received messages."
)

args = parser.parse_args()

Expand Down Expand Up @@ -53,7 +94,7 @@ def main():
connection.send(args.publish, **yaml.safe_load(args.data))

if args.subscribe or args.all:
listener = Listener()
listener = Listener(args.headers)

if args.all:
connection.subscribe_all(listener.callback)
Expand Down
37 changes: 25 additions & 12 deletions pigeon/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
import time
import time

import stomp
from typing import Callable, Dict
Expand All @@ -8,11 +8,16 @@
from importlib.metadata import entry_points

from . import exceptions
from .utils import call_with_correct_args


def get_str_time_ms():
return str(int(time.time_ns() / 1e6))


class Pigeon:
"""A STOMP client with message definitions via Pydantic
This class is a STOMP message client which will automatically serialize and
deserialize message data using Pydantic models. Before sending or receiving
messages, topics must be "registered", or in other words, have a Pydantic
Expand Down Expand Up @@ -65,7 +70,7 @@ def _load_topics(self):

def register_topic(self, topic: str, msg_class: Callable, version: str):
"""Register message definition for a given topic.
Args:
topic: The topic that this message definition applies to.
msg_class: The Pydantic model definition of the message.
Expand All @@ -76,7 +81,7 @@ def register_topic(self, topic: str, msg_class: Callable, version: str):

def register_topics(self, topics: Dict[str, Callable], version: str):
"""Register a number of message definitions for multiple topics.
Args:
topics: A mapping of topics to Pydantic model message definitions.
version: The version of these messages.
Expand Down Expand Up @@ -132,7 +137,11 @@ def send(self, topic: str, **data):
"""
self._ensure_topic_exists(topic)
serialized_data = self._topics[topic](**data).serialize()
headers = dict(service=self._service, version=self._msg_versions[topic])
headers = dict(
service=self._service,
version=self._msg_versions[topic],
sent_at=get_str_time_ms(),
)
self._connection.send(destination=topic, body=serialized_data, headers=headers)
self._logger.debug(f"Sent data to {topic}: {serialized_data}")

Expand All @@ -148,7 +157,9 @@ def _handle_message(self, message_frame: Frame):
if message_frame.headers.get("version") != self._msg_versions.get(topic):
raise exceptions.VersionMismatchException
message_data = self._topics[topic].deserialize(message_frame.body)
self._callbacks[topic](topic, message_data)
call_with_correct_args(
self._callbacks[topic], message_data, topic, message_frame.headers
)

def subscribe(self, topic: str, callback: Callable):
"""
Expand All @@ -157,8 +168,9 @@ def subscribe(self, topic: str, callback: Callable):
Args:
topic (str): The topic to subscribe to.
callback (Callable): The callback function to handle incoming
messages. It must accept two arguments, the topic and the
message data.
messages. It may accept up to three arguments. In order, the
arguments are, the recieved message, the topic the message was
recieved on, and the message headers.
Raises:
NoSuchTopicException: If the specified topic is not defined.
Expand All @@ -168,11 +180,11 @@ def subscribe(self, topic: str, callback: Callable):
if topic not in self._callbacks:
self._connection.subscribe(destination=topic, id=topic)
self._callbacks[topic] = callback
self._logger.info(f"Subscribed to {topic} with {callback.__name__}.")
self._logger.info(f"Subscribed to {topic} with {callback}.")

def subscribe_all(self, callback: Callable):
"""Subscribes to all registered topics.
Args:
callback: The function to call when a message is recieved. It must
accept two arguments, the topic and the message data.
Expand All @@ -182,7 +194,7 @@ def subscribe_all(self, callback: Callable):

def unsubscribe(self, topic: str):
"""Unsubscribes from a given topic.
Args:
topic: The topic to unsubscribe from."""
self._ensure_topic_exists(topic)
Expand All @@ -197,9 +209,10 @@ def disconnect(self):
self._logger.info("Disconnected from STOMP server.")


class TEMCommsListener(stomp.ConnectionListener):
class TEMCommsListener(stomp.ConnectionListener):
def __init__(self, callback: Callable):
self.callback = callback

def on_message(self, frame):
frame.headers["recieved_at"] = get_str_time_ms()
self.callback(frame)
13 changes: 0 additions & 13 deletions pigeon/logging.py

This file was deleted.

47 changes: 47 additions & 0 deletions pigeon/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
import inspect
from copy import copy


def setup_logging(logger_name: str, log_level: int = logging.INFO):
logger = logging.getLogger(logger_name)
handler = logging.StreamHandler()
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(log_level)
return logger


def call_with_correct_args(func, *args, **kwargs):
args = copy(args)
kwargs = copy(kwargs)
params = inspect.signature(func).parameters

if True not in [
param.kind == inspect._ParameterKind.VAR_POSITIONAL for param in params.values()
]:
num_args = len(
[
None
for param in params.values()
if param.default == param.empty and param.kind != param.VAR_KEYWORD
]
)
if num_args > len(args):
raise TypeError(
f"Function '{func}' requires {num_args} arguments, but only {len(args)} are available."
)
args = args[:num_args]

if True not in [
param.kind == inspect._ParameterKind.VAR_KEYWORD for param in params.values()
]:
allowed_keys = [key for key, val in params.items() if val.default != val.empty]
for key in list(kwargs.keys()):
if key not in allowed_keys:
del kwargs[key]

return func(*args, **kwargs)
83 changes: 83 additions & 0 deletions tests/test_call_with_correct_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from pigeon.utils import call_with_correct_args
import pytest


def test_not_enough_args():
def test_func(a, b, c, d):
return a, b, c, d

with pytest.raises(TypeError):
call_with_correct_args(test_func, 1, 2, 3)


def test_equal_args():
def test_func(a, b, c, d):
return a, b, c, d

assert call_with_correct_args(test_func, 1, 2, 3, 4) == (1, 2, 3, 4)


def test_args():
def test_func(a, b, c, d):
return a, b, c, d

assert call_with_correct_args(test_func, 1, 2, 3, 4, 5) == (1, 2, 3, 4)


def test_not_enough_kwargs():
def test_func(a=1, b=2, c=3):
return a, b, c

assert call_with_correct_args(test_func, a=10, b=11) == (10, 11, 3)


def test_no_args():
def test_func():
return True

assert call_with_correct_args(test_func, 1, 2, 3)


def test_both():
def test_func(a, b, c, d=1, e=2):
return a, b, c, d, e

assert call_with_correct_args(test_func, 1, 2, 3, 4, 5, d=10, e=11, f=12) == (
1,
2,
3,
10,
11,
)


def test_var_args():
def test_func(a, b, *args):
return a, b, args

assert call_with_correct_args(test_func, 1, 2, 3, 4) == (1, 2, (3, 4))


def test_var_kwargs():
def test_func(a=1, b=2, **kwargs):
return a, b, kwargs

assert call_with_correct_args(test_func, 1, 2, 3, a=10, c=11, d=12) == (
10,
2,
{"c": 11, "d": 12},
)


def test_both_var():
def test_func(a, b, *args, c=1, d=2, **kwargs):
return a, b, c, d, args, kwargs

assert call_with_correct_args(test_func, 1, 2, 3, 4, e=1, c=12, f=13) == (
1,
2,
12,
2,
(3, 4),
{"e": 1, "f": 13},
)
Loading

0 comments on commit 0098952

Please sign in to comment.