From db97ead3d01c706aec109975b019eb018e0284a1 Mon Sep 17 00:00:00 2001 From: Ana Rute Mendes Date: Mon, 10 Jul 2023 13:20:18 +0200 Subject: [PATCH] Add JWT auth support Add support to receive a Bearer token in the API requests and decode it. --- .env.example | 2 ++ api/auth/auth_bearer.py | 24 ++++++++++++++++++++++++ api/auth/auth_handler.py | 16 ++++++++++++++++ api/pyproject.toml | 2 ++ api/routers/projects.py | 5 +++-- 5 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 api/auth/auth_bearer.py create mode 100644 api/auth/auth_handler.py diff --git a/.env.example b/.env.example index c4ead3d22..155c607bb 100644 --- a/.env.example +++ b/.env.example @@ -21,6 +21,8 @@ EMPLOYEES_GROUP="staff" USE_EXTERNAL_AUTHENTICATION=false EXTERNAL_AUTHENTICATION_USER_HEADER="" +JWT_ALGORITHM= +JWT_SECRET= CALENDAR_URL="" CALENDAR_ID="" diff --git a/api/auth/auth_bearer.py b/api/auth/auth_bearer.py new file mode 100644 index 000000000..6a1d0821f --- /dev/null +++ b/api/auth/auth_bearer.py @@ -0,0 +1,24 @@ +from fastapi import Request, HTTPException +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials + +from .auth_handler import decode_token + + +class BearerToken(HTTPBearer): + def __init__(self, auto_error: bool = True): + super(BearerToken, self).__init__(auto_error=auto_error) + + async def __call__(self, request: Request): + credentials: HTTPAuthorizationCredentials = await super(BearerToken, self).__call__(request) + if credentials: + if not credentials.scheme == "Bearer": + raise HTTPException(status_code=401, detail="Invalid authentication scheme.") + if not self.verify_token(credentials.credentials): + raise HTTPException(status_code=401, detail="Invalid or expired token.") + return credentials.credentials + else: + raise HTTPException(status_code=401, detail="Invalid authorization code.") + + def verify_token(self, token: str) -> bool: + payload = decode_token(token) + return True if payload else False diff --git a/api/auth/auth_handler.py b/api/auth/auth_handler.py new file mode 100644 index 000000000..b490f72be --- /dev/null +++ b/api/auth/auth_handler.py @@ -0,0 +1,16 @@ +import time +import jwt +from decouple import config + + +JWT_SECRET = config("JWT_SECRET") +JWT_ALGORITHM = config("JWT_ALGORITHM") + + +def token_response(token: str): + return {"access_token": token} + + +def decode_token(token: str) -> dict: + decoded_token = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) + return decoded_token if decoded_token["expires"] >= time.time() else None diff --git a/api/pyproject.toml b/api/pyproject.toml index bbb51207d..2431cc095 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -13,6 +13,8 @@ authors = [ readme = "README.md" dependencies = [ + "PyJWT == 2.7.0", + "python-decouple == 3.8", "pydantic == 1.10.7", "psycopg2-binary == 2.9.6", "alembic == 1.10.4", diff --git a/api/routers/projects.py b/api/routers/projects.py index c57f4308f..68f2ac04c 100644 --- a/api/routers/projects.py +++ b/api/routers/projects.py @@ -4,15 +4,16 @@ from models.project import Project from schemas.project import Project as ProjectSchema from db.db_connection import get_db +from auth.auth_bearer import BearerToken router = APIRouter(prefix="/projects", tags=["projects"]) -@router.get("/", response_model=list[ProjectSchema]) +@router.get("/", dependencies=[Depends(BearerToken())], response_model=list[ProjectSchema]) async def get_projects(db: Session = Depends(get_db), skip: int = 0, limit: int = 100): return db.query(Project).offset(skip).limit(limit).all() -@router.get("/{project_id}", response_model=ProjectSchema) +@router.get("/{project_id}", dependencies=[Depends(BearerToken())], response_model=ProjectSchema) async def get_project(project_id: int, db: Session = Depends(get_db)): return db.query(Project).filter(Project.id == project_id).first()