Skip to content

Commit

Permalink
Add support for Litestar (#13)
Browse files Browse the repository at this point in the history
* Add ApitallyPlugin for Litestar

* Add identify_consumer_callback argument and fix validation error capture

* Add tests

* Fix test coverage

* Update readme

* Clean up

* Fix typing

* Add litestar to text matrix in CI

* Move start_time definition down

* Add filter_openapi_paths argument

* Use contextlib.suppress

* Improve typing for ApitallyClient

* Fix getting correct path from Request object
  • Loading branch information
itssimon authored Mar 5, 2024
1 parent e38abfa commit c52a34a
Show file tree
Hide file tree
Showing 7 changed files with 729 additions and 16 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ jobs:
- django-ninja django
- django-ninja==0.22.* django
- django-ninja==0.18.0 django
- litestar
- litestar==2.6.1
- litestar==2.0.1

steps:
- uses: actions/checkout@v4
Expand Down
25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ frameworks:
- [Flask](https://docs.apitally.io/frameworks/flask)
- [Django Ninja](https://docs.apitally.io/frameworks/django-ninja)
- [Django REST Framework](https://docs.apitally.io/frameworks/django-rest-framework)
- [Litestar](https://docs.apitally.io/frameworks/litestar)

Learn more about Apitally on our 🌎 [website](https://apitally.io) or check out
the 📚 [documentation](https://docs.apitally.io).
Expand All @@ -50,7 +51,8 @@ example:
pip install apitally[fastapi]
```

The available extras are: `fastapi`, `starlette`, `flask` and `django`.
The available extras are: `fastapi`, `starlette`, `flask`, `django` and
`litestar`.

## Usage

Expand Down Expand Up @@ -112,6 +114,27 @@ APITALLY_MIDDLEWARE = {
}
```

### Litestar

This is an example of how to add the Apitally plugin to a Litestar application.
For further instructions, see our
[setup guide for Litestar](https://docs.apitally.io/frameworks/litestar).

```python
from litestar import Litestar
from apitally.litestar import ApitallyPlugin

app = Litestar(
route_handlers=[...],
plugins=[
ApitallyPlugin(
client_id="your-client-id",
env="dev", # or "prod" etc.
),
]
)
```

## Getting help

If you need help please
Expand Down
7 changes: 4 additions & 3 deletions apitally/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
import threading
import time
from abc import ABC
from collections import Counter
from dataclasses import dataclass
from math import floor
Expand All @@ -26,16 +27,16 @@
TApitallyClient = TypeVar("TApitallyClient", bound="ApitallyClientBase")


class ApitallyClientBase:
class ApitallyClientBase(ABC):
_instance: Optional[ApitallyClientBase] = None
_lock = threading.Lock()

def __new__(cls, *args, **kwargs) -> ApitallyClientBase:
def __new__(cls: Type[TApitallyClient], *args, **kwargs) -> TApitallyClient:
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
return cast(TApitallyClient, cls._instance)

def __init__(self, client_id: str, env: str) -> None:
if hasattr(self, "client_id"):
Expand Down
186 changes: 186 additions & 0 deletions apitally/litestar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import contextlib
import json
import sys
import time
from importlib.metadata import version
from typing import Callable, Dict, List, Optional

from litestar.app import DEFAULT_OPENAPI_CONFIG, Litestar
from litestar.config.app import AppConfig
from litestar.connection import Request
from litestar.datastructures import Headers
from litestar.enums import ScopeType
from litestar.handlers import HTTPRouteHandler
from litestar.plugins import InitPluginProtocol
from litestar.types import ASGIApp, Message, Receive, Scope, Send

from apitally.client.asyncio import ApitallyClient


__all__ = ["ApitallyPlugin"]


class ApitallyPlugin(InitPluginProtocol):
def __init__(
self,
client_id: str,
env: str = "dev",
app_version: Optional[str] = None,
filter_openapi_paths: bool = True,
identify_consumer_callback: Optional[Callable[[Request], Optional[str]]] = None,
) -> None:
self.client = ApitallyClient(client_id=client_id, env=env)
self.app_version = app_version
self.filter_openapi_paths = filter_openapi_paths
self.identify_consumer_callback = identify_consumer_callback
self.openapi_path: Optional[str] = None

def on_app_init(self, app_config: AppConfig) -> AppConfig:
app_config.on_startup.append(self.on_startup)
app_config.middleware.append(self.middleware_factory)
return app_config

def on_startup(self, app: Litestar) -> None:
openapi_config = app.openapi_config or DEFAULT_OPENAPI_CONFIG
self.openapi_path = openapi_config.openapi_controller.path

app_info = {
"openapi": _get_openapi(app),
"paths": [route for route in _get_routes(app) if not self.filter_path(route["path"])],
"versions": _get_versions(self.app_version),
"client": "python:litestar",
}
self.client.set_app_info(app_info)
self.client.start_sync_loop()

def middleware_factory(self, app: ASGIApp) -> ASGIApp:
async def middleware(scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http" and scope["method"] != "OPTIONS":
request = Request(scope)
response_status = 0
response_time = 0.0
response_headers = Headers()
response_body = b""
start_time = time.perf_counter()

async def send_wrapper(message: Message) -> None:
nonlocal response_time, response_status, response_headers, response_body
if message["type"] == "http.response.start":
response_time = time.perf_counter() - start_time
response_status = message["status"]
response_headers = Headers(message["headers"])
elif message["type"] == "http.response.body" and response_status == 400:
response_body += message["body"]
await send(message)

await app(scope, receive, send_wrapper)
self.add_request(
request=request,
response_status=response_status,
response_time=response_time,
response_headers=response_headers,
response_body=response_body,
)
else:
await app(scope, receive, send) # pragma: no cover

return middleware

def add_request(
self,
request: Request,
response_status: int,
response_time: float,
response_headers: Headers,
response_body: bytes,
) -> None:
if response_status < 100 or not request.route_handler.paths:
return # pragma: no cover
path = self.get_path(request)
if path is None or self.filter_path(path):
return
consumer = self.get_consumer(request)
self.client.request_counter.add_request(
consumer=consumer,
method=request.method,
path=path,
status_code=response_status,
response_time=response_time,
request_size=request.headers.get("Content-Length"),
response_size=response_headers.get("Content-Length"),
)
if response_status == 400 and response_body and len(response_body) < 4096:
with contextlib.suppress(json.JSONDecodeError):
parsed_body = json.loads(response_body)
if (
isinstance(parsed_body, dict)
and "detail" in parsed_body
and isinstance(parsed_body["detail"], str)
and "validation" in parsed_body["detail"].lower()
and "extra" in parsed_body
and isinstance(parsed_body["extra"], list)
):
self.client.validation_error_counter.add_validation_errors(
consumer=consumer,
method=request.method,
path=path,
detail=[
{
"loc": [error.get("source", "body")] + error["key"].split("."),
"msg": error["message"],
"type": "",
}
for error in parsed_body["extra"]
if "key" in error and "message" in error
],
)

def get_path(self, request: Request) -> Optional[str]:
path: List[str] = []
for layer in request.route_handler.ownership_layers:
if isinstance(layer, HTTPRouteHandler):
if len(layer.paths) == 0:
return None # pragma: no cover
path.append(list(layer.paths)[0].lstrip("/"))
else:
path.append(layer.path.lstrip("/"))
return "/" + "/".join(filter(None, path))

def filter_path(self, path: str) -> bool:
if self.filter_openapi_paths and self.openapi_path:
return path == self.openapi_path or path.startswith(self.openapi_path + "/")
return False # pragma: no cover

def get_consumer(self, request: Request) -> Optional[str]:
if hasattr(request.state, "consumer_identifier"):
return str(request.state.consumer_identifier)
if self.identify_consumer_callback is not None:
consumer_identifier = self.identify_consumer_callback(request)
if consumer_identifier is not None:
return str(consumer_identifier)
return None


def _get_openapi(app: Litestar) -> str:
schema = app.openapi_schema.to_schema()
return json.dumps(schema)


def _get_routes(app: Litestar) -> List[Dict[str, str]]:
return [
{"method": method, "path": route.path}
for route in app.routes
for method in route.methods
if route.scope_type == ScopeType.HTTP and method != "OPTIONS"
]


def _get_versions(app_version: Optional[str]) -> Dict[str, str]:
versions = {
"python": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
"apitally": version("apitally"),
"litestar": version("litestar"),
}
if app_version:
versions["app"] = app_version
return versions
Loading

0 comments on commit c52a34a

Please sign in to comment.