Skip to content

Commit

Permalink
Identify consumers (#6)
Browse files Browse the repository at this point in the history
* Add consumer identification

* Increase test coverage

* Fix test

* Rename identify_consumer_func to ..._callback
  • Loading branch information
itssimon authored Sep 6, 2023
1 parent 1c92428 commit 70c52bb
Show file tree
Hide file tree
Showing 18 changed files with 241 additions and 120 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,5 @@ from flask import Flask
from apitally.flask import ApitallyMiddleware

app = Flask(__name__)
app.wsgi_app = ApitallyMiddleware(app.wsgi_app, client_id="<your-client-id>")
app.wsgi_app = ApitallyMiddleware(app, client_id="<your-client-id>")
```
2 changes: 1 addition & 1 deletion apitally/client/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def get_keys(self, client: httpx.AsyncClient) -> None:
if response_data := await self._get_keys(client): # Response data can be None if backoff gives up
self.handle_keys_response(response_data)
elif self.key_registry.salt is None:
logger.error("Initial Apitally key sync failed")
logger.error("Initial Apitally API key sync failed")
# Exit because the application will not be able to authenticate requests
sys.exit(1)

Expand Down
20 changes: 17 additions & 3 deletions apitally/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def handle_keys_response(self, response_data: Dict[str, Any]) -> None:

@dataclass(frozen=True)
class RequestInfo:
consumer: Optional[str]
method: str
path: str
status_code: int
Expand All @@ -105,8 +106,15 @@ def __init__(self) -> None:
self.response_times: Dict[RequestInfo, Counter[int]] = {}
self._lock = threading.Lock()

def log_request(self, method: str, path: str, status_code: int, response_time: float) -> None:
request_info = RequestInfo(method=method.upper(), path=path, status_code=status_code)
def log_request(
self, consumer: Optional[str], method: str, path: str, status_code: int, response_time: float
) -> None:
request_info = RequestInfo(
consumer=consumer,
method=method.upper(),
path=path,
status_code=status_code,
)
response_time_ms_bin = int(floor(response_time / 0.01) * 10) # In ms, rounded down to nearest 10ms
with self._lock:
self.request_counts[request_info] += 1
Expand All @@ -118,6 +126,7 @@ def get_and_reset_requests(self) -> List[Dict[str, Any]]:
for request_info, count in self.request_counts.items():
data.append(
{
"consumer": request_info.consumer,
"method": request_info.method,
"path": request_info.path,
"status_code": request_info.status_code,
Expand All @@ -132,6 +141,7 @@ def get_and_reset_requests(self) -> List[Dict[str, Any]]:

@dataclass(frozen=True)
class ValidationError:
consumer: Optional[str]
method: str
path: str
loc: Tuple[str, ...]
Expand All @@ -144,11 +154,14 @@ def __init__(self) -> None:
self.error_counts: Counter[ValidationError] = Counter()
self._lock = threading.Lock()

def log_validation_errors(self, method: str, path: str, detail: List[Dict[str, Any]]) -> None:
def log_validation_errors(
self, consumer: Optional[str], method: str, path: str, detail: List[Dict[str, Any]]
) -> None:
with self._lock:
for error in detail:
try:
validation_error = ValidationError(
consumer=consumer,
method=method.upper(),
path=path,
loc=tuple(str(loc) for loc in error["loc"]),
Expand All @@ -165,6 +178,7 @@ def get_and_reset_validation_errors(self) -> List[Dict[str, Any]]:
for validation_error, count in self.error_counts.items():
data.append(
{
"consumer": validation_error.consumer,
"method": validation_error.method,
"path": validation_error.path,
"loc": validation_error.loc,
Expand Down
2 changes: 1 addition & 1 deletion apitally/client/threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_keys(self, session: requests.Session) -> None:
if response_data := self._get_keys(session): # Response data can be None if backoff gives up
self.handle_keys_response(response_data)
elif self.key_registry.salt is None:
logger.error("Initial Apitally key sync failed")
logger.error("Initial Apitally API key sync failed")
# Exit because the application will not be able to authenticate requests
sys.exit(1)

Expand Down
23 changes: 22 additions & 1 deletion apitally/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from django.core.exceptions import ViewDoesNotExist
from django.test import RequestFactory
from django.urls import URLPattern, URLResolver, get_resolver, resolve
from django.utils.module_loading import import_string

import apitally
from apitally.client.base import KeyInfo
from apitally.client.threading import ApitallyClient


Expand All @@ -31,6 +33,7 @@ class ApitallyMiddlewareConfig:
sync_api_keys: bool
sync_interval: float
openapi_url: Optional[str]
identify_consumer_callback: Optional[Callable[[HttpRequest], Optional[str]]]


class ApitallyMiddleware:
Expand Down Expand Up @@ -67,6 +70,7 @@ def configure(
sync_api_keys: bool = False,
sync_interval: float = 60,
openapi_url: Optional[str] = None,
identify_consumer_callback: Optional[str] = None,
) -> None:
cls.config = ApitallyMiddlewareConfig(
client_id=client_id,
Expand All @@ -75,14 +79,19 @@ def configure(
sync_api_keys=sync_api_keys,
sync_interval=sync_interval,
openapi_url=openapi_url,
identify_consumer_callback=import_string(identify_consumer_callback)
if identify_consumer_callback
else None,
)

def __call__(self, request: HttpRequest) -> HttpResponse:
view = self.get_view(request)
start_time = time.perf_counter()
response = self.get_response(request)
if request.method is not None and view is not None and view.is_api_view:
consumer = self.get_consumer(request)
self.client.request_logger.log_request(
consumer=consumer,
method=request.method,
path=view.pattern,
status_code=response.status_code,
Expand All @@ -98,18 +107,30 @@ def __call__(self, request: HttpRequest) -> HttpResponse:
if isinstance(body, dict) and "detail" in body and isinstance(body["detail"], list):
# Log Django Ninja / Pydantic validation errors
self.client.validation_error_logger.log_validation_errors(
consumer=consumer,
method=request.method,
path=view.pattern,
detail=body["detail"],
)
except json.JSONDecodeError:
except json.JSONDecodeError: # pragma: no cover
pass
return response

def get_view(self, request: HttpRequest) -> Optional[DjangoViewInfo]:
resolver_match = resolve(request.path_info)
return next((view for view in self.views if view.pattern == resolver_match.route), None)

def get_consumer(self, request: HttpRequest) -> Optional[str]:
if hasattr(request, "consumer_identifier"):
return str(request.consumer_identifier)
if self.config is not None and self.config.identify_consumer_callback is not None:
consumer_identifier = self.config.identify_consumer_callback(request)
if consumer_identifier is not None:
return str(consumer_identifier)
if hasattr(request, "auth") and isinstance(request.auth, KeyInfo):
return f"key:{request.auth.key_id}"
return None


@dataclass
class DjangoViewInfo:
Expand Down
2 changes: 1 addition & 1 deletion apitally/django_ninja.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _get_api(views: List[DjangoViewInfo]) -> NinjaAPI:
return next(
(view.func.__self__.api for view in views if view.is_ninja_path_view and hasattr(view.func, "__self__"))
)
except StopIteration:
except StopIteration: # pragma: no cover
raise RuntimeError("Could not find NinjaAPI instance")


Expand Down
2 changes: 2 additions & 0 deletions apitally/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O
status_code=HTTP_403_FORBIDDEN,
detail="Permission denied",
)
if key_info is not None:
request.state.key_info = key_info
return key_info


Expand Down
54 changes: 24 additions & 30 deletions apitally/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple

import flask
from flask import g, make_response, request
from flask import Flask, g, make_response, request
from werkzeug.exceptions import NotFound
from werkzeug.test import Client

import apitally
from apitally.client.base import KeyInfo
from apitally.client.threading import ApitallyClient


Expand All @@ -26,23 +27,17 @@
class ApitallyMiddleware:
def __init__(
self,
app: WSGIApplication,
app: Flask,
client_id: str,
env: str = "default",
app_version: Optional[str] = None,
sync_api_keys: bool = False,
sync_interval: float = 60,
openapi_url: Optional[str] = None,
url_map: Optional[Map] = None,
filter_unhandled_paths: bool = True,
) -> None:
url_map = url_map or _get_url_map(app)
if url_map is None: # pragma: no cover
raise ValueError(
"Could not extract url_map from app. Please provide it as an argument to ApitallyMiddleware."
)
self.app = app
self.url_map = url_map
self.wsgi_app = app.wsgi_app
self.filter_unhandled_paths = filter_unhandled_paths
self.client = ApitallyClient(
client_id=client_id, env=env, sync_api_keys=sync_api_keys, sync_interval=sync_interval
Expand All @@ -54,7 +49,7 @@ def __init__(
timer.start()

def delayed_send_app_info(self, app_version: Optional[str] = None, openapi_url: Optional[str] = None) -> None:
app_info = _get_app_info(self.app, self.url_map, app_version, openapi_url)
app_info = _get_app_info(self.app, app_version, openapi_url)
self.client.send_app_info(app_info=app_info)

def __call__(self, environ: WSGIEnvironment, start_response: StartResponse) -> Iterable[bytes]:
Expand All @@ -66,33 +61,42 @@ def catching_start_response(status: str, headers, exc_info=None):
return start_response(status, headers, exc_info)

start_time = time.perf_counter()
response = self.app(environ, catching_start_response)
self.log_request(
environ=environ,
status_code=status_code,
response_time=time.perf_counter() - start_time,
)
with self.app.app_context():
response = self.wsgi_app(environ, catching_start_response)
self.log_request(
environ=environ,
status_code=status_code,
response_time=time.perf_counter() - start_time,
)
return response

def log_request(self, environ: WSGIEnvironment, status_code: int, response_time: float) -> None:
rule, is_handled_path = self.get_rule(environ)
if is_handled_path or not self.filter_unhandled_paths:
self.client.request_logger.log_request(
consumer=self.get_consumer(),
method=environ["REQUEST_METHOD"],
path=rule,
status_code=status_code,
response_time=response_time,
)

def get_rule(self, environ: WSGIEnvironment) -> Tuple[str, bool]:
url_adapter = self.url_map.bind_to_environ(environ)
url_adapter = self.app.url_map.bind_to_environ(environ)
try:
endpoint, _ = url_adapter.match()
rule = self.url_map._rules_by_endpoint[endpoint][0]
rule = self.app.url_map._rules_by_endpoint[endpoint][0]
return rule.rule, True
except NotFound:
return environ["PATH_INFO"], False

def get_consumer(self) -> Optional[str]:
if "consumer_identifier" in g:
return str(g.consumer_identifier)
if "key_info" in g and isinstance(g.key_info, KeyInfo):
return f"key:{g.key_info.key_id}"
return None


def require_api_key(func=None, *, scopes: Optional[List[str]] = None, custom_header: Optional[str] = None):
def decorator(func):
Expand Down Expand Up @@ -123,28 +127,18 @@ def wrapped_func(*args, **kwargs):
return decorator if func is None else decorator(func)


def _get_app_info(
app: WSGIApplication, url_map: Map, app_version: Optional[str] = None, openapi_url: Optional[str] = None
) -> Dict[str, Any]:
def _get_app_info(app: Flask, app_version: Optional[str] = None, openapi_url: Optional[str] = None) -> Dict[str, Any]:
app_info: Dict[str, Any] = {}
if openapi_url and (openapi := _get_openapi(app, openapi_url)):
app_info["openapi"] = openapi
elif paths := _get_paths(url_map):
elif paths := _get_paths(app.url_map):
app_info["paths"] = paths
app_info["versions"] = _get_versions(app_version)
app_info["client"] = "apitally-python"
app_info["framework"] = "flask"
return app_info


def _get_url_map(app: WSGIApplication) -> Optional[Map]:
if hasattr(app, "url_map"):
return app.url_map
elif hasattr(app, "__self__") and hasattr(app.__self__, "url_map"):
return app.__self__.url_map
return None


def _get_paths(url_map: Map) -> List[Dict[str, str]]:
return [
{"path": rule.rule, "method": method}
Expand Down
Loading

0 comments on commit 70c52bb

Please sign in to comment.