Skip to content

Commit

Permalink
feat: simpler jwt exp (#12)
Browse files Browse the repository at this point in the history
* docs: added type-checking method docstring
* feat: simpler and more flexible make_jwt utility function
The function now accepts a delta value in hours, with the possibility of also inserting a base time value for such a sum.

* docs: update encode_jwt_token syntax on readme example
  • Loading branch information
ericmiguel committed Jan 29, 2024
1 parent b403f23 commit 6a486df
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 35 deletions.
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ import missil
from fastapi import FastAPI
from fastapi import Response

from datetime import datetime
from datetime import timezone
from datetime import timedelta

app = FastAPI()

TOKEN_KEY = "Authorization"
Expand All @@ -71,8 +67,8 @@ def set_cookies(response: Response) -> None:
"it": missil.WRITE,
}

token_expiration = datetime.now(timezone.utc) + timedelta(hours=8)
token = missil.encode_jwt_token(sample_user_privileges, SECRET_KEY, token_expiration)
token_expiration_in_hours = 8
token = missil.encode_jwt_token(claims, SECRET_KEY, token_expiration_in_hours)

response.set_cookie(
key=TOKEN_KEY,
Expand Down
19 changes: 14 additions & 5 deletions missil/jwt_utilities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""JWT token utilities."""

from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any

from fastapi import status
Expand Down Expand Up @@ -62,7 +64,11 @@ def decode_jwt_token(


def encode_jwt_token(
claims: dict[str, Any], secret: str, exp: datetime, algorithm: str = "HS256"
claims: dict[str, Any],
secret: str,
exp: int,
base: datetime = datetime.now(timezone.utc),
algorithm: str = "HS256",
) -> str:
"""
Create a JWT token.
Expand All @@ -73,16 +79,19 @@ def encode_jwt_token(
Token user data.
secret : str
Secret key to sign the token.
exp : datetime
Token expiration datetime.
exp : int, optional
Token expiration in hours.
base : datetime, optional
Token expiration base datetime, where the final datetime is given by
base + exp, by default datetime.now(timezone.utc)
algorithm : str, optional
Encode algorithm, by default "HS256"
Returns
-------
str
Encoded JWT token.
_description_
"""
to_encode = claims.copy()
to_encode.update({"exp": exp})
to_encode.update({"exp": base + timedelta(exp)})
return jwt.encode(to_encode, key=secret, algorithm=algorithm)
8 changes: 2 additions & 6 deletions sample/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
"""Missil sample usage."""

from datetime import datetime
from datetime import timedelta
from datetime import timezone

from fastapi import FastAPI
from fastapi import Response

Expand Down Expand Up @@ -45,8 +41,8 @@ def set_cookies(response: Response) -> dict[str, str]:
},
}

token_expiration = datetime.now(timezone.utc) + timedelta(hours=8)
token = missil.encode_jwt_token(claims, SECRET_KEY, token_expiration)
token_expiration_in_hours = 8
token = missil.encode_jwt_token(claims, SECRET_KEY, token_expiration_in_hours)

response.set_cookie(
key=TOKEN_KEY,
Expand Down
67 changes: 49 additions & 18 deletions tests/test_jwt_utilities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging

from datetime import datetime
from datetime import timedelta
from datetime import timezone

import pytest

Expand Down Expand Up @@ -47,13 +49,18 @@ def fake_claims():


@pytest.fixture(scope="module")
def token_valid_expiration():
return datetime(2200, 1, 1, 0, 0, 0, 0)
def timedelta_token_expiration():
return 8 # hours


@pytest.fixture(scope="module")
def token_expired_datetime():
return datetime(1900, 1, 1, 0, 0, 0, 0)
def token_valid_base_expiration():
return datetime(2200, 1, 1, 0, 0, 0, 0, timezone.utc)


@pytest.fixture(scope="module")
def token_expired_base_datetime():
return datetime(1900, 1, 1, 0, 0, 0, 0, timezone.utc)


@pytest.fixture(scope="module")
Expand All @@ -62,9 +69,13 @@ def token_invalid_expiration():


@pytest.fixture(scope="module")
def encoded_jwt_token(claims, secret_key, token_valid_expiration):
def encoded_jwt_token(
claims, secret_key, timedelta_token_expiration, token_valid_base_expiration
):
to_encode = claims.copy()
to_encode.update({"exp": token_valid_expiration})
to_encode.update(
{"exp": token_valid_base_expiration + timedelta(timedelta_token_expiration)}
)
return jwt.encode(to_encode, secret_key, "HS256")


Expand All @@ -76,49 +87,69 @@ def encoded_invalid_claims_jwt_token(claims, secret_key, token_invalid_expiratio


@pytest.fixture(scope="module")
def encoded_expired_jwt_token(claims, secret_key, token_expired_datetime):
def encoded_expired_jwt_token(claims, secret_key, token_expired_base_datetime):
to_encode = claims.copy()
to_encode.update({"exp": token_expired_datetime})
to_encode.update({"exp": token_expired_base_datetime})
return jwt.encode(to_encode, secret_key, "HS256")


@pytest.fixture(scope="module")
def encoded_invalid_jwt_token(claims, secret_key, token_expired_datetime):
def encoded_invalid_jwt_token(claims, secret_key, token_expired_base_datetime):
to_encode = claims.copy()
to_encode.update({"exp": token_expired_datetime})
to_encode.update({"exp": token_expired_base_datetime})
encoded = jwt.encode(to_encode, secret_key, "HS256")
invalidated_token = encoded[:39] + encoded[40:]
return invalidated_token


def test_encode_jwt_token(
claims, secret_key, token_valid_expiration, encoded_jwt_token
claims,
secret_key,
timedelta_token_expiration,
token_valid_base_expiration,
encoded_jwt_token,
):
result = jwt_utilities.encode_jwt_token(claims, secret_key, token_valid_expiration)
result = jwt_utilities.encode_jwt_token(
claims, secret_key, timedelta_token_expiration, token_valid_base_expiration
)
assert result == encoded_jwt_token


def test_encode_expired_jwt_token(
claims, secret_key, token_expired_datetime, encoded_jwt_token
claims,
secret_key,
timedelta_token_expiration,
token_expired_base_datetime,
encoded_jwt_token,
):
result = jwt_utilities.encode_jwt_token(claims, secret_key, token_expired_datetime)
result = jwt_utilities.encode_jwt_token(
claims, secret_key, timedelta_token_expiration, token_expired_base_datetime
)
assert result != encoded_jwt_token


def test_encode_fake_claim_jwt_token(
fake_claims, secret_key, token_valid_expiration, encoded_jwt_token
fake_claims,
secret_key,
timedelta_token_expiration,
token_valid_base_expiration,
encoded_jwt_token,
):
result = jwt_utilities.encode_jwt_token(
fake_claims, secret_key, token_valid_expiration
fake_claims, secret_key, timedelta_token_expiration, token_valid_base_expiration
)
assert result != encoded_jwt_token


def test_encode_jwt_token_fake_key(
claims, fake_secret_key, token_valid_expiration, encoded_jwt_token
claims,
fake_secret_key,
timedelta_token_expiration,
token_valid_base_expiration,
encoded_jwt_token,
):
result = jwt_utilities.encode_jwt_token(
claims, fake_secret_key, token_valid_expiration
claims, fake_secret_key, timedelta_token_expiration, token_valid_base_expiration
)
assert result != encoded_jwt_token

Expand Down

0 comments on commit 6a486df

Please sign in to comment.