-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #12 from arshad-yaseen/chore/own-api-key
chore: bring own api key
- Loading branch information
Showing
28 changed files
with
814 additions
and
147 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import { ServerResponse } from "@/server/utils" | ||
|
||
import { kvdel, kvgetdec, kvsetenc } from "@/lib/kv" | ||
import { getCurrentUser } from "@/lib/session" | ||
|
||
export async function POST(req: Request): Promise<Response> { | ||
try { | ||
const body = await req.json() | ||
const { apiKey } = body | ||
|
||
const { sessionUser: user } = await getCurrentUser() | ||
|
||
if (!user?.id) { | ||
return ServerResponse.unauthorized() | ||
} | ||
|
||
if (!apiKey) { | ||
return ServerResponse.badRequest("Missing API key") | ||
} | ||
|
||
await kvsetenc(user.id, "api_key", apiKey) | ||
|
||
return ServerResponse.success({ | ||
body: { message: "API key saved" }, | ||
}) | ||
} catch (error) { | ||
return ServerResponse.internalServerError( | ||
error instanceof Error ? error.message : String(error) | ||
) | ||
} | ||
} | ||
|
||
export async function GET(): Promise<Response> { | ||
try { | ||
const { sessionUser: user } = await getCurrentUser() | ||
|
||
if (!user?.id) { | ||
return ServerResponse.unauthorized() | ||
} | ||
|
||
const apiKeyWithModel = (await kvgetdec(user.id, "api_key")) || "" | ||
const apiKey = apiKeyWithModel.split("::")[0] | ||
|
||
if (!apiKey) { | ||
return ServerResponse.notFound("API key not found") | ||
} | ||
|
||
return ServerResponse.success({ | ||
body: { apiKey }, | ||
}) | ||
} catch (error) { | ||
return ServerResponse.internalServerError( | ||
error instanceof Error ? error.message : String(error) | ||
) | ||
} | ||
} | ||
|
||
export async function DELETE(): Promise<Response> { | ||
try { | ||
const { sessionUser: user } = await getCurrentUser() | ||
|
||
if (!user?.id) { | ||
return ServerResponse.unauthorized() | ||
} | ||
|
||
await kvdel(user.id, "api_key") | ||
return ServerResponse.success({ | ||
body: { message: "API key deleted" }, | ||
}) | ||
} catch (error) { | ||
return ServerResponse.error( | ||
(error as any).message || String(error), | ||
(error as any).status ?? 500 | ||
) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,126 @@ | ||
import { ServerResponse } from "@/server/utils" | ||
import { OpenAIBody } from "@/types" | ||
import { isCorrectApiKey } from "@/utils/openai" | ||
import { OpenAIStream, StreamingTextResponse } from "ai" | ||
import OpenAI from "openai" | ||
|
||
import { env } from "@/env.mjs" | ||
import { kvget, kvset } from "@/lib/kv" | ||
import { openai } from "@/lib/openai" | ||
import { models } from "@/config/ai" | ||
import { free_credits } from "@/config/subscriptions" | ||
import { kvget, kvgetdec, kvset } from "@/lib/kv" | ||
import { getCurrentUser } from "@/lib/session" | ||
|
||
if (!env.OPENAI_API_KEY) { | ||
throw new Error("Missing env var from OpenAI") | ||
} | ||
import { getUserSubscriptionPlan } from "@/lib/subscription" | ||
|
||
export async function POST(req: Request): Promise<Response> { | ||
// Parse the request body. | ||
const body: OpenAIBody = (await req.json()) as OpenAIBody | ||
const { sessionUser: user } = await getCurrentUser() | ||
try { | ||
const body = await req.json() | ||
const { | ||
openai_body, | ||
type = "chat", | ||
api_key, | ||
stream_response = true, | ||
}: { | ||
openai_body: OpenAIBody | ||
type: "chat" | "vision" | ||
api_key: string | ||
stream_response: boolean | ||
} = body | ||
|
||
if (!user?.id) { | ||
return ServerResponse.unauthorized() | ||
} | ||
const { sessionUser: user } = await getCurrentUser() | ||
|
||
if (!body.messages) { | ||
return ServerResponse.error("Missing messages in request body") | ||
} | ||
if (!user?.id) { | ||
return ServerResponse.unauthorized() | ||
} | ||
|
||
const payload: OpenAI.ChatCompletionCreateParams = { | ||
...body, | ||
model: env.OPENAI_MODEL, | ||
stream: true, | ||
} | ||
// Get the User provided API Key and API key compatible OpenAI model from KV store. | ||
const api_key_with_model_from_kv = await kvgetdec(user.id, "api_key") | ||
const api_key_from_kv = api_key_with_model_from_kv?.split("::")[0] | ||
const model_from_kv = api_key_with_model_from_kv?.split("::")[1] | ||
|
||
// The count of the number of times the user has used the AI. | ||
const user_ai_run_count = await kvget(user?.id!, "ai_run_count") | ||
const { isPro } = await getUserSubscriptionPlan(user?.id!) | ||
|
||
// Check if the user has exceeded the free credits limit. | ||
// user_ai_run_count !== undefined -- if user generating openai chat first time and cookie not set | ||
if ( | ||
user_ai_run_count !== undefined && | ||
Number(user_ai_run_count) >= free_credits && | ||
!isPro && | ||
!api_key_from_kv && | ||
!api_key | ||
) { | ||
return ServerResponse.error( | ||
"You have exceeded the free credits limit, please upgrade to pro plan to continue using the AI.", | ||
402 | ||
) | ||
} | ||
|
||
if (!openai_body) { | ||
return ServerResponse.badRequest("Missing openai_body") | ||
} else if (!openai_body.messages) { | ||
return ServerResponse.badRequest("Missing openai_body.messages") | ||
} | ||
|
||
if (!user?.id) { | ||
return ServerResponse.unauthorized() | ||
} | ||
|
||
const response = await openai.chat.completions.create(payload) | ||
const user_ai_run_count = await kvget(user.id, "ai_run_count") | ||
let OPENAI_API_KEY | ||
|
||
// Count the number of times the user has used the AI. | ||
kvset( | ||
user.id, | ||
"ai_run_count", | ||
user_ai_run_count ? Number(user_ai_run_count) + 1 : 1 | ||
) | ||
if (isPro) { | ||
OPENAI_API_KEY = env.OPENAI_API_KEY | ||
} else if (api_key) { | ||
OPENAI_API_KEY = api_key | ||
} else if (api_key_from_kv) { | ||
OPENAI_API_KEY = api_key_from_kv | ||
} else if (user_ai_run_count === null) { | ||
// if user generating openai chat first time and cookie not set | ||
OPENAI_API_KEY = env.OPENAI_API_KEY | ||
} | ||
|
||
const stream = OpenAIStream(response) | ||
return new StreamingTextResponse(stream) | ||
if (!OPENAI_API_KEY) { | ||
return ServerResponse.unauthorized("Missing OPENAI_API_KEY") | ||
} | ||
|
||
if (!isCorrectApiKey(OPENAI_API_KEY)) { | ||
return ServerResponse.unauthorized("Invalid OPENAI_API_KEY") | ||
} | ||
|
||
console.log(OPENAI_API_KEY) | ||
|
||
const openai = new OpenAI({ apiKey: OPENAI_API_KEY }) | ||
|
||
const payload: OpenAI.ChatCompletionCreateParams = { | ||
...openai_body, | ||
model: model_from_kv | ||
? model_from_kv | ||
: type === "chat" | ||
? models.chat | ||
: models.vision, | ||
stream: stream_response, | ||
} | ||
|
||
const response = await openai.chat.completions.create(payload) | ||
|
||
// Increment the count of the number of times the user has used the AI. | ||
await kvset( | ||
user?.id!, | ||
"ai_run_count", | ||
!user_ai_run_count ? 1 : Number(user_ai_run_count) + 1 | ||
) | ||
|
||
if (stream_response) { | ||
// @ts-ignore | ||
const stream = OpenAIStream(response) | ||
return new StreamingTextResponse(stream) | ||
} else { | ||
return ServerResponse.success({ body: response }) | ||
} | ||
} catch (error) { | ||
return ServerResponse.error( | ||
(error as any).message || String(error), | ||
(error as any).status ?? 500 | ||
) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
943f014
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Successfully deployed to the following URLs:
markdx – ./
markdx-git-main-arshadpro.vercel.app
markdx-arshadpro.vercel.app
www.markdx.site
markdx.vercel.app
markdx.site