Skip to content

Commit

Permalink
Allow passing headers to client.subscribe() (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
vrslev authored Aug 23, 2024
1 parent db2b5e5 commit d06cd07
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 11 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ await client.subscribe("DLQ", handle_message_from_dlq, ack="client", on_suppress
await client.subscribe("DLQ", handle_message_from_dlq, ack="auto", on_suppressed_exception=print)
```

You can pass custom headers to `client.subscribe()`:

```python
await client.subscribe("DLQ", handle_message_from_dlq, ack="client", headers={"selector": "location = 'Europe'"}, on_suppressed_exception=print)
```

### Cleaning Up

stompman takes care of cleaning up resources automatically. When you leave the context of async context managers `stompman.Client()`, or `client.begin()`, the necessary frames will be sent to the server.
Expand Down
2 changes: 2 additions & 0 deletions stompman/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,14 @@ async def subscribe(
handler: Callable[[MessageFrame], Coroutine[None, None, None]],
*,
ack: AckMode = "client-individual",
headers: dict[str, str] | None = None,
on_suppressed_exception: Callable[[Exception, MessageFrame], None],
supressed_exception_classes: tuple[type[Exception], ...] = (Exception,),
) -> "Subscription":
subscription = Subscription(
destination=destination,
handler=handler,
headers=headers,
ack=ack,
on_suppressed_exception=on_suppressed_exception,
supressed_exception_classes=supressed_exception_classes,
Expand Down
18 changes: 12 additions & 6 deletions stompman/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,20 +151,26 @@ def build(
content_type: str | None,
headers: dict[str, str] | None,
) -> Self:
full_headers: SendHeaders = headers or {} # type: ignore[assignment]
full_headers["destination"] = destination
full_headers["content-length"] = str(len(body))
all_headers: SendHeaders = headers or {} # type: ignore[assignment]
all_headers["destination"] = destination
all_headers["content-length"] = str(len(body))
if content_type is not None:
full_headers["content-type"] = content_type
all_headers["content-type"] = content_type
if transaction is not None:
full_headers["transaction"] = transaction
return cls(headers=full_headers, body=body)
all_headers["transaction"] = transaction
return cls(headers=all_headers, body=body)


@dataclass(frozen=True, kw_only=True, slots=True)
class SubscribeFrame:
headers: SubscribeHeaders

@classmethod
def build(cls, *, subscription_id: str, destination: str, ack: AckMode, headers: dict[str, str] | None) -> Self:
all_headers: SubscribeHeaders = headers.copy() if headers else {} # type: ignore[assignment, typeddict-item]
all_headers.update({"id": subscription_id, "destination": destination, "ack": ack})
return cls(headers=all_headers)


@dataclass(frozen=True, kw_only=True, slots=True)
class UnsubscribeFrame:
Expand Down
12 changes: 9 additions & 3 deletions stompman/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
class Subscription:
id: str = field(default_factory=lambda: _make_subscription_id(), init=False) # noqa: PLW0108
destination: str
headers: dict[str, str] | None
handler: Callable[[MessageFrame], Coroutine[None, None, None]]
ack: AckMode
on_suppressed_exception: Callable[[Exception, MessageFrame], None]
Expand All @@ -34,7 +35,9 @@ def __post_init__(self) -> None:

async def _subscribe(self) -> None:
await self._connection_manager.write_frame_reconnecting(
SubscribeFrame(headers={"id": self.id, "destination": self.destination, "ack": self.ack})
SubscribeFrame.build(
subscription_id=self.id, destination=self.destination, ack=self.ack, headers=self.headers
)
)
self._active_subscriptions[self.id] = self

Expand Down Expand Up @@ -71,8 +74,11 @@ async def resubscribe_to_active_subscriptions(
) -> None:
for subscription in active_subscriptions.values():
await connection.write_frame(
SubscribeFrame(
headers={"id": subscription.id, "destination": subscription.destination, "ack": subscription.ack}
SubscribeFrame.build(
subscription_id=subscription.id,
destination=subscription.destination,
ack=subscription.ack,
headers=subscription.headers,
)
)

Expand Down
14 changes: 12 additions & 2 deletions tests/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ async def test_client_subscribtions_lifespan_resubscribe(ack: AckMode) -> None:
connection_class, collected_frames = create_spying_connection(*get_read_frames_with_lifespan([CONNECTED_FRAME], []))
client = EnrichedClient(connection_class=connection_class)
sub_destination, message_destination, message_body = FAKER.pystr(), FAKER.pystr(), FAKER.binary(length=10)
sub_extra_headers = FAKER.pydict(value_types=[str])

async with client:
subscription = await client.subscribe(
destination=sub_destination,
handler=noop_message_handler,
ack=ack,
headers=sub_extra_headers,
on_suppressed_exception=noop_error_handler,
)
client._connection_manager._clear_active_connection_state()
Expand All @@ -57,11 +59,19 @@ async def test_client_subscribtions_lifespan_resubscribe(ack: AckMode) -> None:
await asyncio.sleep(0)
await asyncio.sleep(0)

subscribe_frame = SubscribeFrame(
headers={
"id": subscription.id,
"destination": sub_destination,
"ack": ack,
**sub_extra_headers, # type: ignore[typeddict-item]
}
)
assert collected_frames == enrich_expected_frames(
SubscribeFrame(headers={"id": subscription.id, "destination": sub_destination, "ack": ack}),
subscribe_frame,
CONNECT_FRAME,
CONNECTED_FRAME,
SubscribeFrame(headers={"id": subscription.id, "destination": sub_destination, "ack": ack}),
subscribe_frame,
SendFrame(
headers={"destination": message_destination, "content-length": str(len(message_body))}, body=message_body
),
Expand Down

0 comments on commit d06cd07

Please sign in to comment.