diff --git a/README.md b/README.md index fc0af8b..70cb046 100644 --- a/README.md +++ b/README.md @@ -81,5 +81,5 @@ from flask import Flask from apitally.flask import ApitallyMiddleware app = Flask(__name__) -app.wsgi_app = ApitallyMiddleware(app.wsgi_app, client_id="") +app.wsgi_app = ApitallyMiddleware(app, client_id="") ``` diff --git a/apitally/client/asyncio.py b/apitally/client/asyncio.py index 70ac113..8b48b42 100644 --- a/apitally/client/asyncio.py +++ b/apitally/client/asyncio.py @@ -66,7 +66,7 @@ async def get_keys(self, client: httpx.AsyncClient) -> None: if response_data := await self._get_keys(client): # Response data can be None if backoff gives up self.handle_keys_response(response_data) elif self.key_registry.salt is None: - logger.error("Initial Apitally key sync failed") + logger.error("Initial Apitally API key sync failed") # Exit because the application will not be able to authenticate requests sys.exit(1) diff --git a/apitally/client/base.py b/apitally/client/base.py index d52bbb9..8f5f576 100644 --- a/apitally/client/base.py +++ b/apitally/client/base.py @@ -94,6 +94,7 @@ def handle_keys_response(self, response_data: Dict[str, Any]) -> None: @dataclass(frozen=True) class RequestInfo: + consumer: Optional[str] method: str path: str status_code: int @@ -105,8 +106,15 @@ def __init__(self) -> None: self.response_times: Dict[RequestInfo, Counter[int]] = {} self._lock = threading.Lock() - def log_request(self, method: str, path: str, status_code: int, response_time: float) -> None: - request_info = RequestInfo(method=method.upper(), path=path, status_code=status_code) + def log_request( + self, consumer: Optional[str], method: str, path: str, status_code: int, response_time: float + ) -> None: + request_info = RequestInfo( + consumer=consumer, + method=method.upper(), + path=path, + status_code=status_code, + ) response_time_ms_bin = int(floor(response_time / 0.01) * 10) # In ms, rounded down to nearest 10ms with self._lock: self.request_counts[request_info] += 1 @@ -118,6 +126,7 @@ def get_and_reset_requests(self) -> List[Dict[str, Any]]: for request_info, count in self.request_counts.items(): data.append( { + "consumer": request_info.consumer, "method": request_info.method, "path": request_info.path, "status_code": request_info.status_code, @@ -132,6 +141,7 @@ def get_and_reset_requests(self) -> List[Dict[str, Any]]: @dataclass(frozen=True) class ValidationError: + consumer: Optional[str] method: str path: str loc: Tuple[str, ...] @@ -144,11 +154,14 @@ def __init__(self) -> None: self.error_counts: Counter[ValidationError] = Counter() self._lock = threading.Lock() - def log_validation_errors(self, method: str, path: str, detail: List[Dict[str, Any]]) -> None: + def log_validation_errors( + self, consumer: Optional[str], method: str, path: str, detail: List[Dict[str, Any]] + ) -> None: with self._lock: for error in detail: try: validation_error = ValidationError( + consumer=consumer, method=method.upper(), path=path, loc=tuple(str(loc) for loc in error["loc"]), @@ -165,6 +178,7 @@ def get_and_reset_validation_errors(self) -> List[Dict[str, Any]]: for validation_error, count in self.error_counts.items(): data.append( { + "consumer": validation_error.consumer, "method": validation_error.method, "path": validation_error.path, "loc": validation_error.loc, diff --git a/apitally/client/threading.py b/apitally/client/threading.py index ff43389..4172b21 100644 --- a/apitally/client/threading.py +++ b/apitally/client/threading.py @@ -83,7 +83,7 @@ def get_keys(self, session: requests.Session) -> None: if response_data := self._get_keys(session): # Response data can be None if backoff gives up self.handle_keys_response(response_data) elif self.key_registry.salt is None: - logger.error("Initial Apitally key sync failed") + logger.error("Initial Apitally API key sync failed") # Exit because the application will not be able to authenticate requests sys.exit(1) diff --git a/apitally/django.py b/apitally/django.py index 09c0d13..2531851 100644 --- a/apitally/django.py +++ b/apitally/django.py @@ -11,8 +11,10 @@ from django.core.exceptions import ViewDoesNotExist from django.test import RequestFactory from django.urls import URLPattern, URLResolver, get_resolver, resolve +from django.utils.module_loading import import_string import apitally +from apitally.client.base import KeyInfo from apitally.client.threading import ApitallyClient @@ -31,6 +33,7 @@ class ApitallyMiddlewareConfig: sync_api_keys: bool sync_interval: float openapi_url: Optional[str] + identify_consumer_callback: Optional[Callable[[HttpRequest], Optional[str]]] class ApitallyMiddleware: @@ -67,6 +70,7 @@ def configure( sync_api_keys: bool = False, sync_interval: float = 60, openapi_url: Optional[str] = None, + identify_consumer_callback: Optional[str] = None, ) -> None: cls.config = ApitallyMiddlewareConfig( client_id=client_id, @@ -75,6 +79,9 @@ def configure( sync_api_keys=sync_api_keys, sync_interval=sync_interval, openapi_url=openapi_url, + identify_consumer_callback=import_string(identify_consumer_callback) + if identify_consumer_callback + else None, ) def __call__(self, request: HttpRequest) -> HttpResponse: @@ -82,7 +89,9 @@ def __call__(self, request: HttpRequest) -> HttpResponse: start_time = time.perf_counter() response = self.get_response(request) if request.method is not None and view is not None and view.is_api_view: + consumer = self.get_consumer(request) self.client.request_logger.log_request( + consumer=consumer, method=request.method, path=view.pattern, status_code=response.status_code, @@ -98,11 +107,12 @@ def __call__(self, request: HttpRequest) -> HttpResponse: if isinstance(body, dict) and "detail" in body and isinstance(body["detail"], list): # Log Django Ninja / Pydantic validation errors self.client.validation_error_logger.log_validation_errors( + consumer=consumer, method=request.method, path=view.pattern, detail=body["detail"], ) - except json.JSONDecodeError: + except json.JSONDecodeError: # pragma: no cover pass return response @@ -110,6 +120,17 @@ def get_view(self, request: HttpRequest) -> Optional[DjangoViewInfo]: resolver_match = resolve(request.path_info) return next((view for view in self.views if view.pattern == resolver_match.route), None) + def get_consumer(self, request: HttpRequest) -> Optional[str]: + if hasattr(request, "consumer_identifier"): + return str(request.consumer_identifier) + if self.config is not None and self.config.identify_consumer_callback is not None: + consumer_identifier = self.config.identify_consumer_callback(request) + if consumer_identifier is not None: + return str(consumer_identifier) + if hasattr(request, "auth") and isinstance(request.auth, KeyInfo): + return f"key:{request.auth.key_id}" + return None + @dataclass class DjangoViewInfo: diff --git a/apitally/django_ninja.py b/apitally/django_ninja.py index 13fc1e6..426dd9a 100644 --- a/apitally/django_ninja.py +++ b/apitally/django_ninja.py @@ -69,7 +69,7 @@ def _get_api(views: List[DjangoViewInfo]) -> NinjaAPI: return next( (view.func.__self__.api for view in views if view.is_ninja_path_view and hasattr(view.func, "__self__")) ) - except StopIteration: + except StopIteration: # pragma: no cover raise RuntimeError("Could not find NinjaAPI instance") diff --git a/apitally/fastapi.py b/apitally/fastapi.py index 541eb42..9526503 100644 --- a/apitally/fastapi.py +++ b/apitally/fastapi.py @@ -69,6 +69,8 @@ async def __call__(self, request: Request, security_scopes: SecurityScopes) -> O status_code=HTTP_403_FORBIDDEN, detail="Permission denied", ) + if key_info is not None: + request.state.key_info = key_info return key_info diff --git a/apitally/flask.py b/apitally/flask.py index 8c36ae2..88cc261 100644 --- a/apitally/flask.py +++ b/apitally/flask.py @@ -7,11 +7,12 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple import flask -from flask import g, make_response, request +from flask import Flask, g, make_response, request from werkzeug.exceptions import NotFound from werkzeug.test import Client import apitally +from apitally.client.base import KeyInfo from apitally.client.threading import ApitallyClient @@ -26,23 +27,17 @@ class ApitallyMiddleware: def __init__( self, - app: WSGIApplication, + app: Flask, client_id: str, env: str = "default", app_version: Optional[str] = None, sync_api_keys: bool = False, sync_interval: float = 60, openapi_url: Optional[str] = None, - url_map: Optional[Map] = None, filter_unhandled_paths: bool = True, ) -> None: - url_map = url_map or _get_url_map(app) - if url_map is None: # pragma: no cover - raise ValueError( - "Could not extract url_map from app. Please provide it as an argument to ApitallyMiddleware." - ) self.app = app - self.url_map = url_map + self.wsgi_app = app.wsgi_app self.filter_unhandled_paths = filter_unhandled_paths self.client = ApitallyClient( client_id=client_id, env=env, sync_api_keys=sync_api_keys, sync_interval=sync_interval @@ -54,7 +49,7 @@ def __init__( timer.start() def delayed_send_app_info(self, app_version: Optional[str] = None, openapi_url: Optional[str] = None) -> None: - app_info = _get_app_info(self.app, self.url_map, app_version, openapi_url) + app_info = _get_app_info(self.app, app_version, openapi_url) self.client.send_app_info(app_info=app_info) def __call__(self, environ: WSGIEnvironment, start_response: StartResponse) -> Iterable[bytes]: @@ -66,18 +61,20 @@ def catching_start_response(status: str, headers, exc_info=None): return start_response(status, headers, exc_info) start_time = time.perf_counter() - response = self.app(environ, catching_start_response) - self.log_request( - environ=environ, - status_code=status_code, - response_time=time.perf_counter() - start_time, - ) + with self.app.app_context(): + response = self.wsgi_app(environ, catching_start_response) + self.log_request( + environ=environ, + status_code=status_code, + response_time=time.perf_counter() - start_time, + ) return response def log_request(self, environ: WSGIEnvironment, status_code: int, response_time: float) -> None: rule, is_handled_path = self.get_rule(environ) if is_handled_path or not self.filter_unhandled_paths: self.client.request_logger.log_request( + consumer=self.get_consumer(), method=environ["REQUEST_METHOD"], path=rule, status_code=status_code, @@ -85,14 +82,21 @@ def log_request(self, environ: WSGIEnvironment, status_code: int, response_time: ) def get_rule(self, environ: WSGIEnvironment) -> Tuple[str, bool]: - url_adapter = self.url_map.bind_to_environ(environ) + url_adapter = self.app.url_map.bind_to_environ(environ) try: endpoint, _ = url_adapter.match() - rule = self.url_map._rules_by_endpoint[endpoint][0] + rule = self.app.url_map._rules_by_endpoint[endpoint][0] return rule.rule, True except NotFound: return environ["PATH_INFO"], False + def get_consumer(self) -> Optional[str]: + if "consumer_identifier" in g: + return str(g.consumer_identifier) + if "key_info" in g and isinstance(g.key_info, KeyInfo): + return f"key:{g.key_info.key_id}" + return None + def require_api_key(func=None, *, scopes: Optional[List[str]] = None, custom_header: Optional[str] = None): def decorator(func): @@ -123,13 +127,11 @@ def wrapped_func(*args, **kwargs): return decorator if func is None else decorator(func) -def _get_app_info( - app: WSGIApplication, url_map: Map, app_version: Optional[str] = None, openapi_url: Optional[str] = None -) -> Dict[str, Any]: +def _get_app_info(app: Flask, app_version: Optional[str] = None, openapi_url: Optional[str] = None) -> Dict[str, Any]: app_info: Dict[str, Any] = {} if openapi_url and (openapi := _get_openapi(app, openapi_url)): app_info["openapi"] = openapi - elif paths := _get_paths(url_map): + elif paths := _get_paths(app.url_map): app_info["paths"] = paths app_info["versions"] = _get_versions(app_version) app_info["client"] = "apitally-python" @@ -137,14 +139,6 @@ def _get_app_info( return app_info -def _get_url_map(app: WSGIApplication) -> Optional[Map]: - if hasattr(app, "url_map"): - return app.url_map - elif hasattr(app, "__self__") and hasattr(app.__self__, "url_map"): - return app.__self__.url_map - return None - - def _get_paths(url_map: Map) -> List[Dict[str, str]]: return [ {"path": rule.rule, "method": method} diff --git a/apitally/starlette.py b/apitally/starlette.py index 0c50446..6168f06 100644 --- a/apitally/starlette.py +++ b/apitally/starlette.py @@ -3,7 +3,7 @@ import json import sys import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import starlette from httpx import HTTPStatusError @@ -47,8 +47,10 @@ def __init__( sync_interval: float = 60, openapi_url: Optional[str] = "/openapi.json", filter_unhandled_paths: bool = True, + identify_consumer_callback: Optional[Callable[[Request], Optional[str]]] = None, ) -> None: self.filter_unhandled_paths = filter_unhandled_paths + self.identify_consumer_callback = identify_consumer_callback self.client = ApitallyClient( client_id=client_id, env=env, sync_api_keys=sync_api_keys, sync_interval=sync_interval ) @@ -82,7 +84,9 @@ async def log_request( ) -> None: path_template, is_handled_path = self.get_path_template(request) if is_handled_path or not self.filter_unhandled_paths: + consumer = self.get_consumer(request) self.client.request_logger.log_request( + consumer=consumer, method=request.method, path=path_template, status_code=status_code, @@ -93,31 +97,29 @@ async def log_request( and response is not None and response.headers.get("Content-Type") == "application/json" ): - try: - body = await self.get_response_json(response) - if isinstance(body, dict) and "detail" in body and isinstance(body["detail"], list): - # Log FastAPI / Pydantic validation errors - self.client.validation_error_logger.log_validation_errors( - method=request.method, - path=path_template, - detail=body["detail"], - ) - except json.JSONDecodeError: - pass + body = await self.get_response_json(response) + if isinstance(body, dict) and "detail" in body and isinstance(body["detail"], list): + # Log FastAPI / Pydantic validation errors + self.client.validation_error_logger.log_validation_errors( + consumer=consumer, + method=request.method, + path=path_template, + detail=body["detail"], + ) @staticmethod async def get_response_json(response: Response) -> Any: if hasattr(response, "body"): try: return json.loads(response.body) - except json.JSONDecodeError: + except json.JSONDecodeError: # pragma: no cover return None elif hasattr(response, "body_iterator"): try: response_body = [section async for section in response.body_iterator] response.body_iterator = iterate_in_threadpool(iter(response_body)) return json.loads(b"".join(response_body)) - except json.JSONDecodeError: + except json.JSONDecodeError: # pragma: no cover return None @staticmethod @@ -128,6 +130,19 @@ def get_path_template(request: Request) -> Tuple[str, bool]: return route.path, True return request.url.path, False + def get_consumer(self, request: Request) -> Optional[str]: + if hasattr(request.state, "consumer_identifier"): + return str(request.state.consumer_identifier) + if self.identify_consumer_callback is not None: + consumer_identifier = self.identify_consumer_callback(request) + if consumer_identifier is not None: + return str(consumer_identifier) + if hasattr(request.state, "key_info") and isinstance(key_info := request.state.key_info, KeyInfo): + return f"key:{key_info.key_id}" + if "user" in request.scope and isinstance(user := request.scope["user"], APIKeyUser): + return f"key:{user.key_info.key_id}" + return None + class APIKeyAuth(AuthenticationBackend): def __init__(self, custom_header: Optional[str] = None) -> None: @@ -201,7 +216,7 @@ def _get_routes(app: ASGIApp) -> List[BaseRoute]: return app.routes elif hasattr(app, "app"): return _get_routes(app.app) - return [] + return [] # pragma: no cover def _get_versions(app_version: Optional[str]) -> Dict[str, str]: diff --git a/tests/django_ninja_urls.py b/tests/django_ninja_urls.py index b76c5ab..626bbe3 100644 --- a/tests/django_ninja_urls.py +++ b/tests/django_ninja_urls.py @@ -29,6 +29,7 @@ def bar(request: HttpRequest) -> str: @api.put("/baz", auth=APIKeyAuth()) def baz(request: HttpRequest) -> str: + request.consumer_identifier = "baz" # type: ignore[attr-defined] raise ValueError("baz") diff --git a/tests/test_client_asyncio.py b/tests/test_client_asyncio.py index 3f70bec..05656de 100644 --- a/tests/test_client_asyncio.py +++ b/tests/test_client_asyncio.py @@ -23,24 +23,28 @@ async def client(module_mocker: MockerFixture) -> ApitallyClient: client = ApitallyClient(client_id=CLIENT_ID, env=ENV, sync_api_keys=True) client.request_logger.log_request( + consumer=None, method="GET", path="/test", status_code=200, response_time=0.105, ) client.request_logger.log_request( + consumer=None, method="GET", path="/test", status_code=200, response_time=0.227, ) client.request_logger.log_request( + consumer=None, method="GET", path="/test", status_code=422, response_time=0.02, ) client.validation_error_logger.log_validation_errors( + consumer=None, method="GET", path="/test", detail=[ diff --git a/tests/test_client_base.py b/tests/test_client_base.py index d512dcd..0945bd7 100644 --- a/tests/test_client_base.py +++ b/tests/test_client_base.py @@ -6,12 +6,14 @@ def test_request_logger(): requests = RequestLogger() requests.log_request( + consumer=None, method="GET", path="/test", status_code=200, response_time=0.105, ) requests.log_request( + consumer=None, method="GET", path="/test", status_code=200, @@ -35,6 +37,7 @@ def test_validation_error_logger(): validation_errors = ValidationErrorLogger() validation_errors.log_validation_errors( + consumer=None, method="GET", path="/test", detail=[ @@ -51,6 +54,7 @@ def test_validation_error_logger(): ], ) validation_errors.log_validation_errors( + consumer=None, method="GET", path="/test", detail=[ diff --git a/tests/test_client_threading.py b/tests/test_client_threading.py index a74dc19..d4ed279 100644 --- a/tests/test_client_threading.py +++ b/tests/test_client_threading.py @@ -23,24 +23,28 @@ def client() -> ApitallyClient: client = ApitallyClient(client_id=CLIENT_ID, env=ENV, sync_api_keys=True) client.request_logger.log_request( + consumer=None, method="GET", path="/test", status_code=200, response_time=0.105, ) client.request_logger.log_request( + consumer=None, method="GET", path="/test", status_code=200, response_time=0.227, ) client.request_logger.log_request( + consumer=None, method="GET", path="/test", status_code=422, response_time=0.02, ) client.validation_error_logger.log_validation_errors( + consumer=None, method="GET", path="/test", detail=[ diff --git a/tests/test_django_ninja.py b/tests/test_django_ninja.py index 36a9999..e682492 100644 --- a/tests/test_django_ninja.py +++ b/tests/test_django_ninja.py @@ -2,7 +2,7 @@ import json from importlib.util import find_spec -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import pytest from pytest_mock import MockerFixture @@ -12,11 +12,18 @@ pytest.skip("django-ninja is not available", allow_module_level=True) if TYPE_CHECKING: + from django.http import HttpRequest from django.test import Client from apitally.client.base import KeyRegistry +def identify_consumer(request: HttpRequest) -> Optional[str]: + if consumer := request.GET.get("consumer"): + return consumer + return None + + @pytest.fixture(scope="module", autouse=True) def setup(module_mocker: MockerFixture) -> None: import django @@ -44,6 +51,7 @@ def setup(module_mocker: MockerFixture) -> None: "client_id": "76b5cb91-a0a4-4ea0-a894-57d2b9fcb2c9", "env": "default", "sync_api_keys": True, + "identify_consumer_callback": "tests.test_django_ninja.identify_consumer", }, ) django.setup() @@ -105,8 +113,9 @@ def test_middleware_validation_error(client: Client, mocker: MockerFixture): def test_api_key_auth(client: Client, key_registry: KeyRegistry, mocker: MockerFixture): - mock = mocker.patch("apitally.django_ninja.ApitallyClient.get_instance") - mock.return_value.key_registry = key_registry + client_get_instance_mock = mocker.patch("apitally.django_ninja.ApitallyClient.get_instance") + client_get_instance_mock.return_value.key_registry = key_registry + log_request_mock = mocker.patch("apitally.client.base.RequestLogger.log_request") # Unauthenticated response = client.get("/api/foo/123") @@ -127,20 +136,31 @@ def test_api_key_auth(client: Client, key_registry: KeyRegistry, mocker: MockerF response = client.get("/api/foo", **headers) # type: ignore[arg-type] assert response.status_code == 403 - # Valid API key, no scope required, custom header + # Valid API key, no scope required, custom header, consumer identified by API key headers = {"HTTP_APIKEY": "7ll40FB.DuHxzQQuGQU4xgvYvTpmnii7K365j9VI"} response = client.get("/api/foo", **headers) # type: ignore[arg-type] assert response.status_code == 200 + assert log_request_mock.call_args.kwargs["consumer"] == "key:1" # Valid API key with required scope headers = {"HTTP_AUTHORIZATION": "ApiKey 7ll40FB.DuHxzQQuGQU4xgvYvTpmnii7K365j9VI"} response = client.get("/api/foo/123", **headers) # type: ignore[arg-type] assert response.status_code == 200 + # Valid API key with required scope, consumer identified by custom function + response = client.get("/api/foo/123?consumer=foo", **headers) # type: ignore[arg-type] + assert response.status_code == 200 + assert log_request_mock.call_args.kwargs["consumer"] == "foo" + # Valid API key without required scope response = client.post("/api/bar", **headers) # type: ignore[arg-type] assert response.status_code == 403 + # Valid API key, consumer identifier from request object + response = client.put("/api/baz", **headers) # type: ignore[arg-type] + assert response.status_code == 500 + assert log_request_mock.call_args.kwargs["consumer"] == "baz" + def test_get_app_info(mocker: MockerFixture): from django.urls import get_resolver diff --git a/tests/test_django_rest_framework.py b/tests/test_django_rest_framework.py index 7678e6b..f1e77bc 100644 --- a/tests/test_django_rest_framework.py +++ b/tests/test_django_rest_framework.py @@ -95,8 +95,9 @@ def test_middleware_requests_error(client: APIClient, mocker: MockerFixture): def test_api_key_auth(client: APIClient, key_registry: KeyRegistry, mocker: MockerFixture): - mock = mocker.patch("apitally.django_rest_framework.ApitallyClient.get_instance") - mock.return_value.key_registry = key_registry + client_get_instance_mock = mocker.patch("apitally.django_rest_framework.ApitallyClient.get_instance") + client_get_instance_mock.return_value.key_registry = key_registry + log_request_mock = mocker.patch("apitally.client.base.RequestLogger.log_request") # Unauthenticated response = client.get("/foo/123/") @@ -116,6 +117,7 @@ def test_api_key_auth(client: APIClient, key_registry: KeyRegistry, mocker: Mock headers = {"HTTP_AUTHORIZATION": "ApiKey 7ll40FB.DuHxzQQuGQU4xgvYvTpmnii7K365j9VI"} response = client.get("/foo/", **headers) # type: ignore[arg-type] assert response.status_code == 200 + assert log_request_mock.call_args.kwargs["consumer"] == "key:1" # Valid API key with required scope response = client.get("/foo/123/", **headers) # type: ignore[arg-type] diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py index 10f31ba..1a5661e 100644 --- a/tests/test_fastapi.py +++ b/tests/test_fastapi.py @@ -1,7 +1,7 @@ from __future__ import annotations from importlib.util import find_spec -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import pytest from pytest_mock import MockerFixture @@ -15,16 +15,36 @@ from apitally.client.base import KeyRegistry -from apitally.client.base import KeyInfo # import here to avoid pydantic error +# Global imports to avoid NameErrors during FastAPI dependency injection +try: + from fastapi import Request + from apitally.client.base import KeyInfo +except ImportError: + pass -@pytest.fixture() -def app() -> FastAPI: + +CLIENT_ID = "76b5cb91-a0a4-4ea0-a894-57d2b9fcb2c9" +ENV = "default" + + +@pytest.fixture(scope="module") +def app(module_mocker: MockerFixture) -> FastAPI: from fastapi import Depends, FastAPI, Security - from apitally.fastapi import APIKeyAuth, api_key_auth + from apitally.fastapi import APIKeyAuth, ApitallyMiddleware, api_key_auth + + module_mocker.patch("apitally.client.asyncio.ApitallyClient._instance", None) + module_mocker.patch("apitally.client.asyncio.ApitallyClient.start_sync_loop") + module_mocker.patch("apitally.client.asyncio.ApitallyClient.send_app_info") + + def identify_consumer(request: Request) -> Optional[str]: + if consumer := request.query_params.get("consumer"): + return consumer + return None app = FastAPI() + app.add_middleware(ApitallyMiddleware, client_id=CLIENT_ID, env=ENV, identify_consumer_callback=identify_consumer) api_key_auth_custom = APIKeyAuth(custom_header="ApiKey") @app.get("/foo/") @@ -36,7 +56,8 @@ def bar(key: KeyInfo = Security(api_key_auth, scopes=["bar"])): return "bar" @app.get("/baz/", dependencies=[Depends(api_key_auth_custom)]) - def baz(): + def baz(request: Request): + request.state.consumer_identifier = "baz" return "baz" return app @@ -48,8 +69,9 @@ def test_api_key_auth(app: FastAPI, key_registry: KeyRegistry, mocker: MockerFix client = TestClient(app) headers = {"Authorization": "ApiKey 7ll40FB.DuHxzQQuGQU4xgvYvTpmnii7K365j9VI"} headers_custom = {"ApiKey": "7ll40FB.DuHxzQQuGQU4xgvYvTpmnii7K365j9VI"} - mock = mocker.patch("apitally.fastapi.ApitallyClient.get_instance") - mock.return_value.key_registry = key_registry + client_get_instance_mock = mocker.patch("apitally.fastapi.ApitallyClient.get_instance") + client_get_instance_mock.return_value.key_registry = key_registry + log_request_mock = mocker.patch("apitally.client.base.RequestLogger.log_request") # Unauthenticated response = client.get("/foo") @@ -71,13 +93,20 @@ def test_api_key_auth(app: FastAPI, key_registry: KeyRegistry, mocker: MockerFix response = client.get("/baz", headers={"ApiKey": "invalid"}) assert response.status_code == 403 - # Valid API key with required scope + # Valid API key with required scope, consumer identified by API key response = client.get("/foo", headers=headers) assert response.status_code == 200 + assert log_request_mock.call_args.kwargs["consumer"] == "key:1" + + # Valid API key with required scope, identify consumer with custom function + response = client.get("/foo?consumer=foo", headers=headers) + assert response.status_code == 200 + assert log_request_mock.call_args.kwargs["consumer"] == "foo" - # Valid API key, no scope required, custom header + # Valid API key, no scope required, custom header, consumer identifier from request.state object response = client.get("/baz", headers=headers_custom) assert response.status_code == 200 + assert log_request_mock.call_args.kwargs["consumer"] == "baz" # Valid API key without required scope response = client.get("/bar", headers=headers) diff --git a/tests/test_flask.py b/tests/test_flask.py index 3273045..e3267bb 100644 --- a/tests/test_flask.py +++ b/tests/test_flask.py @@ -32,17 +32,17 @@ def app(module_mocker: MockerFixture) -> Flask: module_mocker.patch("apitally.flask.ApitallyMiddleware.delayed_send_app_info") app = Flask("test") - app.wsgi_app = ApitallyMiddleware(app.wsgi_app, client_id=CLIENT_ID, env=ENV) # type: ignore[method-assign] + app.wsgi_app = ApitallyMiddleware(app, client_id=CLIENT_ID, env=ENV) # type: ignore[method-assign] - @app.route("/foo//") + @app.route("/foo/") def foo_bar(bar: int): return f"foo: {bar}" - @app.route("/bar/", methods=["POST"]) + @app.route("/bar", methods=["POST"]) def bar(): return "bar" - @app.route("/baz/", methods=["PUT"]) + @app.route("/baz", methods=["PUT"]) def baz(): raise ValueError("baz") @@ -51,7 +51,7 @@ def baz(): @pytest.fixture(scope="module") def app_with_auth(module_mocker: MockerFixture) -> Flask: - from flask import Flask + from flask import Flask, g, request from apitally.flask import ApitallyMiddleware, require_api_key @@ -61,21 +61,27 @@ def app_with_auth(module_mocker: MockerFixture) -> Flask: module_mocker.patch("apitally.flask.ApitallyMiddleware.delayed_send_app_info") app = Flask("test") - app.wsgi_app = ApitallyMiddleware(app.wsgi_app, client_id=CLIENT_ID, env=ENV) # type: ignore[method-assign] + app.wsgi_app = ApitallyMiddleware(app, client_id=CLIENT_ID, env=ENV) # type: ignore[method-assign] + + @app.before_request + def identify_consumer(): + if consumer := request.args.get("consumer"): + g.consumer_identifier = consumer - @app.route("/foo/") + @app.route("/foo") @require_api_key(scopes=["foo"]) def foo(): return "foo" - @app.route("/bar/") + @app.route("/bar") @require_api_key(custom_header="ApiKey", scopes=["bar"]) def bar(): return "bar" - @app.route("/baz/") + @app.route("/baz") @require_api_key def baz(): + g.consumer_identifier = "baz" return "baz" return app @@ -85,16 +91,16 @@ def test_middleware_requests_ok(app: Flask, mocker: MockerFixture): mock = mocker.patch("apitally.client.base.RequestLogger.log_request") client = app.test_client() - response = client.get("/foo/123/") + response = client.get("/foo/123") assert response.status_code == 200 mock.assert_called_once() assert mock.call_args is not None assert mock.call_args.kwargs["method"] == "GET" - assert mock.call_args.kwargs["path"] == "/foo//" + assert mock.call_args.kwargs["path"] == "/foo/" assert mock.call_args.kwargs["status_code"] == 200 assert mock.call_args.kwargs["response_time"] > 0 - response = client.post("/bar/") + response = client.post("/bar") assert response.status_code == 200 assert mock.call_count == 2 assert mock.call_args is not None @@ -105,12 +111,12 @@ def test_middleware_requests_error(app: Flask, mocker: MockerFixture): mock = mocker.patch("apitally.client.base.RequestLogger.log_request") client = app.test_client() - response = client.put("/baz/") + response = client.put("/baz") assert response.status_code == 500 mock.assert_called_once() assert mock.call_args is not None assert mock.call_args.kwargs["method"] == "PUT" - assert mock.call_args.kwargs["path"] == "/baz/" + assert mock.call_args.kwargs["path"] == "/baz" assert mock.call_args.kwargs["status_code"] == 500 assert mock.call_args.kwargs["response_time"] > 0 @@ -119,7 +125,7 @@ def test_middleware_requests_unhandled(app: Flask, mocker: MockerFixture): mock = mocker.patch("apitally.client.base.RequestLogger.log_request") client = app.test_client() - response = client.post("/xxx/") + response = client.post("/xxx") assert response.status_code == 404 mock.assert_not_called() @@ -128,45 +134,53 @@ def test_require_api_key(app_with_auth: Flask, key_registry: KeyRegistry, mocker client = app_with_auth.test_client() headers = {"Authorization": "ApiKey 7ll40FB.DuHxzQQuGQU4xgvYvTpmnii7K365j9VI"} headers_custom = {"ApiKey": "7ll40FB.DuHxzQQuGQU4xgvYvTpmnii7K365j9VI"} - mock = mocker.patch("apitally.flask.ApitallyClient.get_instance") - mock.return_value.key_registry = key_registry + client_get_instance_mock = mocker.patch("apitally.flask.ApitallyClient.get_instance") + client_get_instance_mock.return_value.key_registry = key_registry + log_request_mock = mocker.patch("apitally.client.base.RequestLogger.log_request") # Unauthenticated - response = client.get("/foo/") + response = client.get("/foo") assert response.status_code == 401 - response = client.get("/baz/") + response = client.get("/baz") assert response.status_code == 401 # Invalid auth scheme - response = client.get("/foo/", headers={"Authorization": "Bearer invalid"}) + response = client.get("/foo", headers={"Authorization": "Bearer invalid"}) assert response.status_code == 401 # Invalid API key - response = client.get("/foo/", headers={"Authorization": "ApiKey invalid"}) + response = client.get("/foo", headers={"Authorization": "ApiKey invalid"}) assert response.status_code == 403 # Invalid API key, custom header - response = client.get("/bar/", headers={"ApiKey": "invalid"}) + response = client.get("/bar", headers={"ApiKey": "invalid"}) assert response.status_code == 403 # Valid API key with required scope - response = client.get("/foo/", headers=headers) + response = client.get("/foo", headers=headers) + assert response.status_code == 200 + assert log_request_mock.call_args.kwargs["consumer"] == "key:1" + + # Valid API key with required scope, identify consumer with custom function + response = client.get("/foo?consumer=foo", headers=headers) assert response.status_code == 200 + assert log_request_mock.call_args.kwargs["consumer"] == "foo" # Valid API key, no scope required - response = client.get("/baz/", headers=headers) + response = client.get("/baz", headers=headers) assert response.status_code == 200 + assert log_request_mock.call_args.kwargs["consumer"] == "baz" # Valid API key without required scope, custom header - response = client.get("/bar/", headers=headers_custom) + response = client.get("/bar", headers=headers_custom) assert response.status_code == 403 def test_get_app_info(app: Flask): from apitally.flask import _get_app_info - app_info = _get_app_info(app.wsgi_app, app.url_map, app_version="1.2.3", openapi_url="/openapi.json") + app_info = _get_app_info(app, app_version="1.2.3", openapi_url="/openapi.json") assert len(app_info["paths"]) == 3 assert app_info["versions"]["flask"] assert app_info["versions"]["app"] == "1.2.3" diff --git a/tests/test_starlette.py b/tests/test_starlette.py index d076dae..84a8d48 100644 --- a/tests/test_starlette.py +++ b/tests/test_starlette.py @@ -2,7 +2,6 @@ from importlib.util import find_spec from typing import TYPE_CHECKING, Tuple -from unittest.mock import MagicMock import pytest from pytest import FixtureRequest @@ -17,8 +16,6 @@ from apitally.client.base import KeyRegistry -from starlette.background import BackgroundTasks # import here to avoid pydantic error - CLIENT_ID = "76b5cb91-a0a4-4ea0-a894-57d2b9fcb2c9" ENV = "default" @@ -40,7 +37,7 @@ async def app(request: FixtureRequest, module_mocker: MockerFixture) -> Starlett @pytest.fixture(params=["Authorization", "ApiKey"]) -def app_with_auth(request: FixtureRequest) -> Tuple[Starlette, str]: +def app_with_auth(request: FixtureRequest, mocker: MockerFixture) -> Tuple[Starlette, str]: from starlette.applications import Starlette from starlette.authentication import requires from starlette.middleware.authentication import AuthenticationMiddleware @@ -48,7 +45,9 @@ def app_with_auth(request: FixtureRequest) -> Tuple[Starlette, str]: from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Route - from apitally.starlette import APIKeyAuth + from apitally.starlette import APIKeyAuth, ApitallyMiddleware + + mocker.patch("apitally.client.asyncio.ApitallyClient._instance", None) @requires(["authenticated", "foo"]) def foo(request: Request): @@ -61,6 +60,7 @@ def bar(request: Request): @requires("authenticated") def baz(request: Request): + request.state.consumer_identifier = "baz" return JSONResponse( { "key_id": int(request.user.identity), @@ -80,25 +80,23 @@ def baz(request: Request): AuthenticationMiddleware, backend=APIKeyAuth() if api_key_header == "Authorization" else APIKeyAuth(custom_header=api_key_header), ) + app.add_middleware(ApitallyMiddleware, client_id=CLIENT_ID, env=ENV) return (app, api_key_header) def get_starlette_app() -> Starlette: from starlette.applications import Starlette - from starlette.background import BackgroundTask, BackgroundTasks from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route from apitally.starlette import ApitallyMiddleware - background_task_mock = MagicMock() - def foo(request: Request): - return PlainTextResponse("foo", background=BackgroundTasks([BackgroundTask(background_task_mock)])) + return PlainTextResponse("foo") def foo_bar(request: Request): - return PlainTextResponse(f"foo: {request.path_params['bar']}", background=BackgroundTask(background_task_mock)) + return PlainTextResponse(f"foo: {request.path_params['bar']}") def bar(request: Request): return PlainTextResponse("bar") @@ -118,7 +116,6 @@ def val(request: Request): ] app = Starlette(routes=routes) app.add_middleware(ApitallyMiddleware, client_id=CLIENT_ID, env=ENV) - app.state.background_task_mock = background_task_mock return app @@ -127,20 +124,15 @@ def get_fastapi_app() -> Starlette: from apitally.fastapi import ApitallyMiddleware - background_task_mock = MagicMock() - app = FastAPI(title="Test App", description="A simple test app.", version="1.2.3") app.add_middleware(ApitallyMiddleware, client_id=CLIENT_ID, env=ENV) - app.state.background_task_mock = background_task_mock @app.get("/foo/") - def foo(background_tasks: BackgroundTasks): - background_tasks.add_task(background_task_mock) + def foo(): return "foo" @app.get("/foo/{bar}/") - def foo_bar(bar: str, background_tasks: BackgroundTasks): - background_tasks.add_task(background_task_mock) + def foo_bar(bar: str): return f"foo: {bar}" @app.post("/bar/") @@ -163,11 +155,9 @@ def test_middleware_requests_ok(app: Starlette, mocker: MockerFixture): mock = mocker.patch("apitally.client.base.RequestLogger.log_request") client = TestClient(app) - background_task_mock: MagicMock = app.state.background_task_mock # type: ignore[attr-defined] response = client.get("/foo/") assert response.status_code == 200 - background_task_mock.assert_called_once() mock.assert_called_once() assert mock.call_args is not None assert mock.call_args.kwargs["method"] == "GET" @@ -177,7 +167,6 @@ def test_middleware_requests_ok(app: Starlette, mocker: MockerFixture): response = client.get("/foo/123/") assert response.status_code == 200 - assert background_task_mock.call_count == 2 assert mock.call_count == 2 assert mock.call_args is not None assert mock.call_args.kwargs["path"] == "/foo/{bar}/" @@ -222,6 +211,7 @@ def test_middleware_validation_error(app: Starlette, mocker: MockerFixture): mock = mocker.patch("apitally.client.base.ValidationErrorLogger.log_validation_errors") client = TestClient(app) + # Validation error as foo must be an integer response = client.get("/val?foo=bar") assert response.status_code == 422 @@ -248,8 +238,9 @@ def test_api_key_auth(app_with_auth: Tuple[Starlette, str], key_registry: KeyReg headers_invalid = ( {"Authorization": "ApiKey invalid"} if api_key_header == "Authorization" else {api_key_header: "invalid"} ) - mock = mocker.patch("apitally.starlette.ApitallyClient.get_instance") - mock.return_value.key_registry = key_registry + client_get_instance_mock = mocker.patch("apitally.starlette.ApitallyClient.get_instance") + client_get_instance_mock.return_value.key_registry = key_registry + log_request_mock = mocker.patch("apitally.client.base.RequestLogger.log_request") # Unauthenticated response = client.get("/foo") @@ -258,21 +249,27 @@ def test_api_key_auth(app_with_auth: Tuple[Starlette, str], key_registry: KeyReg response = client.get("/baz") assert response.status_code == 403 + # Invalid auth scheme + response = client.get("/foo", headers={"Authorization": "Bearer invalid"}) + assert response.status_code == 403 + # Invalid API key response = client.get("/foo", headers=headers_invalid) assert response.status_code == 400 - # Valid API key with required scope + # Valid API key with required scope, consumer identified by API key response = client.get("/foo", headers=headers) assert response.status_code == 200 + assert log_request_mock.call_args.kwargs["consumer"] == "key:1" - # Valid API key, no scope required + # Valid API key, no scope required, consumer identifier from request.state object response = client.get("/baz", headers=headers) assert response.status_code == 200 response_data = response.json() assert response_data["key_id"] == 1 assert response_data["key_name"] == "Test key" assert response_data["key_scopes"] == ["authenticated", "foo"] + assert log_request_mock.call_args.kwargs["consumer"] == "baz" # Valid API key without required scope response = client.get("/bar", headers=headers)