Skip to content

Commit

Permalink
Add JWT auth support
Browse files Browse the repository at this point in the history
Add support to receive a Bearer token in the API requests and
decode it.
  • Loading branch information
anarute committed Jul 11, 2023
1 parent ed0c0e3 commit db97ead
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ EMPLOYEES_GROUP="staff"

USE_EXTERNAL_AUTHENTICATION=false
EXTERNAL_AUTHENTICATION_USER_HEADER=""
JWT_ALGORITHM=
JWT_SECRET=

CALENDAR_URL=""
CALENDAR_ID=""
Expand Down
24 changes: 24 additions & 0 deletions api/auth/auth_bearer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from fastapi import Request, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials

from .auth_handler import decode_token


class BearerToken(HTTPBearer):
def __init__(self, auto_error: bool = True):
super(BearerToken, self).__init__(auto_error=auto_error)

async def __call__(self, request: Request):
credentials: HTTPAuthorizationCredentials = await super(BearerToken, self).__call__(request)
if credentials:
if not credentials.scheme == "Bearer":
raise HTTPException(status_code=401, detail="Invalid authentication scheme.")
if not self.verify_token(credentials.credentials):
raise HTTPException(status_code=401, detail="Invalid or expired token.")
return credentials.credentials
else:
raise HTTPException(status_code=401, detail="Invalid authorization code.")

def verify_token(self, token: str) -> bool:
payload = decode_token(token)
return True if payload else False
16 changes: 16 additions & 0 deletions api/auth/auth_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import time
import jwt
from decouple import config


JWT_SECRET = config("JWT_SECRET")
JWT_ALGORITHM = config("JWT_ALGORITHM")


def token_response(token: str):
return {"access_token": token}


def decode_token(token: str) -> dict:
decoded_token = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
return decoded_token if decoded_token["expires"] >= time.time() else None
2 changes: 2 additions & 0 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ authors = [
readme = "README.md"

dependencies = [
"PyJWT == 2.7.0",
"python-decouple == 3.8",
"pydantic == 1.10.7",
"psycopg2-binary == 2.9.6",
"alembic == 1.10.4",
Expand Down
5 changes: 3 additions & 2 deletions api/routers/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
from models.project import Project
from schemas.project import Project as ProjectSchema
from db.db_connection import get_db
from auth.auth_bearer import BearerToken

router = APIRouter(prefix="/projects", tags=["projects"])


@router.get("/", response_model=list[ProjectSchema])
@router.get("/", dependencies=[Depends(BearerToken())], response_model=list[ProjectSchema])
async def get_projects(db: Session = Depends(get_db), skip: int = 0, limit: int = 100):
return db.query(Project).offset(skip).limit(limit).all()


@router.get("/{project_id}", response_model=ProjectSchema)
@router.get("/{project_id}", dependencies=[Depends(BearerToken())], response_model=ProjectSchema)
async def get_project(project_id: int, db: Session = Depends(get_db)):
return db.query(Project).filter(Project.id == project_id).first()

0 comments on commit db97ead

Please sign in to comment.