Skip to content

Commit

Permalink
Add logic to websocket reconnection to re-subscribe to registered tri…
Browse files Browse the repository at this point in the history
…ggers.
  • Loading branch information
nlioc4 committed Feb 20, 2024
1 parent 23dcf92 commit 8fbff23
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions auraxium/event/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
_EventT = TypeVar('_EventT', bound=Event)
_EventT2 = TypeVar('_EventT2', bound=Event)
_CallbackT = Union[Callable[[_EventT], None],
Callable[[_EventT], Coroutine[Any, Any, None]]]
Callable[[_EventT], Coroutine[Any, Any, None]]]

_log = logging.getLogger('auraxium.ess')

Expand Down Expand Up @@ -170,6 +170,14 @@ def remove_trigger(self, trigger: Union[Trigger, str], *,
_log.info('All triggers have been removed, closing websocket')
self.loop.create_task(self.close())

def _subscribe_all(self):
"""Add subscription messages for every registered trigger.
This will add a subscription message for every trigger currently registered with the client.
Useful for resubscribing to all events after a disconnect.
"""
self._send_queue.extend([trigger.generate_subscription() for trigger in self.triggers])

async def close(self) -> None:
"""Gracefully shut down the client.
Expand All @@ -186,7 +194,7 @@ async def connect(self) -> None:
This will continuously loop until :meth:`EventClient.close` is
called.
If the WebSocket connection encounters and error, it will be
If the WebSocket connection encounters an error, it will be
automatically restarted.
Any event payloads received will be passed to
Expand Down Expand Up @@ -262,9 +270,12 @@ async def _connection_handler(self) -> None:
# NOTE: The following "async for" loop will cleanly restart the
# connection should it go down. Invoking "continue" manually may be
# used to manually force a reconnect if needed.

connection_failed = False
async for websocket in websockets.client.connect(str(url)):
_log.info('Connected to %s', url)
if connection_failed:
self._subscribe_all()
connection_failed = False
self.websocket = websocket

try:
Expand All @@ -273,6 +284,7 @@ async def _connection_handler(self) -> None:

except websockets.exceptions.ConnectionClosed:
_log.info('Connection closed, restarting...')
connection_failed = True
continue

if not self._open:
Expand Down Expand Up @@ -312,22 +324,22 @@ async def _handle_websocket(self, timeout: float = 0.1) -> None:
def trigger(self, event: Type[_EventT], *, name: Optional[str] = None,
**kwargs: Any) -> Callable[[_CallbackT[_EventT]], None]:
# Single event variant (checks callback argument type)
... # pragma: no cover
... # pragma: no cover

@overload
def trigger(self, event: Type[_EventT],
arg1: Type[_EventT], *args: Type[_EventT2],
name: Optional[str] = None, **kwargs: Any) -> Callable[
[_CallbackT[Union[_EventT, _EventT2]]], None]:
[_CallbackT[Union[_EventT, _EventT2]]], None]:
# Two event variant (checks callback argument type)
... # pragma: no cover
... # pragma: no cover

@overload
def trigger(self, event: Union[str, Type[Event]],
*args: Union[str, Type[Event]], name: Optional[str] = None,
**kwargs: Any) -> Callable[[_CallbackT[Event]], None]:
# Generic fallback variant (callback argument type not checked)
... # pragma: no cover
... # pragma: no cover

def trigger(self, event: Union[str, Type[Event]],
*args: Union[str, Type[Event]], name: Optional[str] = None,
Expand Down

0 comments on commit 8fbff23

Please sign in to comment.