Skip to content

Commit

Permalink
Use pure ASGI middleware for Starlette (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
itssimon authored Sep 20, 2024
1 parent 67b0bd9 commit ff12b26
Showing 1 changed file with 65 additions and 68 deletions.
133 changes: 65 additions & 68 deletions apitally/starlette.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,29 @@
from __future__ import annotations

import asyncio
import contextlib
import json
import time
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from warnings import warn

from httpx import HTTPStatusError
from starlette.concurrency import iterate_in_threadpool
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.datastructures import Headers
from starlette.requests import Request
from starlette.routing import BaseRoute, Match, Router
from starlette.schemas import EndpointInfo, SchemaGenerator
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
from starlette.testclient import TestClient
from starlette.types import ASGIApp
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from apitally.client.asyncio import ApitallyClient
from apitally.client.base import Consumer as ApitallyConsumer
from apitally.common import get_versions


if TYPE_CHECKING:
from starlette.middleware.base import RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response


__all__ = ["ApitallyMiddleware", "ApitallyConsumer"]


class ApitallyMiddleware(BaseHTTPMiddleware):
class ApitallyMiddleware:
def __init__(
self,
app: ASGIApp,
Expand All @@ -40,14 +34,14 @@ def __init__(
filter_unhandled_paths: bool = True,
identify_consumer_callback: Optional[Callable[[Request], Union[str, ApitallyConsumer, None]]] = None,
) -> None:
self.app = app
self.filter_unhandled_paths = filter_unhandled_paths
self.identify_consumer_callback = identify_consumer_callback
self.client = ApitallyClient(client_id=client_id, env=env)
self.client.start_sync_loop()
self._delayed_set_startup_data_task: Optional[asyncio.Task] = None
self.delayed_set_startup_data(app_version, openapi_url)
_register_shutdown_handler(app, self.client.handle_shutdown)
super().__init__(app)

def delayed_set_startup_data(self, app_version: Optional[str] = None, openapi_url: Optional[str] = None) -> None:
self._delayed_set_startup_data_task = asyncio.create_task(
Expand All @@ -61,87 +55,90 @@ async def _delayed_set_startup_data(
data = _get_startup_data(self.app, app_version, openapi_url)
self.client.set_startup_data(data)

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
try:
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http" and scope["method"] != "OPTIONS":
request = Request(scope)
response_status = 0
response_time = 0.0
response_headers = Headers()
response_body = b""
start_time = time.perf_counter()
response = await call_next(request)
except BaseException as e:
await self.add_request(
request=request,
response=None,
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
response_time=time.perf_counter() - start_time,
exception=e,
)
raise e from None

async def send_wrapper(message: Message) -> None:
nonlocal response_time, response_status, response_headers, response_body
if message["type"] == "http.response.start":
response_time = time.perf_counter() - start_time
response_status = message["status"]
response_headers = Headers(scope=message)
elif message["type"] == "http.response.body" and response_status == 422:
response_body += message["body"]
await send(message)

try:
await self.app(scope, receive, send_wrapper)
except BaseException as e:
self.add_request(
request=request,
response_status=500,
response_time=time.perf_counter() - start_time,
response_headers=response_headers,
response_body=response_body,
exception=e,
)
raise e from None
else:
self.add_request(
request=request,
response_status=response_status,
response_time=response_time,
response_headers=response_headers,
response_body=response_body,
)
else:
await self.add_request(
request=request,
response=response,
status_code=response.status_code,
response_time=time.perf_counter() - start_time,
)
return response
await self.app(scope, receive, send) # pragma: no cover

async def add_request(
def add_request(
self,
request: Request,
response: Optional[Response],
status_code: int,
response_status: int,
response_time: float,
response_headers: Headers,
response_body: bytes,
exception: Optional[BaseException] = None,
) -> None:
path_template, is_handled_path = self.get_path_template(request)
if (is_handled_path or not self.filter_unhandled_paths) and request.method != "OPTIONS":
if is_handled_path or not self.filter_unhandled_paths:
consumer = self.get_consumer(request)
consumer_identifier = consumer.identifier if consumer else None
self.client.consumer_registry.add_or_update_consumer(consumer)
self.client.request_counter.add_request(
consumer=consumer_identifier,
method=request.method,
path=path_template,
status_code=status_code,
status_code=response_status,
response_time=response_time,
request_size=request.headers.get("Content-Length"),
response_size=response.headers.get("Content-Length") if response is not None else None,
response_size=response_headers.get("Content-Length"),
)
if (
status_code == 422
and response is not None
and response.headers.get("Content-Type") == "application/json"
):
body = await self.get_response_json(response)
if isinstance(body, dict) and "detail" in body and isinstance(body["detail"], list):
# Log FastAPI / Pydantic validation errors
self.client.validation_error_counter.add_validation_errors(
consumer=consumer_identifier,
method=request.method,
path=path_template,
detail=body["detail"],
)
if status_code == 500 and exception is not None:
if response_status == 422 and response_body and response_headers.get("Content-Type") == "application/json":
with contextlib.suppress(json.JSONDecodeError):
body = json.loads(response_body)
if isinstance(body, dict) and "detail" in body and isinstance(body["detail"], list):
# Log FastAPI / Pydantic validation errors
self.client.validation_error_counter.add_validation_errors(
consumer=consumer_identifier,
method=request.method,
path=path_template,
detail=body["detail"],
)
if response_status == 500 and exception is not None:
self.client.server_error_counter.add_server_error(
consumer=consumer_identifier,
method=request.method,
path=path_template,
exception=exception,
)

@staticmethod
async def get_response_json(response: Response) -> Any:
if hasattr(response, "body"):
try:
return json.loads(response.body)
except json.JSONDecodeError: # pragma: no cover
return None
elif hasattr(response, "body_iterator"):
try:
response_body = [section async for section in response.body_iterator]
response.body_iterator = iterate_in_threadpool(iter(response_body))
return json.loads(b"".join(response_body))
except json.JSONDecodeError: # pragma: no cover
return None

@staticmethod
def get_path_template(request: Request) -> Tuple[str, bool]:
for route in request.app.routes:
Expand Down

0 comments on commit ff12b26

Please sign in to comment.