Skip to content

Commit

Permalink
Merge pull request #19 from dbluhm/refactor/backend-interface
Browse files Browse the repository at this point in the history
refactor: backend interface
  • Loading branch information
dbluhm authored May 10, 2024
2 parents 2c4cb4b + 8b90e44 commit d5d0fd2
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 88 deletions.
24 changes: 16 additions & 8 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
FROM python:3.9-slim-bullseye
FROM python:3.9-slim-bookworm as base
WORKDIR /usr/src/app

ENV POETRY_VERSION=1.4.2

RUN apt-get update && apt-get install -y curl && apt-get clean
RUN pip install "poetry==$POETRY_VERSION"
COPY poetry.lock pyproject.toml README.md ./
RUN mkdir -p socketdock && touch socketdock/__init__.py
RUN poetry config virtualenvs.create false \
&& poetry install --without=dev --no-interaction --no-ansi
ENV POETRY_VERSION=1.5.1
ENV POETRY_HOME=/opt/poetry
RUN curl -sSL https://install.python-poetry.org | python -

ENV PATH="/opt/poetry/bin:$PATH"
RUN poetry config virtualenvs.in-project true

# Setup project
COPY pyproject.toml poetry.lock ./
RUN poetry install --without dev

FROM python:3.9-slim-bookworm
WORKDIR /usr/src/app
COPY --from=base /usr/src/app/.venv /usr/src/app/.venv
ENV PATH="/usr/src/app/.venv/bin:$PATH"

COPY socketdock socketdock

Expand Down
3 changes: 1 addition & 2 deletions docker-compose-local.yaml → demo/docker-compose-local.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ version: '3'

services:
websocket-gateway:
build: .
build: ..
ports:
- "8765:8765"
volumes:
- ./server:/code
- ./wait-for-tunnel.sh:/wait-for-tunnel.sh:ro,z
entrypoint: /wait-for-tunnel.sh
command: >
Expand Down
16 changes: 16 additions & 0 deletions demo/docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
version: '3'

services:
socketdock:
build: ..
ports:
- "8765:8765"
volumes:
- ../socketdock:/usr/src/app/socketdock:z
command: >
--bindip 0.0.0.0
--backend loopback
--message-uri https://example.com
--disconnect-uri https://example.com
--endpoint http://socketdock:8765
--log-level INFO
3 changes: 2 additions & 1 deletion socket_client.py → demo/socket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ async def hello():
print(f"< {response}", flush=True)


asyncio.run(hello())
if __name__ == "__main__":
asyncio.run(hello())
2 changes: 1 addition & 1 deletion wait-for-tunnel.sh → demo/wait-for-tunnel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ done
WS_ENDPOINT=$(curl --silent "${TUNNEL_ENDPOINT}/start" | python -c "import sys, json; print(json.load(sys.stdin)['url'])" | sed -rn 's#https?://([^/]+).*#\1#p')
echo "fetched hostname and port [$WS_ENDPOINT]"

exec "$@" --externalhostandport ${WS_ENDPOINT}
exec "$@" --externalhostandport ${WS_ENDPOINT}
23 changes: 0 additions & 23 deletions docker-compose.yaml

This file was deleted.

9 changes: 5 additions & 4 deletions socketdock/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import argparse
from sanic import Sanic

from .api import api, backend_var, endpoint_var
from .api import api, backend_var


def config() -> argparse.Namespace:
Expand Down Expand Up @@ -34,16 +34,17 @@ def main():
if args.backend == "loopback":
from .testbackend import TestBackend

backend = TestBackend()
backend = TestBackend(args.endpoint)
elif args.backend == "http":
from .httpbackend import HTTPBackend

backend = HTTPBackend(args.connect_uri, args.message_uri, args.disconnect_uri)
backend = HTTPBackend(
args.endpoint, args.connect_uri, args.message_uri, args.disconnect_uri
)
else:
raise ValueError("Invalid backend type")

backend_var.set(backend)
endpoint_var.set(args.endpoint)

logging.basicConfig(level=args.log_level)

Expand Down
34 changes: 11 additions & 23 deletions socketdock/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .backend import Backend

backend_var: ContextVar[Backend] = ContextVar("backend")
endpoint_var: ContextVar[str] = ContextVar("endpoint")

api = Blueprint("api", url_prefix="/")

Expand Down Expand Up @@ -45,7 +44,7 @@ async def status_handler(request: Request):
async def socket_send(request: Request, connectionid: str):
"""Send a message to a connected socket."""
LOGGER.info("Inbound message for %s", connectionid)
LOGGER.info("Existing connections: %s", active_connections.keys())
LOGGER.debug("Existing connections: %s", active_connections.keys())

if connectionid not in active_connections:
return text("FAIL", status=500)
Expand All @@ -62,7 +61,7 @@ async def socket_send(request: Request, connectionid: str):
async def socket_disconnect(request: Request, connectionid: str):
"""Disconnect a socket."""
LOGGER.info("Disconnect %s", connectionid)
LOGGER.info("Existing connections: %s", active_connections.keys())
LOGGER.debug("Existing connections: %s", active_connections.keys())

if connectionid not in active_connections:
return text("FAIL", status=500)
Expand All @@ -78,37 +77,26 @@ async def socket_handler(request: Request, websocket: Websocket):
global lifetime_connections
backend = backend_var.get()
socket_id = None
endpoint = endpoint_var.get()
send = f"{endpoint}/socket/{socket_id}/send"
disconnect = f"{endpoint_var.get()}/socket/{socket_id}/disconnect"
try:
# register user
LOGGER.info("new client connected")
socket_id = websocket.connection.id.hex
socket_id = websocket.ws_proto.id.hex
active_connections[socket_id] = websocket
lifetime_connections += 1
LOGGER.info("Existing connections: %s", active_connections.keys())
LOGGER.info("Added connection: %s", socket_id)
LOGGER.info("Request headers: %s", dict(request.headers.items()))
LOGGER.debug("Existing connections: %s", active_connections.keys())
LOGGER.debug("Added connection: %s", socket_id)
LOGGER.debug("Request headers: %s", dict(request.headers.items()))

await backend.socket_connected(
{
"connection_id": socket_id,
"headers": dict(request.headers.items()),
"send": send,
"disconnect": disconnect,
},
connection_id=socket_id,
headers=dict(request.headers.items()),
)

async for message in websocket:
if message:
await backend.inbound_socket_message(
{
"connection_id": socket_id,
"send": send,
"disconnect": disconnect,
},
message,
connection_id=socket_id,
message=message,
)
else:
LOGGER.warning("empty message received")
Expand All @@ -118,4 +106,4 @@ async def socket_handler(request: Request, websocket: Websocket):
if socket_id:
del active_connections[socket_id]
LOGGER.info("Removed connection: %s", socket_id)
await backend.socket_disconnected({"connection_id": socket_id})
await backend.socket_disconnected(socket_id)
17 changes: 10 additions & 7 deletions socketdock/backend.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
"""Backend interface for SocketDock."""

from abc import ABC, abstractmethod
from typing import Union
from typing import Dict, Union


class Backend(ABC):
"""Backend interface for SocketDock."""

@abstractmethod
async def socket_connected(self, callback_uris: dict):
async def socket_connected(
self,
connection_id: str,
headers: Dict[str, str],
):
"""Handle new socket connections, with calback provided."""
raise NotImplementedError()

@abstractmethod
async def inbound_socket_message(
self, callback_uris: dict, message: Union[str, bytes]
self,
connection_id: str,
message: Union[str, bytes],
):
"""Handle inbound socket message, with calback provided."""
raise NotImplementedError()

@abstractmethod
async def socket_disconnected(self, bundle: dict):
async def socket_disconnected(self, connection_id: str):
"""Handle socket disconnected."""
raise NotImplementedError()
57 changes: 48 additions & 9 deletions socketdock/httpbackend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""HTTP backend for SocketDock."""

import logging
from typing import Union
from typing import Dict, Union

import aiohttp

Expand All @@ -14,16 +14,46 @@
class HTTPBackend(Backend):
"""HTTP backend for SocketDock."""

def __init__(self, connect_uri: str, message_uri: str, disconnect_uri: str):
def __init__(
self,
socket_base_uri: str,
connect_uri: str,
message_uri: str,
disconnect_uri: str,
):
"""Initialize HTTP backend."""
self._connect_uri = connect_uri
self._message_uri = message_uri
self._disconnect_uri = disconnect_uri
self.socket_base_uri = socket_base_uri

def send_callback(self, connection_id: str) -> str:
"""Return the callback URI for sending a message to a connected socket."""
return f"{self.socket_base_uri}/socket/{connection_id}/send"

def disconnect_callback(self, connection_id: str) -> str:
"""Return the callback URI for disconnecting a connected socket."""
return f"{self.socket_base_uri}/socket/{connection_id}/disconnect"

async def socket_connected(self, callback_uris: dict):
def callback_uris(self, connection_id: str) -> Dict[str, str]:
"""Return labelled callback URIs."""
return {
"send": self.send_callback(connection_id),
"disconnect": self.disconnect_callback(connection_id),
}

async def socket_connected(
self,
connection_id: str,
headers: Dict[str, str],
):
"""Handle inbound socket message, with calback provided."""
http_body = {
"meta": callback_uris,
"meta": {
**self.callback_uris(connection_id),
"headers": headers,
"connection_id": connection_id,
},
}

if self._connect_uri:
Expand All @@ -37,11 +67,16 @@ async def socket_connected(self, callback_uris: dict):
LOGGER.debug("Response: %s", response)

async def inbound_socket_message(
self, callback_uris: dict, message: Union[str, bytes]
self,
connection_id: str,
message: Union[str, bytes],
):
"""Handle inbound socket message, with calback provided."""
http_body = {
"meta": callback_uris,
"meta": {
**self.callback_uris(connection_id),
"connection_id": connection_id,
},
"message": message.decode("utf-8") if isinstance(message, bytes) else message,
}

Expand All @@ -54,11 +89,15 @@ async def inbound_socket_message(
else:
LOGGER.debug("Response: %s", response)

async def socket_disconnected(self, bundle: dict):
async def socket_disconnected(self, connection_id: str):
"""Handle socket disconnected."""
async with aiohttp.ClientSession() as session:
LOGGER.info("Notifying of disconnect: %s %s", self._disconnect_uri, bundle)
async with session.post(self._disconnect_uri, json=bundle) as resp:
LOGGER.info(
"Notifying of disconnect: %s %s", self._disconnect_uri, connection_id
)
async with session.post(
self._disconnect_uri, json={"connection_id": connection_id}
) as resp:
response = await resp.text()
if resp.status != 200:
LOGGER.error("Error posting to disconnect uri: %s", response)
Expand Down
Loading

0 comments on commit d5d0fd2

Please sign in to comment.