-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f675017
commit d4358d5
Showing
20 changed files
with
762 additions
and
129 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import { ServerResponse } from "@/server/utils" | ||
|
||
import { kvdel, kvgetdec, kvsetenc } from "@/lib/kv" | ||
import { getCurrentUser } from "@/lib/session" | ||
|
||
export async function POST(req: Request): Promise<Response> { | ||
try { | ||
const body = await req.json() | ||
const { apiKey } = body | ||
|
||
const { sessionUser: user } = await getCurrentUser() | ||
|
||
if (!user?.id) { | ||
return ServerResponse.unauthorized() | ||
} | ||
|
||
if (!apiKey) { | ||
return ServerResponse.badRequest("Missing API key") | ||
} | ||
|
||
await kvsetenc(user.id, "api_key", apiKey) | ||
|
||
return ServerResponse.success({ | ||
body: { message: "API key saved" }, | ||
}) | ||
} catch (error) { | ||
return ServerResponse.internalServerError( | ||
error instanceof Error ? error.message : String(error) | ||
) | ||
} | ||
} | ||
|
||
export async function GET(): Promise<Response> { | ||
try { | ||
const { sessionUser: user } = await getCurrentUser() | ||
|
||
if (!user?.id) { | ||
return ServerResponse.unauthorized() | ||
} | ||
|
||
const apiKeyWithModel = (await kvgetdec(user.id, "api_key")) || "" | ||
const apiKey = apiKeyWithModel.split("::")[0] | ||
|
||
if (!apiKey) { | ||
return ServerResponse.notFound("API key not found") | ||
} | ||
|
||
return ServerResponse.success({ | ||
body: { apiKey }, | ||
}) | ||
} catch (error) { | ||
return ServerResponse.internalServerError( | ||
error instanceof Error ? error.message : String(error) | ||
) | ||
} | ||
} | ||
|
||
export async function DELETE(): Promise<Response> { | ||
try { | ||
const { sessionUser: user } = await getCurrentUser() | ||
|
||
if (!user?.id) { | ||
return ServerResponse.unauthorized() | ||
} | ||
|
||
await kvdel(user.id, "api_key") | ||
return ServerResponse.success({ | ||
body: { message: "API key deleted" }, | ||
}) | ||
} catch (error) { | ||
return ServerResponse.error( | ||
(error as any).message || String(error), | ||
(error as any).status ?? 500 | ||
) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,108 @@ | ||
import { ServerResponse } from "@/server/utils" | ||
import { OpenAIBody } from "@/types" | ||
import { isCorrectApiKey } from "@/utils/openai" | ||
import { OpenAIStream, StreamingTextResponse } from "ai" | ||
import OpenAI from "openai" | ||
|
||
import { env } from "@/env.mjs" | ||
import { kvget, kvset } from "@/lib/kv" | ||
import { openai } from "@/lib/openai" | ||
import { models } from "@/config/ai" | ||
import { free_credits } from "@/config/subscriptions" | ||
import { kvget, kvgetdec } from "@/lib/kv" | ||
import { getCurrentUser } from "@/lib/session" | ||
|
||
if (!env.OPENAI_API_KEY) { | ||
throw new Error("Missing env var from OpenAI") | ||
} | ||
import { getUserSubscriptionPlan } from "@/lib/subscription" | ||
|
||
export async function POST(req: Request): Promise<Response> { | ||
// Parse the request body. | ||
const body: OpenAIBody = (await req.json()) as OpenAIBody | ||
const { sessionUser: user } = await getCurrentUser() | ||
try { | ||
const body = await req.json() | ||
const { | ||
openai_body, | ||
type = "chat", | ||
api_key, | ||
stream_response = true, | ||
}: { | ||
openai_body: OpenAIBody | ||
type: "chat" | "vision" | ||
api_key: string | ||
stream_response: boolean | ||
} = body | ||
|
||
if (!user?.id) { | ||
return ServerResponse.unauthorized() | ||
} | ||
const { sessionUser: user } = await getCurrentUser() | ||
|
||
if (!body.messages) { | ||
return ServerResponse.error("Missing messages in request body") | ||
} | ||
// The count of the number of times the user has used the AI. | ||
const user_ai_run_count = await kvget(user?.id!, "ai_run_count") | ||
const { isPro } = await getUserSubscriptionPlan(user?.id!) | ||
|
||
const payload: OpenAI.ChatCompletionCreateParams = { | ||
...body, | ||
model: env.OPENAI_MODEL, | ||
stream: true, | ||
} | ||
// Check if the user has exceeded the free credits limit. | ||
// user_ai_run_count !== undefined -- if user generating openai chat first time and cookie not set | ||
if ( | ||
user_ai_run_count !== undefined && | ||
Number(user_ai_run_count) > free_credits && | ||
!isPro | ||
) { | ||
return ServerResponse.error( | ||
"You have exceeded the free credits limit, please upgrade to pro plan to continue using the AI.", | ||
402 | ||
) | ||
} | ||
|
||
if (!user?.id) { | ||
return ServerResponse.unauthorized() | ||
} | ||
|
||
if (!openai_body) { | ||
return ServerResponse.badRequest("Missing openai_body") | ||
} else if (!openai_body.messages) { | ||
return ServerResponse.badRequest("Missing openai_body.messages") | ||
} | ||
|
||
if (!user?.id) { | ||
return ServerResponse.unauthorized() | ||
} | ||
|
||
const response = await openai.chat.completions.create(payload) | ||
const user_ai_run_count = await kvget(user.id, "ai_run_count") | ||
// Get the User provided API Key and API key compatible OpenAI model from KV store. | ||
const api_key_with_model_from_kv = await kvgetdec(user.id, "api_key") | ||
const api_key_from_kv = api_key_with_model_from_kv?.split("::")[0] | ||
const model_from_kv = api_key_with_model_from_kv?.split("::")[1] | ||
|
||
// Count the number of times the user has used the AI. | ||
kvset( | ||
user.id, | ||
"ai_run_count", | ||
user_ai_run_count ? Number(user_ai_run_count) + 1 : 1 | ||
) | ||
let OPENAI_API_KEY | ||
|
||
const stream = OpenAIStream(response) | ||
return new StreamingTextResponse(stream) | ||
if (api_key) { | ||
OPENAI_API_KEY = api_key | ||
} else if (api_key_from_kv) { | ||
OPENAI_API_KEY = api_key_from_kv | ||
} else if (env.OPENAI_API_KEY && isPro) { | ||
OPENAI_API_KEY = env.OPENAI_API_KEY | ||
} | ||
|
||
if (!OPENAI_API_KEY) { | ||
return ServerResponse.unauthorized("Missing OPENAI_API_KEY") | ||
} | ||
|
||
if (!isCorrectApiKey(OPENAI_API_KEY)) { | ||
return ServerResponse.unauthorized("Invalid OPENAI_API_KEY") | ||
} | ||
|
||
const openai = new OpenAI({ apiKey: OPENAI_API_KEY }) | ||
|
||
const payload: OpenAI.ChatCompletionCreateParams = { | ||
...openai_body, | ||
model: model_from_kv || type === "chat" ? models.chat : models.vision, | ||
stream: stream_response, | ||
} | ||
|
||
const response = await openai.chat.completions.create(payload) | ||
|
||
if (stream_response) { | ||
// @ts-ignore | ||
const stream = OpenAIStream(response) | ||
return new StreamingTextResponse(stream) | ||
} else { | ||
return ServerResponse.success({ body: response }) | ||
} | ||
} catch (error) { | ||
return ServerResponse.error( | ||
(error as any).message || String(error), | ||
(error as any).status ?? 500 | ||
) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.