Skip to content

Commit

Permalink
Feat/api-improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
ericmiguel authored May 5, 2024
2 parents e21c6bd + 9c068ee commit 9714a9a
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 26 deletions.
40 changes: 29 additions & 11 deletions missil/bearers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""JWT token obtaining via dependency injection."""

from typing import Any

from fastapi import Request
from fastapi import status

Expand Down Expand Up @@ -93,7 +95,20 @@ def decode_jwt(self, token: str) -> dict[str, int]:
decoded_token = decode_jwt_token(
token, self.token_secret_key, algorithm=self.algorithm
)
return decoded_token

def decode_from_cookies(self, request: Request) -> dict[str, Any]:
"""Get token from cookies and decode it."""
token = self.get_token_from_cookies(request)
return self.decode_jwt(token)

def decode_from_header(self, request: Request) -> dict[str, Any]:
"""Get token from headers and decode it."""
token = self.get_token_from_header(request)
return self.decode_jwt(token)

def get_user_permissions(self, decoded_token: dict[str, Any]) -> dict[str, int]:
"""Get user permissions from a decoded token."""
if self.user_permissions_key:
try:
user_permissions: dict[str, int] = decoded_token[
Expand All @@ -110,39 +125,42 @@ def decode_jwt(self, token: str) -> dict[str, int]:
else:
return user_permissions

return decoded_token
raise TokenErrorException(500, "User permissions key not provided.")


class CookieTokenBearer(TokenBearer):
"""Read JWT token from http cookies."""

async def __call__(self, request: Request) -> dict[str, int]:
async def __call__(self, request: Request) -> tuple[dict[str, Any], dict[str, int]]:
"""Fastapi FastAPIDependsFunc will call this method."""
token = self.get_token_from_cookies(request)
return self.decode_jwt(token)
decoded_token = self.decode_from_cookies(request)
user_permissions = self.get_user_permissions(decoded_token)
return decoded_token, user_permissions


class HTTPTokenBearer(TokenBearer):
"""Read JWT token from the request header."""

async def __call__(self, request: Request) -> dict[str, int]:
async def __call__(self, request: Request) -> tuple[dict[str, Any], dict[str, int]]:
"""Fastapi FastAPIDependsFunc will call this method."""
token = self.get_token_from_header(request)
return self.decode_jwt(token)
decoded_token = self.decode_from_header(request)
user_permissions = self.get_user_permissions(decoded_token)
return decoded_token, user_permissions


class FlexibleTokenBearer(TokenBearer):
"""Tries to read the token from the cookies or from request headers."""

async def __call__(self, request: Request) -> dict[str, int]:
async def __call__(self, request: Request) -> tuple[dict[str, Any], dict[str, int]]:
"""Fastapi FastAPIDependsFunc will call this method."""
try:
token = self.get_token_from_cookies(request)
decoded_token = self.decode_from_cookies(request)
except TokenErrorException:
token = self.get_token_from_header(request)
decoded_token = self.decode_from_header(request)
except Exception as e:
raise TokenErrorException(
status.HTTP_417_EXPECTATION_FAILED, "Token not found."
) from e

return self.decode_jwt(token)
user_permissions = self.get_user_permissions(decoded_token)
return decoded_token, user_permissions
33 changes: 24 additions & 9 deletions missil/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,33 @@ def dependency(self) -> Callable[..., Any] | None:
"""Allows Missil to pass a FastAPI dependency that gets correctly evaluated."""

def check_user_permissions(
claims: Annotated[dict[str, int], FastAPIDependsFunc(self.bearer)],
) -> None:
claims: Annotated[
tuple[
dict[str, Any], ## full claims
dict[str, int], ## user permissions
],
FastAPIDependsFunc(self.bearer),
],
) -> dict[str, Any]:
"""
Run JWT claims against an declared endpoint rule.
If claims contains the asked business area and sufficient access level,
the endpoint access is granted to the user.
the endpoint access is granted to the user and the full claims are returned.
Parameters
----------
claims : Annotated[dict[str, int], FastAPIDependsFunc
claims : Annotated[
tuple[
dict[str, Any],
dict[str, int]
],
FastAPIDependsFunc
]
Content decoded from a JWT Token, obtained after FastAPI resolves
the TokenBearer dependency. Missil expects an dict using the
following structure:
the TokenBearer dependency. Missil expects a permission dict like the
following example structure:
```python
{
Expand All @@ -87,18 +100,20 @@ def check_user_permissions(
PermissionErrorException
Insufficient access level.
"""
if self.area not in claims:
if self.area not in claims[1]:
raise PermissionErrorException(
status.HTTP_403_FORBIDDEN, f"'{self.area}' not in user permissions."
)

if not claims[self.area] >= self.level:
if not claims[1][self.area] >= self.level:
raise PermissionErrorException(
status.HTTP_403_FORBIDDEN,
"insufficient access level: "
f"({claims[self.area]}/{self.level}) on {self.area}.",
f"({claims[1][self.area]}/{self.level}) on {self.area}.",
)

return claims[0]

return check_user_permissions

if TYPE_CHECKING:
Expand Down
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ classifiers = [
"Framework :: AsyncIO",
"Framework :: FastAPI",
"Framework :: Pydantic",
"Framework :: Pydantic :: 1",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3 :: Only",
Expand All @@ -47,12 +46,10 @@ log_cli = true
log_cli_level = "INFO"
log_cli_format = "%(message)s"
log_file = "pytest.log"
log_file_level = "DEBUG"
log_file_level = "INFO"
log_file_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
log_file_date_format = "%Y-%m-%d %H:%M:%S"

ignore_decorators = ["@field_validator", "@app*", "@route*"]

[tool.ruff]

# Enable fix behavior by-default when running ruff
Expand Down
11 changes: 11 additions & 0 deletions sample/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Missil sample usage."""

from typing import Annotated
from typing import Any

from fastapi import FastAPI
from fastapi import Response

Expand Down Expand Up @@ -66,6 +69,14 @@ def finances_write() -> dict[str, str]:
return {"msg": "you have permission to perform write actions on finances!"}


@app.get("/user-profile", dependencies=[bas["it"].READ])
def get_user_profile(
user_profile: Annotated[dict[str, Any], bas["it"].READ],
) -> dict[str, Any]:
"""Require read permission on it."""
return user_profile


@finances_read_router.get("/finances/read/router")
def finances_read_route() -> dict[str, str]:
"""Require read permission on finances."""
Expand Down
18 changes: 16 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import json

import pytest
from starlette.testclient import TestClient

from missil import decode_jwt_token
from sample.main import app


Expand All @@ -10,8 +13,19 @@ def test_app():
yield client


@pytest.fixture(scope="function")
@pytest.fixture(scope="module")
def bearer_token(test_app):
test_app.get("/set-cookies")
bearer_token = dict(test_app.cookies)["Authorization"].replace(" ", "")
bearer_token = json.loads(test_app.cookies["Authorization"]).replace("Bearer ", "")
print(f"Bearer token: {bearer_token}")
yield bearer_token


@pytest.fixture(scope="module")
def jwt_secret_key():
return "2ef9451be5d149ceaf5be306b5aa03b41a0331218926e12329c5eeba60ed5cf0"


@pytest.fixture(scope="module")
def decoded_token(bearer_token, jwt_secret_key):
return decode_jwt_token(bearer_token, jwt_secret_key)
13 changes: 13 additions & 0 deletions tests/test_sample_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,16 @@ def test_write_access(api_url, response_msg, test_app, bearer_token):

assert response.status_code == 403
assert response.json() == {"detail": response_msg}


@ignore_warnings
@pytest.mark.parametrize(
"api_url",
[
"/user-profile",
],
)
def test_get_current_user(api_url, test_app, bearer_token, decoded_token):
response = test_app.get(api_url, headers={"Authorization": bearer_token})
assert response.status_code == 200
assert response.json() == decoded_token

0 comments on commit 9714a9a

Please sign in to comment.