From 1367d9575060f7c2e8abdb1c2e009862020add9a Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 21 Nov 2024 20:38:50 -0800 Subject: [PATCH] refac: enforce api on all routes --- main.py | 19 +++++++++++-------- utils/pipelines/auth.py | 12 ++++++++++++ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index 33c04997..cff33353 100644 --- a/main.py +++ b/main.py @@ -106,28 +106,31 @@ def get_all_pipelines(): return pipelines + def parse_frontmatter(content): frontmatter = {} - for line in content.split('\n'): - if ':' in line: - key, value = line.split(':', 1) + for line in content.split("\n"): + if ":" in line: + key, value = line.split(":", 1) frontmatter[key.strip().lower()] = value.strip() return frontmatter + def install_frontmatter_requirements(requirements): if requirements: - req_list = [req.strip() for req in requirements.split(',')] + req_list = [req.strip() for req in requirements.split(",")] for req in req_list: print(f"Installing requirement: {req}") subprocess.check_call([sys.executable, "-m", "pip", "install", req]) else: print("No requirements found in frontmatter.") + async def load_module_from_path(module_name, module_path): try: # Read the module content - with open(module_path, 'r') as file: + with open(module_path, "r") as file: content = file.read() # Parse frontmatter @@ -139,8 +142,8 @@ async def load_module_from_path(module_name, module_path): frontmatter = parse_frontmatter(frontmatter_content) # Install requirements if specified - if 'requirements' in frontmatter: - install_frontmatter_requirements(frontmatter['requirements']) + if "requirements" in frontmatter: + install_frontmatter_requirements(frontmatter["requirements"]) # Load the module spec = importlib.util.spec_from_file_location(module_name, module_path) @@ -277,7 +280,7 @@ async def check_url(request: Request, call_next): @app.get("/v1/models") @app.get("/models") -async def get_models(): +async def get_models(user: str = Depends(get_current_user)): """ Returns the available pipelines """ diff --git a/utils/pipelines/auth.py b/utils/pipelines/auth.py index df03ad87..2b098207 100644 --- a/utils/pipelines/auth.py +++ b/utils/pipelines/auth.py @@ -1,6 +1,7 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi import HTTPException, status, Depends + from pydantic import BaseModel from typing import Union, Optional @@ -14,6 +15,10 @@ import requests import uuid + +from config import API_KEY, PIPELINES_DIR + + SESSION_SECRET = os.getenv("SESSION_SECRET", " ") ALGORITHM = "HS256" @@ -62,4 +67,11 @@ def get_current_user( credentials: HTTPAuthorizationCredentials = Depends(bearer_security), ) -> Optional[dict]: token = credentials.credentials + + if token != API_KEY: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + ) + return token