diff --git a/src/app/(main)/repos/[id]/@menu/default.tsx b/src/app/(main)/repos/[id]/@menu/default.tsx index 2d25320f15ea..2c96eec9fd7d 100644 --- a/src/app/(main)/repos/[id]/@menu/default.tsx +++ b/src/app/(main)/repos/[id]/@menu/default.tsx @@ -1,6 +1,7 @@ import { notFound } from 'next/navigation'; import { Flexbox } from 'react-layout-kit'; +import { serverDB } from '@/database/server'; import { KnowledgeBaseModel } from '@/database/server/models/knowledgeBase'; import Head from './Head'; @@ -14,7 +15,7 @@ type Props = { params: Params }; const MenuPage = async ({ params }: Props) => { const id = params.id; - const item = await KnowledgeBaseModel.findById(params.id); + const item = await KnowledgeBaseModel.findById(serverDB, params.id); if (!item) return notFound(); diff --git a/src/app/(main)/repos/[id]/page.tsx b/src/app/(main)/repos/[id]/page.tsx index 24662ac70071..fce88618a3df 100644 --- a/src/app/(main)/repos/[id]/page.tsx +++ b/src/app/(main)/repos/[id]/page.tsx @@ -1,5 +1,6 @@ import { redirect } from 'next/navigation'; +import { serverDB } from '@/database/server'; import { KnowledgeBaseModel } from '@/database/server/models/knowledgeBase'; import FileManager from '@/features/FileManager'; @@ -10,7 +11,7 @@ interface Params { type Props = { params: Params }; export default async ({ params }: Props) => { - const item = await KnowledgeBaseModel.findById(params.id); + const item = await KnowledgeBaseModel.findById(serverDB, params.id); if (!item) return redirect('/repos'); diff --git a/src/config/db.ts b/src/config/db.ts index 6d6361c22f96..4c02c3bf0f43 100644 --- a/src/config/db.ts +++ b/src/config/db.ts @@ -11,8 +11,6 @@ export const getServerDBConfig = () => { DATABASE_TEST_URL: process.env.DATABASE_TEST_URL, DATABASE_URL: process.env.DATABASE_URL, - DISABLE_REMOVE_GLOBAL_FILE: process.env.DISABLE_REMOVE_GLOBAL_FILE === '1', - KEY_VAULTS_SECRET: process.env.KEY_VAULTS_SECRET, NEXT_PUBLIC_ENABLED_SERVER_SERVICE: process.env.NEXT_PUBLIC_SERVICE_MODE === 'server', @@ -22,8 +20,6 @@ export const getServerDBConfig = () => { DATABASE_TEST_URL: z.string().optional(), DATABASE_URL: z.string().optional(), - DISABLE_REMOVE_GLOBAL_FILE: z.boolean().optional(), - KEY_VAULTS_SECRET: z.string().optional(), }, }); diff --git a/src/database/server/models/__tests__/_test_template.ts b/src/database/server/models/__tests__/_test_template.ts index 3a368d6f8cbe..13e77cfbd815 100644 --- a/src/database/server/models/__tests__/_test_template.ts +++ b/src/database/server/models/__tests__/_test_template.ts @@ -16,7 +16,7 @@ vi.mock('@/database/server/core/db', async () => ({ })); const userId = 'session-group-model-test-user-id'; -const sessionGroupModel = new SessionGroupModel(userId); +const sessionGroupModel = new SessionGroupModel(serverDB, userId); beforeEach(async () => { await serverDB.delete(users); @@ -74,7 +74,7 @@ describe('SessionGroupModel', () => { await sessionGroupModel.create({ name: 'Test Group 1' }); await sessionGroupModel.create({ name: 'Test Group 333' }); - const anotherSessionGroupModel = new SessionGroupModel('user2'); + const anotherSessionGroupModel = new SessionGroupModel(serverDB, 'user2'); await anotherSessionGroupModel.create({ name: 'Test Group 2' }); await sessionGroupModel.deleteAll(); diff --git a/src/database/server/models/__tests__/agent.test.ts b/src/database/server/models/__tests__/agent.test.ts index f683106749c1..7c9b69c3c857 100644 --- a/src/database/server/models/__tests__/agent.test.ts +++ b/src/database/server/models/__tests__/agent.test.ts @@ -25,7 +25,7 @@ vi.mock('@/database/server/core/db', async () => ({ })); const userId = 'agent-model-test-user-id'; -const agentModel = new AgentModel(userId); +const agentModel = new AgentModel(serverDB, userId); const knowledgeBase = { id: 'kb1', userId, name: 'knowledgeBase' }; const fileList = [ diff --git a/src/database/server/models/__tests__/asyncTask.test.ts b/src/database/server/models/__tests__/asyncTask.test.ts index 3d587af9c8b4..fc21d3002d5e 100644 --- a/src/database/server/models/__tests__/asyncTask.test.ts +++ b/src/database/server/models/__tests__/asyncTask.test.ts @@ -17,7 +17,7 @@ vi.mock('@/database/server/core/db', async () => ({ })); const userId = 'async-task-model-test-user-id'; -const asyncTaskModel = new AsyncTaskModel(userId); +const asyncTaskModel = new AsyncTaskModel(serverDB, userId); beforeEach(async () => { await serverDB.delete(users); diff --git a/src/database/server/models/__tests__/chunk.test.ts b/src/database/server/models/__tests__/chunk.test.ts index 4e94c35a55a2..81582f9b6ed1 100644 --- a/src/database/server/models/__tests__/chunk.test.ts +++ b/src/database/server/models/__tests__/chunk.test.ts @@ -24,7 +24,7 @@ vi.mock('@/database/server/core/db', async () => ({ })); const userId = 'chunk-model-test-user-id'; -const chunkModel = new ChunkModel(userId); +const chunkModel = new ChunkModel(serverDB, userId); const sharedFileList = [ { id: '1', diff --git a/src/database/server/models/__tests__/file.test.ts b/src/database/server/models/__tests__/file.test.ts index b02a89bd4ebb..6847600e08a1 100644 --- a/src/database/server/models/__tests__/file.test.ts +++ b/src/database/server/models/__tests__/file.test.ts @@ -38,7 +38,7 @@ vi.mock('@/config/db', async () => ({ })); const userId = 'file-model-test-user-id'; -const fileModel = new FileModel(userId); +const fileModel = new FileModel(serverDB, userId); const knowledgeBase = { id: 'kb1', userId, name: 'knowledgeBase' }; beforeEach(async () => { diff --git a/src/database/server/models/__tests__/knowledgeBase.test.ts b/src/database/server/models/__tests__/knowledgeBase.test.ts index 53fa49eeec0d..28b1c06656ea 100644 --- a/src/database/server/models/__tests__/knowledgeBase.test.ts +++ b/src/database/server/models/__tests__/knowledgeBase.test.ts @@ -24,7 +24,7 @@ vi.mock('@/database/server/core/db', async () => ({ })); const userId = 'session-group-model-test-user-id'; -const knowledgeBaseModel = new KnowledgeBaseModel(userId); +const knowledgeBaseModel = new KnowledgeBaseModel(serverDB, userId); beforeEach(async () => { await serverDB.delete(users); @@ -82,7 +82,7 @@ describe('KnowledgeBaseModel', () => { await knowledgeBaseModel.create({ name: 'Test Group 1' }); await knowledgeBaseModel.create({ name: 'Test Group 333' }); - const anotherSessionGroupModel = new KnowledgeBaseModel('user2'); + const anotherSessionGroupModel = new KnowledgeBaseModel(serverDB, 'user2'); await anotherSessionGroupModel.create({ name: 'Test Group 2' }); await knowledgeBaseModel.deleteAll(); @@ -235,7 +235,7 @@ describe('KnowledgeBaseModel', () => { it('should find a knowledge base by id without user restriction', async () => { const { id } = await knowledgeBaseModel.create({ name: 'Test Group' }); - const group = await KnowledgeBaseModel.findById(id); + const group = await KnowledgeBaseModel.findById(serverDB, id); expect(group).toMatchObject({ id, name: 'Test Group', @@ -244,10 +244,10 @@ describe('KnowledgeBaseModel', () => { }); it('should find a knowledge base created by another user', async () => { - const anotherKnowledgeBaseModel = new KnowledgeBaseModel('user2'); + const anotherKnowledgeBaseModel = new KnowledgeBaseModel(serverDB, 'user2'); const { id } = await anotherKnowledgeBaseModel.create({ name: 'Another User Group' }); - const group = await KnowledgeBaseModel.findById(id); + const group = await KnowledgeBaseModel.findById(serverDB, id); expect(group).toMatchObject({ id, name: 'Another User Group', diff --git a/src/database/server/models/__tests__/message.test.ts b/src/database/server/models/__tests__/message.test.ts index 9079f8bb6a2a..0e9dec4020a8 100644 --- a/src/database/server/models/__tests__/message.test.ts +++ b/src/database/server/models/__tests__/message.test.ts @@ -25,7 +25,7 @@ vi.mock('@/database/server/core/db', async () => ({ })); const userId = 'message-db'; -const messageModel = new MessageModel(userId); +const messageModel = new MessageModel(serverDB, userId); beforeEach(async () => { // 在每个测试用例之前,清空表 diff --git a/src/database/server/models/__tests__/plugin.test.ts b/src/database/server/models/__tests__/plugin.test.ts index 821864238533..75c06a626f0a 100644 --- a/src/database/server/models/__tests__/plugin.test.ts +++ b/src/database/server/models/__tests__/plugin.test.ts @@ -15,7 +15,7 @@ vi.mock('@/database/server/core/db', async () => ({ })); const userId = 'plugin-db'; -const pluginModel = new PluginModel(userId); +const pluginModel = new PluginModel(serverDB, userId); beforeEach(async () => { await serverDB.transaction(async (trx) => { diff --git a/src/database/server/models/__tests__/session.test.ts b/src/database/server/models/__tests__/session.test.ts index 0a91adb6f7f8..88a39033bc2a 100644 --- a/src/database/server/models/__tests__/session.test.ts +++ b/src/database/server/models/__tests__/session.test.ts @@ -26,7 +26,7 @@ vi.mock('@/database/server/core/db', async () => ({ })); const userId = 'session-user'; -const sessionModel = new SessionModel(userId); +const sessionModel = new SessionModel(serverDB, userId); beforeEach(async () => { await serverDB.delete(users); @@ -259,7 +259,13 @@ describe('SessionModel', () => { ]); await serverDB.insert(agents).values([ - { id: 'agent-1', userId, model: 'gpt-3.5-turbo', title: 'Agent 1', description: 'Description with Keyword' }, + { + id: 'agent-1', + userId, + model: 'gpt-3.5-turbo', + title: 'Agent 1', + description: 'Description with Keyword', + }, { id: 'agent-2', userId, model: 'gpt-4', title: 'Agent 2' }, ]); @@ -338,7 +344,7 @@ describe('SessionModel', () => { }); }); - describe.skip('batchCreate', () => { + describe('batchCreate', () => { it('should batch create sessions', async () => { // 调用 batchCreate 方法 const sessions: NewSession[] = [ diff --git a/src/database/server/models/__tests__/sessionGroup.test.ts b/src/database/server/models/__tests__/sessionGroup.test.ts index 2dd58c2c755e..6072ff9613a6 100644 --- a/src/database/server/models/__tests__/sessionGroup.test.ts +++ b/src/database/server/models/__tests__/sessionGroup.test.ts @@ -17,7 +17,7 @@ vi.mock('@/database/server/core/db', async () => ({ })); const userId = 'session-group-model-test-user-id'; -const sessionGroupModel = new SessionGroupModel(userId); +const sessionGroupModel = new SessionGroupModel(serverDB, userId); beforeEach(async () => { await serverDB.delete(users); @@ -75,7 +75,7 @@ describe('SessionGroupModel', () => { await sessionGroupModel.create({ name: 'Test Group 1' }); await sessionGroupModel.create({ name: 'Test Group 333' }); - const anotherSessionGroupModel = new SessionGroupModel('user2'); + const anotherSessionGroupModel = new SessionGroupModel(serverDB, 'user2'); await anotherSessionGroupModel.create({ name: 'Test Group 2' }); await sessionGroupModel.deleteAll(); diff --git a/src/database/server/models/__tests__/user.test.ts b/src/database/server/models/__tests__/user.test.ts index 60c945d9fe20..36941b5744dc 100644 --- a/src/database/server/models/__tests__/user.test.ts +++ b/src/database/server/models/__tests__/user.test.ts @@ -21,7 +21,7 @@ vi.mock('@/database/server/core/db', async () => ({ const userId = 'user-db'; const userEmail = 'user@example.com'; -const userModel = new UserModel(); +const userModel = new UserModel(serverDB, userId); beforeEach(async () => { await serverDB.delete(users); @@ -44,14 +44,14 @@ describe('UserModel', () => { email: 'test@example.com', }; - await UserModel.createUser(params); + await UserModel.createUser(serverDB, params); const user = await serverDB.query.users.findFirst({ where: eq(users.id, userId) }); expect(user).not.toBeNull(); expect(user?.username).toBe('testuser'); expect(user?.email).toBe('test@example.com'); - const sessionModel = new SessionModel(userId); + const sessionModel = new SessionModel(serverDB, userId); const inbox = await sessionModel.findByIdOrSlug(INBOX_SESSION_ID); expect(inbox).not.toBeNull(); }); @@ -61,7 +61,7 @@ describe('UserModel', () => { it('should delete a user', async () => { await serverDB.insert(users).values({ id: userId }); - await UserModel.deleteUser(userId); + await UserModel.deleteUser(serverDB, userId); const user = await serverDB.query.users.findFirst({ where: eq(users.id, userId) }); expect(user).toBeUndefined(); @@ -72,7 +72,7 @@ describe('UserModel', () => { it('should find a user by ID', async () => { await serverDB.insert(users).values({ id: userId, username: 'testuser' }); - const user = await UserModel.findById(userId); + const user = await UserModel.findById(serverDB, userId); expect(user).not.toBeNull(); expect(user?.id).toBe(userId); @@ -84,7 +84,7 @@ describe('UserModel', () => { it('should find a user by email', async () => { await serverDB.insert(users).values({ id: userId, email: userEmail }); - const user = await UserModel.findByEmail(userEmail); + const user = await UserModel.findByEmail(serverDB, userEmail); expect(user).not.toBeNull(); expect(user?.id).toBe(userId); @@ -107,7 +107,7 @@ describe('UserModel', () => { keyVaults: encryptedKeyVaults, }); - const state = await userModel.getUserState(userId); + const state = await userModel.getUserState(); expect(state.userId).toBe(userId); expect(state.preference).toEqual(preference); @@ -115,7 +115,9 @@ describe('UserModel', () => { }); it('should throw an error if user not found', async () => { - await expect(userModel.getUserState('invalid-user-id')).rejects.toThrow('user not found'); + const userModel = new UserModel(serverDB, 'invalid-user-id'); + + await expect(userModel.getUserState()).rejects.toThrow('user not found'); }); }); @@ -123,7 +125,7 @@ describe('UserModel', () => { it('should update user fields', async () => { await serverDB.insert(users).values({ id: userId, username: 'oldname' }); - await userModel.updateUser(userId, { username: 'newname' }); + await userModel.updateUser({ username: 'newname' }); const updatedUser = await serverDB.query.users.findFirst({ where: eq(users.id, userId), @@ -137,7 +139,7 @@ describe('UserModel', () => { await serverDB.insert(users).values({ id: userId }); await serverDB.insert(userSettings).values({ id: userId }); - await userModel.deleteSetting(userId); + await userModel.deleteSetting(); const settings = await serverDB.query.userSettings.findFirst({ where: eq(users.id, userId), @@ -155,7 +157,7 @@ describe('UserModel', () => { } as UserSettings; await serverDB.insert(users).values({ id: userId }); - await userModel.updateSetting(userId, settings); + await userModel.updateSetting(settings); const updatedSettings = await serverDB.query.userSettings.findFirst({ where: eq(users.id, userId), @@ -178,7 +180,7 @@ describe('UserModel', () => { const newSettings = { general: { fontSize: 16, language: 'zh-CN', themeMode: 'dark' }, } as UserSettings; - await userModel.updateSetting(userId, newSettings); + await userModel.updateSetting(newSettings); const updatedSettings = await serverDB.query.userSettings.findFirst({ where: eq(users.id, userId), @@ -195,7 +197,7 @@ describe('UserModel', () => { const newPreference: Partial = { guide: { topic: true, moveSettingsToAvatar: true }, }; - await userModel.updatePreference(userId, newPreference); + await userModel.updatePreference(newPreference); const updatedUser = await serverDB.query.users.findFirst({ where: eq(users.id, userId) }); expect(updatedUser?.preference).toEqual({ ...preference, ...newPreference }); @@ -212,7 +214,7 @@ describe('UserModel', () => { moveSettingsToAvatar: true, uploadFileInKnowledgeBase: true, }; - await userModel.updateGuide(userId, newGuide); + await userModel.updateGuide(newGuide); const updatedUser = await serverDB.query.users.findFirst({ where: eq(users.id, userId) }); expect(updatedUser?.preference).toEqual({ ...preference, guide: newGuide }); diff --git a/src/database/server/models/_template.ts b/src/database/server/models/_template.ts index 3699bf87a346..272ec594b12f 100644 --- a/src/database/server/models/_template.ts +++ b/src/database/server/models/_template.ts @@ -1,19 +1,21 @@ import { eq } from 'drizzle-orm'; import { and, desc } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; +import { LobeChatDatabase } from '@/database/type'; import { NewSessionGroup, SessionGroupItem, sessionGroups } from '../schemas/lobechat'; export class TemplateModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } create = async (params: NewSessionGroup) => { - const [result] = await serverDB + const [result] = await this.db .insert(sessionGroups) .values({ ...params, userId: this.userId }) .returning(); @@ -22,30 +24,30 @@ export class TemplateModel { }; delete = async (id: string) => { - return serverDB + return this.db .delete(sessionGroups) .where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId))); }; deleteAll = async () => { - return serverDB.delete(sessionGroups).where(eq(sessionGroups.userId, this.userId)); + return this.db.delete(sessionGroups).where(eq(sessionGroups.userId, this.userId)); }; query = async () => { - return serverDB.query.sessionGroups.findMany({ + return this.db.query.sessionGroups.findMany({ orderBy: [desc(sessionGroups.updatedAt)], where: eq(sessionGroups.userId, this.userId), }); }; findById = async (id: string) => { - return serverDB.query.sessionGroups.findFirst({ + return this.db.query.sessionGroups.findFirst({ where: and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)), }); }; async update(id: string, value: Partial) { - return serverDB + return this.db .update(sessionGroups) .set({ ...value, updatedAt: new Date() }) .where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId))); diff --git a/src/database/server/models/agent.ts b/src/database/server/models/agent.ts index a65eb21cdfb8..e90abf0bfe73 100644 --- a/src/database/server/models/agent.ts +++ b/src/database/server/models/agent.ts @@ -1,7 +1,8 @@ import { inArray } from 'drizzle-orm'; import { and, desc, eq } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; +import { LobeChatDatabase } from '@/database/type'; + import { agents, agentsFiles, @@ -9,16 +10,19 @@ import { agentsToSessions, files, knowledgeBases, -} from '@/database/server/schemas/lobechat'; +} from '../schemas/lobechat'; export class AgentModel { private userId: string; - constructor(userId: string) { + private db: LobeChatDatabase; + + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } async getAgentConfigById(id: string) { - const agent = await serverDB.query.agents.findFirst({ where: eq(agents.id, id) }); + const agent = await this.db.query.agents.findFirst({ where: eq(agents.id, id) }); const knowledge = await this.getAgentAssignedKnowledge(id); @@ -26,14 +30,14 @@ export class AgentModel { } async getAgentAssignedKnowledge(id: string) { - const knowledgeBaseResult = await serverDB + const knowledgeBaseResult = await this.db .select({ enabled: agentsKnowledgeBases.enabled, knowledgeBases }) .from(agentsKnowledgeBases) .where(eq(agentsKnowledgeBases.agentId, id)) .orderBy(desc(agentsKnowledgeBases.createdAt)) .leftJoin(knowledgeBases, eq(knowledgeBases.id, agentsKnowledgeBases.knowledgeBaseId)); - const fileResult = await serverDB + const fileResult = await this.db .select({ enabled: agentsFiles.enabled, files }) .from(agentsFiles) .where(eq(agentsFiles.agentId, id)) @@ -56,7 +60,7 @@ export class AgentModel { * Find agent by session id */ async findBySessionId(sessionId: string) { - const item = await serverDB.query.agentsToSessions.findFirst({ + const item = await this.db.query.agentsToSessions.findFirst({ where: eq(agentsToSessions.sessionId, sessionId), }); if (!item) return; @@ -71,7 +75,7 @@ export class AgentModel { knowledgeBaseId: string, enabled: boolean = true, ) => { - return serverDB + return this.db .insert(agentsKnowledgeBases) .values({ agentId, @@ -83,7 +87,7 @@ export class AgentModel { }; deleteAgentKnowledgeBase = async (agentId: string, knowledgeBaseId: string) => { - return serverDB + return this.db .delete(agentsKnowledgeBases) .where( and( @@ -96,7 +100,7 @@ export class AgentModel { }; toggleKnowledgeBase = async (agentId: string, knowledgeBaseId: string, enabled?: boolean) => { - return serverDB + return this.db .update(agentsKnowledgeBases) .set({ enabled }) .where( @@ -111,7 +115,7 @@ export class AgentModel { createAgentFiles = async (agentId: string, fileIds: string[], enabled: boolean = true) => { // Exclude the fileIds that already exist in agentsFiles, and then insert them - const existingFiles = await serverDB + const existingFiles = await this.db .select({ id: agentsFiles.fileId }) .from(agentsFiles) .where( @@ -128,7 +132,7 @@ export class AgentModel { if (needToInsertFileIds.length === 0) return; - return serverDB + return this.db .insert(agentsFiles) .values( needToInsertFileIds.map((fileId) => ({ agentId, enabled, fileId, userId: this.userId })), @@ -137,7 +141,7 @@ export class AgentModel { }; deleteAgentFile = async (agentId: string, fileId: string) => { - return serverDB + return this.db .delete(agentsFiles) .where( and( @@ -150,7 +154,7 @@ export class AgentModel { }; toggleFile = async (agentId: string, fileId: string, enabled?: boolean) => { - return serverDB + return this.db .update(agentsFiles) .set({ enabled }) .where( diff --git a/src/database/server/models/asyncTask.ts b/src/database/server/models/asyncTask.ts index 94235505b3cb..f95aede4248b 100644 --- a/src/database/server/models/asyncTask.ts +++ b/src/database/server/models/asyncTask.ts @@ -1,7 +1,7 @@ import { eq, inArray, lt } from 'drizzle-orm'; import { and } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; +import { LobeChatDatabase } from '@/database/type'; import { AsyncTaskError, AsyncTaskErrorType, @@ -16,13 +16,15 @@ export const ASYNC_TASK_TIMEOUT = 298 * 1000; export class AsyncTaskModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } create = async (params: Pick): Promise => { - const data = await serverDB + const data = await this.db .insert(asyncTasks) .values({ ...params, userId: this.userId }) .returning(); @@ -31,17 +33,17 @@ export class AsyncTaskModel { }; delete = async (id: string) => { - return serverDB + return this.db .delete(asyncTasks) .where(and(eq(asyncTasks.id, id), eq(asyncTasks.userId, this.userId))); }; findById = async (id: string) => { - return serverDB.query.asyncTasks.findFirst({ where: and(eq(asyncTasks.id, id)) }); + return this.db.query.asyncTasks.findFirst({ where: and(eq(asyncTasks.id, id)) }); }; update(taskId: string, value: Partial) { - return serverDB + return this.db .update(asyncTasks) .set({ ...value, updatedAt: new Date() }) .where(and(eq(asyncTasks.id, taskId))); @@ -52,7 +54,7 @@ export class AsyncTaskModel { if (taskIds.length > 0) { await this.checkTimeoutTasks(taskIds); - chunkTasks = await serverDB.query.asyncTasks.findMany({ + chunkTasks = await this.db.query.asyncTasks.findMany({ where: and(inArray(asyncTasks.id, taskIds), eq(asyncTasks.type, type)), }); } @@ -64,7 +66,7 @@ export class AsyncTaskModel { * make the task status to be `error` if the task is not finished in 20 seconds */ async checkTimeoutTasks(ids: string[]) { - const tasks = await serverDB + const tasks = await this.db .select({ id: asyncTasks.id }) .from(asyncTasks) .where( @@ -76,7 +78,7 @@ export class AsyncTaskModel { ); if (tasks.length > 0) { - await serverDB + await this.db .update(asyncTasks) .set({ error: new AsyncTaskError( diff --git a/src/database/server/models/chunk.ts b/src/database/server/models/chunk.ts index 9c4dc7000def..91e8298daf6d 100644 --- a/src/database/server/models/chunk.ts +++ b/src/database/server/models/chunk.ts @@ -2,7 +2,7 @@ import { asc, cosineDistance, count, eq, inArray, sql } from 'drizzle-orm'; import { and, desc, isNull } from 'drizzle-orm/expressions'; import { chunk } from 'lodash-es'; -import { serverDB } from '@/database/server'; +import { LobeChatDatabase } from '@/database/type'; import { ChunkMetadata, FileChunk } from '@/types/chunk'; import { @@ -18,12 +18,15 @@ import { export class ChunkModel { private userId: string; - constructor(userId: string) { + private db: LobeChatDatabase; + + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } bulkCreate = async (params: NewChunkItem[], fileId: string) => { - return serverDB.transaction(async (trx) => { + return this.db.transaction(async (trx) => { const result = await trx.insert(chunks).values(params).returning(); const fileChunksData = result.map((chunk) => ({ chunkId: chunk.id, fileId })); @@ -37,15 +40,15 @@ export class ChunkModel { }; bulkCreateUnstructuredChunks = async (params: NewUnstructuredChunkItem[]) => { - return serverDB.insert(unstructuredChunks).values(params); + return this.db.insert(unstructuredChunks).values(params); }; delete = async (id: string) => { - return serverDB.delete(chunks).where(and(eq(chunks.id, id), eq(chunks.userId, this.userId))); + return this.db.delete(chunks).where(and(eq(chunks.id, id), eq(chunks.userId, this.userId))); }; deleteOrphanChunks = async () => { - const orphanedChunks = await serverDB + const orphanedChunks = await this.db .select({ chunkId: chunks.id }) .from(chunks) .leftJoin(fileChunks, eq(chunks.id, fileChunks.chunkId)) @@ -56,7 +59,7 @@ export class ChunkModel { const list = chunk(ids, 500); - await serverDB.transaction(async (trx) => { + await this.db.transaction(async (trx) => { await Promise.all( list.map(async (chunkIds) => { await trx.delete(chunks).where(inArray(chunks.id, chunkIds)); @@ -66,13 +69,13 @@ export class ChunkModel { }; findById = async (id: string) => { - return serverDB.query.chunks.findFirst({ + return this.db.query.chunks.findFirst({ where: and(eq(chunks.id, id)), }); }; async findByFileId(id: string, page = 0) { - const data = await serverDB + const data = await this.db .select({ abstract: chunks.abstract, createdAt: chunks.createdAt, @@ -98,7 +101,7 @@ export class ChunkModel { } async getChunksTextByFileId(id: string): Promise<{ id: string; text: string }[]> { - const data = await serverDB + const data = await this.db .select() .from(chunks) .innerJoin(fileChunks, eq(chunks.id, fileChunks.chunkId)) @@ -113,7 +116,7 @@ export class ChunkModel { async countByFileIds(ids: string[]) { if (ids.length === 0) return []; - return serverDB + return this.db .select({ count: count(fileChunks.chunkId), id: fileChunks.fileId, @@ -124,7 +127,7 @@ export class ChunkModel { } async countByFileId(ids: string) { - const data = await serverDB + const data = await this.db .select({ count: count(fileChunks.chunkId), id: fileChunks.fileId, @@ -146,7 +149,7 @@ export class ChunkModel { }) { const similarity = sql`1 - (${cosineDistance(embeddings.embeddings, embedding)})`; - const data = await serverDB + const data = await this.db .select({ fileId: fileChunks.fileId, fileName: files.name, @@ -185,7 +188,7 @@ export class ChunkModel { if (!hasFiles) return []; - const result = await serverDB + const result = await this.db .select({ fileId: files.id, fileName: files.name, diff --git a/src/database/server/models/embedding.ts b/src/database/server/models/embedding.ts index ffe36ba6a4ef..f0b06666a7b4 100644 --- a/src/database/server/models/embedding.ts +++ b/src/database/server/models/embedding.ts @@ -1,19 +1,21 @@ import { count, eq } from 'drizzle-orm'; import { and } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; +import { LobeChatDatabase } from '@/database/type'; import { NewEmbeddingsItem, embeddings } from '../schemas/lobechat'; export class EmbeddingModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } create = async (value: Omit) => { - const [item] = await serverDB + const [item] = await this.db .insert(embeddings) .values({ ...value, userId: this.userId }) .returning(); @@ -22,7 +24,7 @@ export class EmbeddingModel { }; bulkCreate = async (values: Omit[]) => { - return serverDB + return this.db .insert(embeddings) .values(values.map((item) => ({ ...item, userId: this.userId }))) .onConflictDoNothing({ @@ -31,25 +33,25 @@ export class EmbeddingModel { }; delete = async (id: string) => { - return serverDB + return this.db .delete(embeddings) .where(and(eq(embeddings.id, id), eq(embeddings.userId, this.userId))); }; query = async () => { - return serverDB.query.embeddings.findMany({ + return this.db.query.embeddings.findMany({ where: eq(embeddings.userId, this.userId), }); }; findById = async (id: string) => { - return serverDB.query.embeddings.findFirst({ + return this.db.query.embeddings.findFirst({ where: and(eq(embeddings.id, id), eq(embeddings.userId, this.userId)), }); }; countUsage = async () => { - const result = await serverDB + const result = await this.db .select({ count: count(), }) diff --git a/src/database/server/models/file.ts b/src/database/server/models/file.ts index 189065490174..52cda3d56c61 100644 --- a/src/database/server/models/file.ts +++ b/src/database/server/models/file.ts @@ -2,8 +2,7 @@ import { asc, count, eq, ilike, inArray, notExists, or, sum } from 'drizzle-orm' import { and, desc, like } from 'drizzle-orm/expressions'; import type { PgTransaction } from 'drizzle-orm/pg-core'; -import { serverDBEnv } from '@/config/db'; -import { serverDB } from '@/database/server/core/db'; +import { LobeChatDatabase } from '@/database/type'; import { FilesTabs, QueryFileListParams, SortType } from '@/types/files'; import { @@ -20,13 +19,15 @@ import { export class FileModel { private readonly userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } create = async (params: Omit & { knowledgeBaseId?: string }) => { - const result = await serverDB.transaction(async (trx) => { + const result = await this.db.transaction(async (trx) => { const result = await trx .insert(files) .values({ ...params, userId: this.userId }) @@ -47,11 +48,11 @@ export class FileModel { }; createGlobalFile = async (file: Omit) => { - return serverDB.insert(globalFiles).values(file).returning(); + return this.db.insert(globalFiles).values(file).returning(); }; checkHash = async (hash: string) => { - const item = await serverDB.query.globalFiles.findFirst({ + const item = await this.db.query.globalFiles.findFirst({ where: eq(globalFiles.hashId, hash), }); if (!item) return { isExist: false }; @@ -71,7 +72,7 @@ export class FileModel { const fileHash = file.fileHash!; - return await serverDB.transaction(async (trx) => { + return await this.db.transaction(async (trx) => { // 1. 删除相关的 chunks await this.deleteFileChunks(trx as any, [id]); @@ -86,8 +87,7 @@ export class FileModel { const fileCount = result[0].count; // delete the file from global file if it is not used by other files - // if `DISABLE_REMOVE_GLOBAL_FILE` is true, we will not remove the global file - if (fileCount === 0 && !serverDBEnv.DISABLE_REMOVE_GLOBAL_FILE) { + if (fileCount === 0) { await trx.delete(globalFiles).where(eq(globalFiles.hashId, fileHash)); return file; @@ -96,11 +96,11 @@ export class FileModel { }; deleteGlobalFile = async (hashId: string) => { - return serverDB.delete(globalFiles).where(eq(globalFiles.hashId, hashId)); + return this.db.delete(globalFiles).where(eq(globalFiles.hashId, hashId)); }; countUsage = async () => { - const result = await serverDB + const result = await this.db .select({ totalSize: sum(files.size), }) @@ -114,7 +114,7 @@ export class FileModel { const fileList = await this.findByIds(ids); const hashList = fileList.map((file) => file.fileHash!); - return await serverDB.transaction(async (trx) => { + return await this.db.transaction(async (trx) => { // 1. 删除相关的 chunks await this.deleteFileChunks(trx as any, ids); @@ -142,7 +142,7 @@ export class FileModel { const needToDeleteList = fileHashCounts.filter((item) => item.count === 0); - if (needToDeleteList.length === 0 || serverDBEnv.DISABLE_REMOVE_GLOBAL_FILE) return; + if (needToDeleteList.length === 0) return; // delete the file from global file if it is not used by other files await trx.delete(globalFiles).where( @@ -159,7 +159,7 @@ export class FileModel { }; clear = async () => { - return serverDB.delete(files).where(eq(files.userId, this.userId)); + return this.db.delete(files).where(eq(files.userId, this.userId)); }; query = async ({ @@ -198,7 +198,7 @@ export class FileModel { } // 3. build query - let query = serverDB + let query = this.db .select({ chunkTaskId: files.chunkTaskId, createdAt: files.createdAt, @@ -230,7 +230,7 @@ export class FileModel { whereClause = and( whereClause, notExists( - serverDB.select().from(knowledgeBaseFiles).where(eq(knowledgeBaseFiles.fileId, files.id)), + this.db.select().from(knowledgeBaseFiles).where(eq(knowledgeBaseFiles.fileId, files.id)), ), ); } @@ -240,19 +240,19 @@ export class FileModel { }; findByIds = async (ids: string[]) => { - return serverDB.query.files.findMany({ + return this.db.query.files.findMany({ where: and(inArray(files.id, ids), eq(files.userId, this.userId)), }); }; findById = async (id: string) => { - return serverDB.query.files.findFirst({ + return this.db.query.files.findFirst({ where: and(eq(files.id, id), eq(files.userId, this.userId)), }); }; countFilesByHash = async (hash: string) => { - const result = await serverDB + const result = await this.db .select({ count: count(), }) @@ -263,7 +263,7 @@ export class FileModel { }; async update(id: string, value: Partial) { - return serverDB + return this.db .update(files) .set({ ...value, updatedAt: new Date() }) .where(and(eq(files.id, id), eq(files.userId, this.userId))); @@ -293,7 +293,7 @@ export class FileModel { }; async findByNames(fileNames: string[]) { - return serverDB.query.files.findMany({ + return this.db.query.files.findMany({ where: and( or(...fileNames.map((name) => like(files.name, `${name}%`))), eq(files.userId, this.userId), diff --git a/src/database/server/models/knowledgeBase.ts b/src/database/server/models/knowledgeBase.ts index 1dd852f70f1e..0af61a198bde 100644 --- a/src/database/server/models/knowledgeBase.ts +++ b/src/database/server/models/knowledgeBase.ts @@ -1,22 +1,24 @@ import { eq, inArray } from 'drizzle-orm'; import { and, desc } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; +import { LobeChatDatabase } from '@/database/type'; import { KnowledgeBaseItem } from '@/types/knowledgeBase'; import { NewKnowledgeBase, knowledgeBaseFiles, knowledgeBases } from '../schemas/lobechat'; export class KnowledgeBaseModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } // create create = async (params: Omit) => { - const [result] = await serverDB + const [result] = await this.db .insert(knowledgeBases) .values({ ...params, userId: this.userId }) .returning(); @@ -25,7 +27,7 @@ export class KnowledgeBaseModel { }; addFilesToKnowledgeBase = async (id: string, fileIds: string[]) => { - return serverDB + return this.db .insert(knowledgeBaseFiles) .values(fileIds.map((fileId) => ({ fileId, knowledgeBaseId: id, userId: this.userId }))) .returning(); @@ -33,17 +35,17 @@ export class KnowledgeBaseModel { // delete delete = async (id: string) => { - return serverDB + return this.db .delete(knowledgeBases) .where(and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId))); }; deleteAll = async () => { - return serverDB.delete(knowledgeBases).where(eq(knowledgeBases.userId, this.userId)); + return this.db.delete(knowledgeBases).where(eq(knowledgeBases.userId, this.userId)); }; removeFilesFromKnowledgeBase = async (knowledgeBaseId: string, ids: string[]) => { - return serverDB.delete(knowledgeBaseFiles).where( + return this.db.delete(knowledgeBaseFiles).where( and( eq(knowledgeBaseFiles.knowledgeBaseId, knowledgeBaseId), inArray(knowledgeBaseFiles.fileId, ids), @@ -53,7 +55,7 @@ export class KnowledgeBaseModel { }; // query query = async () => { - const data = await serverDB + const data = await this.db .select({ avatar: knowledgeBases.avatar, createdAt: knowledgeBases.createdAt, @@ -73,21 +75,21 @@ export class KnowledgeBaseModel { }; findById = async (id: string) => { - return serverDB.query.knowledgeBases.findFirst({ + return this.db.query.knowledgeBases.findFirst({ where: and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId)), }); }; // update async update(id: string, value: Partial) { - return serverDB + return this.db .update(knowledgeBases) .set({ ...value, updatedAt: new Date() }) .where(and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId))); } - static async findById(id: string) { - return serverDB.query.knowledgeBases.findFirst({ + static async findById(db: LobeChatDatabase, id: string) { + return db.query.knowledgeBases.findFirst({ where: eq(knowledgeBases.id, id), }); } diff --git a/src/database/server/models/message.ts b/src/database/server/models/message.ts index 24670a680e9d..9a828bb3aa40 100644 --- a/src/database/server/models/message.ts +++ b/src/database/server/models/message.ts @@ -1,8 +1,7 @@ import { count } from 'drizzle-orm'; import { and, asc, desc, eq, gte, inArray, isNull, like, lt } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server/core/db'; -import { idGenerator } from '@/database/server/utils/idGenerator'; +import { LobeChatDatabase } from '@/database/type'; import { getFullFileUrl } from '@/server/utils/files'; import { ChatFileItem, @@ -29,6 +28,7 @@ import { messages, messagesFiles, } from '../schemas/lobechat'; +import { idGenerator } from '../utils/idGenerator'; export interface QueryMessageParams { current?: number; @@ -39,9 +39,11 @@ export interface QueryMessageParams { export class MessageModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } // **************** Query *************** // @@ -54,7 +56,7 @@ export class MessageModel { const offset = current * pageSize; // 1. get basic messages - const result = await serverDB + const result = await this.db .select({ /* eslint-disable sort-keys-fix/sort-keys-fix*/ id: messages.id, @@ -115,7 +117,7 @@ export class MessageModel { if (messageIds.length === 0) return []; // 2. get relative files - const rawRelatedFileList = await serverDB + const rawRelatedFileList = await this.db .select({ fileType: files.fileType, id: messagesFiles.fileId, @@ -139,7 +141,7 @@ export class MessageModel { const fileList = relatedFileList.filter((i) => !(i.fileType || '').startsWith('image')); // 3. get relative file chunks - const chunksList = await serverDB + const chunksList = await this.db .select({ fileId: files.id, fileType: files.fileType, @@ -157,7 +159,7 @@ export class MessageModel { .where(inArray(messageQueryChunks.messageId, messageIds)); // 3. get relative message query - const messageQueriesList = await serverDB + const messageQueriesList = await this.db .select({ id: messageQueries.id, messageId: messageQueries.messageId, @@ -216,13 +218,13 @@ export class MessageModel { } async findById(id: string) { - return serverDB.query.messages.findFirst({ + return this.db.query.messages.findFirst({ where: and(eq(messages.id, id), eq(messages.userId, this.userId)), }); } async findMessageQueriesById(messageId: string) { - const result = await serverDB + const result = await this.db .select({ embeddings: embeddings.embeddings, id: messageQueries.id, @@ -240,7 +242,7 @@ export class MessageModel { } async queryAll(): Promise { - return serverDB + return this.db .select() .from(messages) .orderBy(messages.createdAt) @@ -250,7 +252,7 @@ export class MessageModel { } async queryBySessionId(sessionId?: string | null): Promise { - return serverDB.query.messages.findMany({ + return this.db.query.messages.findMany({ orderBy: [asc(messages.createdAt)], where: and(eq(messages.userId, this.userId), this.matchSession(sessionId)), }); @@ -259,14 +261,14 @@ export class MessageModel { async queryByKeyword(keyword: string): Promise { if (!keyword) return []; - return serverDB.query.messages.findMany({ + return this.db.query.messages.findMany({ orderBy: [desc(messages.createdAt)], where: and(eq(messages.userId, this.userId), like(messages.content, `%${keyword}%`)), }); } async count() { - const result = await serverDB + const result = await this.db .select({ count: count(), }) @@ -283,7 +285,7 @@ export class MessageModel { const tomorrow = new Date(today); tomorrow.setDate(tomorrow.getDate() + 1); - const result = await serverDB + const result = await this.db .select({ count: count(), }) @@ -315,7 +317,7 @@ export class MessageModel { }: CreateMessageParams, id: string = this.genId(), ): Promise { - return serverDB.transaction(async (trx) => { + return this.db.transaction(async (trx) => { const [item] = (await trx .insert(messages) .values({ @@ -366,72 +368,72 @@ export class MessageModel { return { ...m, userId: this.userId }; }); - return serverDB.insert(messages).values(messagesToInsert); + return this.db.insert(messages).values(messagesToInsert); } async createMessageQuery(params: NewMessageQuery) { - const result = await serverDB.insert(messageQueries).values(params).returning(); + const result = await this.db.insert(messageQueries).values(params).returning(); return result[0]; } // **************** Update *************** // async update(id: string, message: Partial) { - return serverDB + return this.db .update(messages) .set(message) .where(and(eq(messages.id, id), eq(messages.userId, this.userId))); } async updatePluginState(id: string, state: Record) { - const item = await serverDB.query.messagePlugins.findFirst({ + const item = await this.db.query.messagePlugins.findFirst({ where: eq(messagePlugins.id, id), }); if (!item) throw new Error('Plugin not found'); - return serverDB + return this.db .update(messagePlugins) .set({ state: merge(item.state || {}, state) }) .where(eq(messagePlugins.id, id)); } async updateMessagePlugin(id: string, value: Partial) { - const item = await serverDB.query.messagePlugins.findFirst({ + const item = await this.db.query.messagePlugins.findFirst({ where: eq(messagePlugins.id, id), }); if (!item) throw new Error('Plugin not found'); - return serverDB.update(messagePlugins).set(value).where(eq(messagePlugins.id, id)); + return this.db.update(messagePlugins).set(value).where(eq(messagePlugins.id, id)); } async updateTranslate(id: string, translate: Partial) { - const result = await serverDB.query.messageTranslates.findFirst({ + const result = await this.db.query.messageTranslates.findFirst({ where: and(eq(messageTranslates.id, id)), }); // If the message does not exist in the translate table, insert it if (!result) { - return serverDB.insert(messageTranslates).values({ ...translate, id }); + return this.db.insert(messageTranslates).values({ ...translate, id }); } // or just update the existing one - return serverDB.update(messageTranslates).set(translate).where(eq(messageTranslates.id, id)); + return this.db.update(messageTranslates).set(translate).where(eq(messageTranslates.id, id)); } async updateTTS(id: string, tts: Partial) { - const result = await serverDB.query.messageTTS.findFirst({ + const result = await this.db.query.messageTTS.findFirst({ where: and(eq(messageTTS.id, id)), }); // If the message does not exist in the translate table, insert it if (!result) { - return serverDB + return this.db .insert(messageTTS) .values({ contentMd5: tts.contentMd5, fileId: tts.file, id, voice: tts.voice }); } // or just update the existing one - return serverDB + return this.db .update(messageTTS) .set({ contentMd5: tts.contentMd5, fileId: tts.file, voice: tts.voice }) .where(eq(messageTTS.id, id)); @@ -440,7 +442,7 @@ export class MessageModel { // **************** Delete *************** // async deleteMessage(id: string) { - return serverDB.transaction(async (tx) => { + return this.db.transaction(async (tx) => { // 1. 查询要删除的 message 的完整信息 const message = await tx .select() @@ -476,25 +478,25 @@ export class MessageModel { } async deleteMessages(ids: string[]) { - return serverDB + return this.db .delete(messages) .where(and(eq(messages.userId, this.userId), inArray(messages.id, ids))); } async deleteMessageTranslate(id: string) { - return serverDB.delete(messageTranslates).where(and(eq(messageTranslates.id, id))); + return this.db.delete(messageTranslates).where(and(eq(messageTranslates.id, id))); } async deleteMessageTTS(id: string) { - return serverDB.delete(messageTTS).where(and(eq(messageTTS.id, id))); + return this.db.delete(messageTTS).where(and(eq(messageTTS.id, id))); } async deleteMessageQuery(id: string) { - return serverDB.delete(messageQueries).where(and(eq(messageQueries.id, id))); + return this.db.delete(messageQueries).where(and(eq(messageQueries.id, id))); } async deleteMessagesBySession(sessionId?: string | null, topicId?: string | null) { - return serverDB + return this.db .delete(messages) .where( and( @@ -506,7 +508,7 @@ export class MessageModel { } async deleteAllMessages() { - return serverDB.delete(messages).where(eq(messages.userId, this.userId)); + return this.db.delete(messages).where(eq(messages.userId, this.userId)); } // **************** Helper *************** // diff --git a/src/database/server/models/plugin.ts b/src/database/server/models/plugin.ts index 99089827c0cf..7816b0a8f37f 100644 --- a/src/database/server/models/plugin.ts +++ b/src/database/server/models/plugin.ts @@ -1,20 +1,22 @@ import { and, desc, eq } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; +import { LobeChatDatabase } from '@/database/type'; import { InstalledPluginItem, NewInstalledPlugin, installedPlugins } from '../schemas/lobechat'; export class PluginModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } create = async ( params: Pick, ) => { - const [result] = await serverDB + const [result] = await this.db .insert(installedPlugins) .values({ ...params, createdAt: new Date(), updatedAt: new Date(), userId: this.userId }) .returning(); @@ -23,17 +25,17 @@ export class PluginModel { }; delete = async (id: string) => { - return serverDB + return this.db .delete(installedPlugins) .where(and(eq(installedPlugins.identifier, id), eq(installedPlugins.userId, this.userId))); }; deleteAll = async () => { - return serverDB.delete(installedPlugins).where(eq(installedPlugins.userId, this.userId)); + return this.db.delete(installedPlugins).where(eq(installedPlugins.userId, this.userId)); }; query = async () => { - return serverDB + return this.db .select({ createdAt: installedPlugins.createdAt, customParams: installedPlugins.customParams, @@ -49,13 +51,13 @@ export class PluginModel { }; findById = async (id: string) => { - return serverDB.query.installedPlugins.findFirst({ + return this.db.query.installedPlugins.findFirst({ where: and(eq(installedPlugins.identifier, id), eq(installedPlugins.userId, this.userId)), }); }; async update(id: string, value: Partial) { - return serverDB + return this.db .update(installedPlugins) .set({ ...value, updatedAt: new Date() }) .where(and(eq(installedPlugins.identifier, id), eq(installedPlugins.userId, this.userId))); diff --git a/src/database/server/models/ragEval/evaluationRecord.ts b/src/database/server/models/ragEval/evaluationRecord.ts index b8a7374697a1..390e8248c59d 100644 --- a/src/database/server/models/ragEval/evaluationRecord.ts +++ b/src/database/server/models/ragEval/evaluationRecord.ts @@ -1,7 +1,8 @@ import { and, eq } from 'drizzle-orm'; import { serverDB } from '@/database/server'; -import { NewEvaluationRecordsItem, evaluationRecords } from '@/database/server/schemas/lobechat'; + +import { NewEvaluationRecordsItem, evaluationRecords } from '../../schemas/lobechat'; export class EvaluationRecordModel { private userId: string; diff --git a/src/database/server/models/session.ts b/src/database/server/models/session.ts index 01d4bb287d86..057cfc5086e9 100644 --- a/src/database/server/models/session.ts +++ b/src/database/server/models/session.ts @@ -4,7 +4,7 @@ import { and, desc, eq, isNull, not, or } from 'drizzle-orm/expressions'; import { appEnv } from '@/config/app'; import { INBOX_SESSION_ID } from '@/const/session'; import { DEFAULT_AGENT_CONFIG } from '@/const/settings'; -import { serverDB } from '@/database/server/core/db'; +import { LobeChatDatabase } from '@/database/type'; import { parseAgentConfig } from '@/server/globalConfig/parseDefaultAgent'; import { ChatSessionList, LobeAgentSession } from '@/types/session'; import { merge } from '@/utils/merge'; @@ -23,16 +23,18 @@ import { idGenerator } from '../utils/idGenerator'; export class SessionModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } // **************** Query *************** // async query({ current = 0, pageSize = 9999 } = {}) { const offset = current * pageSize; - return serverDB.query.sessions.findMany({ + return this.db.query.sessions.findMany({ limit: pageSize, offset, orderBy: [desc(sessions.updatedAt)], @@ -45,7 +47,7 @@ export class SessionModel { // 查询所有会话 const result = await this.query(); - const groups = await serverDB.query.sessionGroups.findMany({ + const groups = await this.db.query.sessionGroups.findMany({ orderBy: [asc(sessionGroups.sort), desc(sessionGroups.createdAt)], where: eq(sessions.userId, this.userId), }); @@ -69,7 +71,7 @@ export class SessionModel { async findByIdOrSlug( idOrSlug: string, ): Promise<(SessionItem & { agent: AgentItem }) | undefined> { - const result = await serverDB.query.sessions.findFirst({ + const result = await this.db.query.sessions.findFirst({ where: and( or(eq(sessions.id, idOrSlug), eq(sessions.slug, idOrSlug)), eq(sessions.userId, this.userId), @@ -83,7 +85,7 @@ export class SessionModel { } async count() { - const result = await serverDB + const result = await this.db .select({ count: count(), }) @@ -109,7 +111,7 @@ export class SessionModel { slug?: string; type: 'agent' | 'group'; }): Promise { - return serverDB.transaction(async (trx) => { + return this.db.transaction(async (trx) => { const newAgents = await trx .insert(agents) .values({ @@ -144,7 +146,7 @@ export class SessionModel { } async createInbox() { - const item = await serverDB.query.sessions.findFirst({ + const item = await this.db.query.sessions.findFirst({ where: and(eq(sessions.userId, this.userId), eq(sessions.slug, INBOX_SESSION_ID)), }); if (item) return; @@ -167,7 +169,7 @@ export class SessionModel { }; }); - return serverDB.insert(sessions).values(sessionsToInsert); + return this.db.insert(sessions).values(sessionsToInsert); } async duplicate(id: string, newTitle?: string) { @@ -199,7 +201,7 @@ export class SessionModel { * Delete a session, also delete all messages and topics associated with it. */ async delete(id: string) { - return serverDB + return this.db .delete(sessions) .where(and(eq(sessions.id, id), eq(sessions.userId, this.userId))); } @@ -208,18 +210,18 @@ export class SessionModel { * Batch delete sessions, also delete all messages and topics associated with them. */ async batchDelete(ids: string[]) { - return serverDB + return this.db .delete(sessions) .where(and(inArray(sessions.id, ids), eq(sessions.userId, this.userId))); } async deleteAll() { - return serverDB.delete(sessions).where(eq(sessions.userId, this.userId)); + return this.db.delete(sessions).where(eq(sessions.userId, this.userId)); } // **************** Update *************** // async update(id: string, data: Partial) { - return serverDB + return this.db .update(sessions) .set(data) .where(and(eq(sessions.id, id), eq(sessions.userId, this.userId))) @@ -227,7 +229,7 @@ export class SessionModel { } async updateConfig(id: string, data: Partial) { - return serverDB + return this.db .update(agents) .set(data) .where(and(eq(agents.id, id), eq(agents.userId, this.userId))); @@ -272,7 +274,7 @@ export class SessionModel { const { pinned, keyword, group, pageSize = 9999, current = 0 } = params; const offset = current * pageSize; - return serverDB.query.sessions.findMany({ + return this.db.query.sessions.findMany({ limit: pageSize, offset, orderBy: [desc(sessions.updatedAt)], @@ -281,15 +283,15 @@ export class SessionModel { pinned !== undefined ? eq(sessions.pinned, pinned) : eq(sessions.userId, this.userId), keyword ? or( - like( - sql`lower(${sessions.title})` as unknown as Column, - `%${keyword.toLowerCase()}%`, - ), - like( - sql`lower(${sessions.description})` as unknown as Column, - `%${keyword.toLowerCase()}%`, - ), - ) + like( + sql`lower(${sessions.title})` as unknown as Column, + `%${keyword.toLowerCase()}%`, + ), + like( + sql`lower(${sessions.description})` as unknown as Column, + `%${keyword.toLowerCase()}%`, + ), + ) : eq(sessions.userId, this.userId), group ? eq(sessions.groupId, group) : isNull(sessions.groupId), ), @@ -298,29 +300,22 @@ export class SessionModel { }); } - async findSessionsByKeywords(params: { - current?: number; - keyword: string; - pageSize?: number; - }) { + async findSessionsByKeywords(params: { current?: number; keyword: string; pageSize?: number }) { const { keyword, pageSize = 9999, current = 0 } = params; const offset = current * pageSize; - const results = await serverDB.query.agents.findMany({ + const results = await this.db.query.agents.findMany({ limit: pageSize, offset, orderBy: [desc(agents.updatedAt)], where: and( eq(agents.userId, this.userId), or( - like( - sql`lower(${agents.title})` as unknown as Column, - `%${keyword.toLowerCase()}%`, - ), + like(sql`lower(${agents.title})` as unknown as Column, `%${keyword.toLowerCase()}%`), like( sql`lower(${agents.description})` as unknown as Column, `%${keyword.toLowerCase()}%`, ), - ) + ), ), with: { agentsToSessions: { columns: {}, with: { session: true } } }, }); @@ -328,6 +323,6 @@ export class SessionModel { // @ts-expect-error return results.map((item) => item.agentsToSessions[0].session); } catch {} - return [] + return []; } } diff --git a/src/database/server/models/sessionGroup.ts b/src/database/server/models/sessionGroup.ts index ab257184a946..17557cd63247 100644 --- a/src/database/server/models/sessionGroup.ts +++ b/src/database/server/models/sessionGroup.ts @@ -1,20 +1,22 @@ import { eq } from 'drizzle-orm'; import { and, asc, desc } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; -import { idGenerator } from '@/database/server/utils/idGenerator'; +import { LobeChatDatabase } from '@/database/type'; import { SessionGroupItem, sessionGroups } from '../schemas/lobechat'; +import { idGenerator } from '../utils/idGenerator'; export class SessionGroupModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } create = async (params: { name: string; sort?: number }) => { - const [result] = await serverDB + const [result] = await this.db .insert(sessionGroups) .values({ ...params, id: this.genId(), userId: this.userId }) .returning(); @@ -23,37 +25,37 @@ export class SessionGroupModel { }; delete = async (id: string) => { - return serverDB + return this.db .delete(sessionGroups) .where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId))); }; deleteAll = async () => { - return serverDB.delete(sessionGroups).where(eq(sessionGroups.userId, this.userId)); + return this.db.delete(sessionGroups).where(eq(sessionGroups.userId, this.userId)); }; query = async () => { - return serverDB.query.sessionGroups.findMany({ + return this.db.query.sessionGroups.findMany({ orderBy: [asc(sessionGroups.sort), desc(sessionGroups.createdAt)], where: eq(sessionGroups.userId, this.userId), }); }; findById = async (id: string) => { - return serverDB.query.sessionGroups.findFirst({ + return this.db.query.sessionGroups.findFirst({ where: and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)), }); }; async update(id: string, value: Partial) { - return serverDB + return this.db .update(sessionGroups) .set({ ...value, updatedAt: new Date() }) .where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId))); } async updateOrder(sortMap: { id: string; sort: number }[]) { - await serverDB.transaction(async (tx) => { + await this.db.transaction(async (tx) => { const updates = sortMap.map(({ id, sort }) => { return tx .update(sessionGroups) diff --git a/src/database/server/models/thread.ts b/src/database/server/models/thread.ts index c8bc4542cb52..90706423089a 100644 --- a/src/database/server/models/thread.ts +++ b/src/database/server/models/thread.ts @@ -1,7 +1,7 @@ import { eq } from 'drizzle-orm'; import { and, desc } from 'drizzle-orm/expressions'; -import { serverDB } from '@/database/server'; +import { LobeChatDatabase } from '@/database/type'; import { CreateThreadParams, ThreadStatus } from '@/types/topic'; import { ThreadItem, threads } from '../schemas/lobechat'; @@ -20,14 +20,16 @@ const queryColumns = { export class ThreadModel { private userId: string; + private db: LobeChatDatabase; - constructor(userId: string) { + constructor(db: LobeChatDatabase, userId: string) { this.userId = userId; + this.db = db; } create = async (params: CreateThreadParams) => { // @ts-ignore - const [result] = await serverDB + const [result] = await this.db .insert(threads) .values({ ...params, status: ThreadStatus.Active, userId: this.userId }) .onConflictDoNothing() @@ -37,15 +39,15 @@ export class ThreadModel { }; delete = async (id: string) => { - return serverDB.delete(threads).where(and(eq(threads.id, id), eq(threads.userId, this.userId))); + return this.db.delete(threads).where(and(eq(threads.id, id), eq(threads.userId, this.userId))); }; deleteAll = async () => { - return serverDB.delete(threads).where(eq(threads.userId, this.userId)); + return this.db.delete(threads).where(eq(threads.userId, this.userId)); }; query = async () => { - const data = await serverDB + const data = await this.db .select(queryColumns) .from(threads) .where(eq(threads.userId, this.userId)) @@ -55,7 +57,7 @@ export class ThreadModel { }; queryByTopicId = async (topicId: string) => { - const data = await serverDB + const data = await this.db .select(queryColumns) .from(threads) .where(and(eq(threads.topicId, topicId), eq(threads.userId, this.userId))) @@ -65,13 +67,13 @@ export class ThreadModel { }; findById = async (id: string) => { - return serverDB.query.threads.findFirst({ + return this.db.query.threads.findFirst({ where: and(eq(threads.id, id), eq(threads.userId, this.userId)), }); }; async update(id: string, value: Partial) { - return serverDB + return this.db .update(threads) .set({ ...value, updatedAt: new Date() }) .where(and(eq(threads.id, id), eq(threads.userId, this.userId))); diff --git a/src/database/server/models/topic.ts b/src/database/server/models/topic.ts index 890558e4fe4b..eeee76c2bbdc 100644 --- a/src/database/server/models/topic.ts +++ b/src/database/server/models/topic.ts @@ -1,7 +1,7 @@ import { Column, count, inArray, sql } from 'drizzle-orm'; import { and, desc, eq, exists, isNull, like, or } from 'drizzle-orm/expressions'; -import { LobeChatDatabase } from '@/database/server/type'; +import { LobeChatDatabase } from '@/database/type'; import { NewMessage, TopicItem, messages, topics } from '../schemas/lobechat'; import { idGenerator } from '../utils/idGenerator'; diff --git a/src/database/server/models/user.ts b/src/database/server/models/user.ts index 86008a8fc896..7e792438ca96 100644 --- a/src/database/server/models/user.ts +++ b/src/database/server/models/user.ts @@ -3,6 +3,7 @@ import { eq } from 'drizzle-orm'; import { DeepPartial } from 'utility-types'; import { serverDB } from '@/database/server/core/db'; +import { LobeChatDatabase } from '@/database/type'; import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt'; import { UserGuide, UserPreference } from '@/types/user'; import { UserKeyVaults, UserSettings } from '@/types/user/settings'; @@ -18,38 +19,16 @@ export class UserNotFoundError extends TRPCError { } export class UserModel { - static createUser = async (params: NewUser) => { - // if user already exists, skip creation - if (params.id) { - const user = await serverDB.query.users.findFirst({ where: eq(users.id, params.id) }); - if (!!user) return; - } - - const [user] = await serverDB - .insert(users) - .values({ ...params }) - .returning(); - - // Create an inbox session for the user - const model = new SessionModel(user.id); - - await model.createInbox(); - }; + private userId: string; + private db: LobeChatDatabase; - static deleteUser = async (id: string) => { - return serverDB.delete(users).where(eq(users.id, id)); - }; - - static findById = async (id: string) => { - return serverDB.query.users.findFirst({ where: eq(users.id, id) }); - }; - - static findByEmail = async (email: string) => { - return serverDB.query.users.findFirst({ where: eq(users.email, email) }); - }; + constructor(db: LobeChatDatabase, userId: string) { + this.userId = userId; + this.db = db; + } - getUserState = async (id: string) => { - const result = await serverDB + async getUserState() { + const result = await this.db .select({ isOnboarded: users.isOnboarded, preference: users.preference, @@ -63,7 +42,7 @@ export class UserModel { settingsTool: userSettings.tool, }) .from(users) - .where(eq(users.id, id)) + .where(eq(users.id, this.userId)) .leftJoin(userSettings, eq(users.id, userSettings.id)); if (!result || !result[0]) { @@ -82,7 +61,7 @@ export class UserModel { try { decryptKeyVaults = JSON.parse(plaintext); } catch (e) { - console.error(`Failed to parse keyVaults ,userId: ${id}. Error:`, e); + console.error(`Failed to parse keyVaults ,userId: ${this.userId}. Error:`, e); } } } @@ -101,54 +80,22 @@ export class UserModel { isOnboarded: state.isOnboarded, preference: state.preference as UserPreference, settings, - userId: id, + userId: this.userId, }; - }; - - static getUserApiKeys = async (id: string) => { - const result = await serverDB - .select({ - settingsKeyVaults: userSettings.keyVaults, - }) - .from(userSettings) - .where(eq(userSettings.id, id)); - - if (!result || !result[0]) { - throw new UserNotFoundError(); - } - - const state = result[0]; - - // Decrypt keyVaults - let decryptKeyVaults = {}; - if (state.settingsKeyVaults) { - const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); - const { wasAuthentic, plaintext } = await gateKeeper.decrypt(state.settingsKeyVaults); - - if (wasAuthentic) { - try { - decryptKeyVaults = JSON.parse(plaintext); - } catch (e) { - console.error(`Failed to parse keyVaults ,userId: ${id}. Error:`, e); - } - } - } - - return decryptKeyVaults as UserKeyVaults; - }; + } - async updateUser(id: string, value: Partial) { - return serverDB + async updateUser(value: Partial) { + return this.db .update(users) .set({ ...value, updatedAt: new Date() }) - .where(eq(users.id, id)); + .where(eq(users.id, this.userId)); } - async deleteSetting(id: string) { - return serverDB.delete(userSettings).where(eq(userSettings.id, id)); + async deleteSetting() { + return this.db.delete(userSettings).where(eq(userSettings.id, this.userId)); } - async updateSetting(id: string, value: Partial) { + async updateSetting(value: Partial) { const { keyVaults, ...res } = value; // Encrypt keyVaults @@ -165,33 +112,99 @@ export class UserModel { const newValue = { ...res, keyVaults: encryptedKeyVaults }; // update or create user settings - const settings = await serverDB.query.userSettings.findFirst({ where: eq(users.id, id) }); + const settings = await this.db.query.userSettings.findFirst({ + where: eq(users.id, this.userId), + }); if (!settings) { - await serverDB.insert(userSettings).values({ id, ...newValue }); + await serverDB.insert(userSettings).values({ id: this.userId, ...newValue }); return; } - return serverDB.update(userSettings).set(newValue).where(eq(userSettings.id, id)); + return serverDB.update(userSettings).set(newValue).where(eq(userSettings.id, this.userId)); } - async updatePreference(id: string, value: Partial) { - const user = await serverDB.query.users.findFirst({ where: eq(users.id, id) }); + async updatePreference(value: Partial) { + const user = await this.db.query.users.findFirst({ where: eq(users.id, this.userId) }); if (!user) return; - return serverDB + return this.db .update(users) .set({ preference: merge(user.preference, value) }) - .where(eq(users.id, id)); + .where(eq(users.id, this.userId)); } - async updateGuide(id: string, value: Partial) { - const user = await serverDB.query.users.findFirst({ where: eq(users.id, id) }); + async updateGuide(value: Partial) { + const user = await this.db.query.users.findFirst({ where: eq(users.id, this.userId) }); if (!user) return; const prevPreference = (user.preference || {}) as UserPreference; - return serverDB + return this.db .update(users) .set({ preference: { ...prevPreference, guide: merge(prevPreference.guide || {}, value) } }) - .where(eq(users.id, id)); + .where(eq(users.id, this.userId)); } + + // Static method + + static createUser = async (db: LobeChatDatabase, params: NewUser) => { + // if user already exists, skip creation + if (params.id) { + const user = await db.query.users.findFirst({ where: eq(users.id, params.id) }); + if (!!user) return; + } + + const [user] = await db + .insert(users) + .values({ ...params }) + .returning(); + + // Create an inbox session for the user + const model = new SessionModel(db, user.id); + + await model.createInbox(); + }; + + static deleteUser = async (db: LobeChatDatabase, id: string) => { + return db.delete(users).where(eq(users.id, id)); + }; + + static findById = async (db: LobeChatDatabase, id: string) => { + return db.query.users.findFirst({ where: eq(users.id, id) }); + }; + + static findByEmail = async (db: LobeChatDatabase, email: string) => { + return db.query.users.findFirst({ where: eq(users.email, email) }); + }; + + static getUserApiKeys = async (db: LobeChatDatabase, id: string) => { + const result = await db + .select({ + settingsKeyVaults: userSettings.keyVaults, + }) + .from(userSettings) + .where(eq(userSettings.id, id)); + + if (!result || !result[0]) { + throw new UserNotFoundError(); + } + + const state = result[0]; + + // Decrypt keyVaults + let decryptKeyVaults = {}; + if (state.settingsKeyVaults) { + const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); + const { wasAuthentic, plaintext } = await gateKeeper.decrypt(state.settingsKeyVaults); + + if (wasAuthentic) { + try { + decryptKeyVaults = JSON.parse(plaintext); + } catch (e) { + console.error(`Failed to parse keyVaults ,userId: ${id}. Error:`, e); + } + } + } + + return decryptKeyVaults as UserKeyVaults; + }; } diff --git a/src/database/server/type.ts b/src/database/type.ts similarity index 53% rename from src/database/server/type.ts rename to src/database/type.ts index 3da18c2bbb8b..d47c1aa1017c 100644 --- a/src/database/server/type.ts +++ b/src/database/type.ts @@ -1,6 +1,6 @@ -import { PgliteDatabase } from 'drizzle-orm/pglite'; +import type { PgliteDatabase } from 'drizzle-orm/pglite'; -import * as schema from '../server/schemas/lobechat'; +import * as schema from './server/schemas/lobechat'; export type LobeChatDatabaseSchema = typeof schema; diff --git a/src/libs/next-auth/adapter/index.ts b/src/libs/next-auth/adapter/index.ts index 129da498f185..0ab51e5fd80a 100644 --- a/src/libs/next-auth/adapter/index.ts +++ b/src/libs/next-auth/adapter/index.ts @@ -33,8 +33,6 @@ const { * @returns {Adapter} */ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Adapter { - const userModel = new UserModel(); - return { async createAuthenticator(authenticator): Promise { const result = await serverDB @@ -55,10 +53,10 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad async createUser(user): Promise { const { id, name, email, emailVerified, image, providerAccountId } = user; // return the user if it already exists - let existingUser = await UserModel.findByEmail(email); + let existingUser = await UserModel.findByEmail(serverDB, email); // If the user is not found by email, try to find by providerAccountId if (!existingUser && providerAccountId) { - existingUser = await UserModel.findById(providerAccountId); + existingUser = await UserModel.findById(serverDB, providerAccountId); } if (existingUser) { const adapterUser = mapLobeUserToAdapterUser(existingUser); @@ -66,6 +64,7 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad } // create a new user if it does not exist await UserModel.createUser( + serverDB, mapAdapterUserToLobeUser({ email, emailVerified, @@ -91,10 +90,10 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad return; }, async deleteUser(id): Promise { - const user = await UserModel.findById(id); + const user = await UserModel.findById(serverDB, id); if (!user) throw new Error('NextAuth: Delete User not found'); - await UserModel.deleteUser(id); + await UserModel.deleteUser(serverDB, id); return; }, @@ -145,7 +144,7 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad }, async getUser(id): Promise { - const lobeUser = await UserModel.findById(id); + const lobeUser = await UserModel.findById(serverDB, id); if (!lobeUser) return null; return mapLobeUserToAdapterUser(lobeUser); }, @@ -170,7 +169,7 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad }, async getUserByEmail(email): Promise { - const lobeUser = await UserModel.findByEmail(email); + const lobeUser = await UserModel.findByEmail(serverDB, email); return lobeUser ? mapLobeUserToAdapterUser(lobeUser) : null; }, @@ -228,10 +227,11 @@ export function LobeNextAuthDbAdapter(serverDB: NeonDatabase): Ad }, async updateUser(user): Promise { - const lobeUser = await UserModel.findById(user?.id); + const lobeUser = await UserModel.findById(serverDB, user?.id); if (!lobeUser) throw new Error('NextAuth: User not found'); + const userModel = new UserModel(serverDB, user.id); - const updatedUser = await userModel.updateUser(user.id, { + const updatedUser = await userModel.updateUser({ ...partialMapAdapterUserToLobeUser(user), }); if (!updatedUser) throw new Error('NextAuth: Failed to update user'); diff --git a/src/libs/trpc/async/asyncAuth.ts b/src/libs/trpc/async/asyncAuth.ts index 3ebf3248a273..69b52d5d5441 100644 --- a/src/libs/trpc/async/asyncAuth.ts +++ b/src/libs/trpc/async/asyncAuth.ts @@ -1,6 +1,7 @@ import { TRPCError } from '@trpc/server'; import { serverDBEnv } from '@/config/db'; +import { serverDB } from '@/database/server'; import { UserModel } from '@/database/server/models/user'; import { asyncTrpc } from './init'; @@ -12,7 +13,7 @@ export const asyncAuth = asyncTrpc.middleware(async (opts) => { throw new TRPCError({ code: 'UNAUTHORIZED' }); } - const result = await UserModel.findById(ctx.userId); + const result = await UserModel.findById(serverDB, ctx.userId); if (!result) { throw new TRPCError({ code: 'UNAUTHORIZED', message: 'user is invalid' }); diff --git a/src/server/routers/async/file.ts b/src/server/routers/async/file.ts index bce5cd784b75..a60ca354220b 100644 --- a/src/server/routers/async/file.ts +++ b/src/server/routers/async/file.ts @@ -5,6 +5,7 @@ import { z } from 'zod'; import { fileEnv } from '@/config/file'; import { DEFAULT_EMBEDDING_MODEL } from '@/const/settings'; +import { serverDB } from '@/database/server'; import { ASYNC_TASK_TIMEOUT, AsyncTaskModel } from '@/database/server/models/asyncTask'; import { ChunkModel } from '@/database/server/models/chunk'; import { EmbeddingModel } from '@/database/server/models/embedding'; @@ -28,11 +29,11 @@ const fileProcedure = asyncAuthedProcedure.use(async (opts) => { return opts.next({ ctx: { - asyncTaskModel: new AsyncTaskModel(ctx.userId), - chunkModel: new ChunkModel(ctx.userId), + asyncTaskModel: new AsyncTaskModel(serverDB, ctx.userId), + chunkModel: new ChunkModel(serverDB, ctx.userId), chunkService: new ChunkService(ctx.userId), - embeddingModel: new EmbeddingModel(ctx.userId), - fileModel: new FileModel(ctx.userId), + embeddingModel: new EmbeddingModel(serverDB, ctx.userId), + fileModel: new FileModel(serverDB, ctx.userId), }, }); }); diff --git a/src/server/routers/async/ragEval.ts b/src/server/routers/async/ragEval.ts index 5fc637337f23..de4190ee668e 100644 --- a/src/server/routers/async/ragEval.ts +++ b/src/server/routers/async/ragEval.ts @@ -4,6 +4,7 @@ import { z } from 'zod'; import { chainAnswerWithContext } from '@/chains/answerWithContext'; import { DEFAULT_EMBEDDING_MODEL, DEFAULT_MODEL } from '@/const/settings'; +import { serverDB } from '@/database/server'; import { ChunkModel } from '@/database/server/models/chunk'; import { EmbeddingModel } from '@/database/server/models/embedding'; import { FileModel } from '@/database/server/models/file'; @@ -24,13 +25,13 @@ const ragEvalProcedure = asyncAuthedProcedure.use(async (opts) => { return opts.next({ ctx: { - chunkModel: new ChunkModel(ctx.userId), + chunkModel: new ChunkModel(serverDB, ctx.userId), chunkService: new ChunkService(ctx.userId), datasetRecordModel: new EvalDatasetRecordModel(ctx.userId), - embeddingModel: new EmbeddingModel(ctx.userId), + embeddingModel: new EmbeddingModel(serverDB, ctx.userId), evalRecordModel: new EvaluationRecordModel(ctx.userId), evaluationModel: new EvalEvaluationModel(ctx.userId), - fileModel: new FileModel(ctx.userId), + fileModel: new FileModel(serverDB, ctx.userId), }, }); }); diff --git a/src/server/routers/lambda/_template.ts b/src/server/routers/lambda/_template.ts index e7dace3277ec..530402a59b94 100644 --- a/src/server/routers/lambda/_template.ts +++ b/src/server/routers/lambda/_template.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; +import { serverDB } from '@/database/server'; import { SessionGroupModel } from '@/database/server/models/sessionGroup'; import { insertSessionGroupSchema } from '@/database/server/schemas/lobechat'; import { authedProcedure, router } from '@/libs/trpc'; @@ -10,7 +11,7 @@ const sessionProcedure = authedProcedure.use(async (opts) => { return opts.next({ ctx: { - sessionGroupModel: new SessionGroupModel(ctx.userId), + sessionGroupModel: new SessionGroupModel(serverDB, ctx.userId), }, }); }); diff --git a/src/server/routers/lambda/agent.ts b/src/server/routers/lambda/agent.ts index 1ce31bef76b0..a3735ac732ef 100644 --- a/src/server/routers/lambda/agent.ts +++ b/src/server/routers/lambda/agent.ts @@ -2,6 +2,7 @@ import { z } from 'zod'; import { INBOX_SESSION_ID } from '@/const/session'; import { DEFAULT_AGENT_CONFIG } from '@/const/settings'; +import { serverDB } from '@/database/server'; import { AgentModel } from '@/database/server/models/agent'; import { FileModel } from '@/database/server/models/file'; import { KnowledgeBaseModel } from '@/database/server/models/knowledgeBase'; @@ -16,10 +17,10 @@ const agentProcedure = authedProcedure.use(async (opts) => { return opts.next({ ctx: { - agentModel: new AgentModel(ctx.userId), - fileModel: new FileModel(ctx.userId), - knowledgeBaseModel: new KnowledgeBaseModel(ctx.userId), - sessionModel: new SessionModel(ctx.userId), + agentModel: new AgentModel(serverDB, ctx.userId), + fileModel: new FileModel(serverDB, ctx.userId), + knowledgeBaseModel: new KnowledgeBaseModel(serverDB, ctx.userId), + sessionModel: new SessionModel(serverDB, ctx.userId), }, }); }); @@ -87,7 +88,7 @@ export const agentRouter = router({ // if there is no session for user, create one if (!item) { // if there is no user, return default config - const user = await UserModel.findById(ctx.userId); + const user = await UserModel.findById(serverDB, ctx.userId); if (!user) return DEFAULT_AGENT_CONFIG; const res = await ctx.sessionModel.createInbox(); diff --git a/src/server/routers/lambda/chunk.ts b/src/server/routers/lambda/chunk.ts index aa8947e1e0da..d1f2d53d2a5c 100644 --- a/src/server/routers/lambda/chunk.ts +++ b/src/server/routers/lambda/chunk.ts @@ -21,12 +21,12 @@ const chunkProcedure = authedProcedure.use(keyVaults).use(async (opts) => { return opts.next({ ctx: { - asyncTaskModel: new AsyncTaskModel(ctx.userId), - chunkModel: new ChunkModel(ctx.userId), + asyncTaskModel: new AsyncTaskModel(serverDB, ctx.userId), + chunkModel: new ChunkModel(serverDB, ctx.userId), chunkService: new ChunkService(ctx.userId), - embeddingModel: new EmbeddingModel(ctx.userId), - fileModel: new FileModel(ctx.userId), - messageModel: new MessageModel(ctx.userId), + embeddingModel: new EmbeddingModel(serverDB, ctx.userId), + fileModel: new FileModel(serverDB, ctx.userId), + messageModel: new MessageModel(serverDB, ctx.userId), }, }); }); diff --git a/src/server/routers/lambda/file.ts b/src/server/routers/lambda/file.ts index 8b73cc715ca7..02b4366aa9bc 100644 --- a/src/server/routers/lambda/file.ts +++ b/src/server/routers/lambda/file.ts @@ -1,6 +1,7 @@ import { TRPCError } from '@trpc/server'; import { z } from 'zod'; +import { serverDB } from '@/database/server'; import { AsyncTaskModel } from '@/database/server/models/asyncTask'; import { ChunkModel } from '@/database/server/models/chunk'; import { FileModel } from '@/database/server/models/file'; @@ -15,9 +16,9 @@ const fileProcedure = authedProcedure.use(async (opts) => { return opts.next({ ctx: { - asyncTaskModel: new AsyncTaskModel(ctx.userId), - chunkModel: new ChunkModel(ctx.userId), - fileModel: new FileModel(ctx.userId), + asyncTaskModel: new AsyncTaskModel(serverDB, ctx.userId), + chunkModel: new ChunkModel(serverDB, ctx.userId), + fileModel: new FileModel(serverDB, ctx.userId), }, }); }); diff --git a/src/server/routers/lambda/knowledgeBase.ts b/src/server/routers/lambda/knowledgeBase.ts index 140eb7c1f00d..1aceeaf5c97f 100644 --- a/src/server/routers/lambda/knowledgeBase.ts +++ b/src/server/routers/lambda/knowledgeBase.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; +import { serverDB } from '@/database/server'; import { KnowledgeBaseModel } from '@/database/server/models/knowledgeBase'; import { insertKnowledgeBasesSchema } from '@/database/server/schemas/lobechat'; import { authedProcedure, router } from '@/libs/trpc'; @@ -10,7 +11,7 @@ const knowledgeBaseProcedure = authedProcedure.use(async (opts) => { return opts.next({ ctx: { - knowledgeBaseModel: new KnowledgeBaseModel(ctx.userId), + knowledgeBaseModel: new KnowledgeBaseModel(serverDB, ctx.userId), }, }); }); diff --git a/src/server/routers/lambda/message.ts b/src/server/routers/lambda/message.ts index ffcb9390ca60..a01d5cc6d801 100644 --- a/src/server/routers/lambda/message.ts +++ b/src/server/routers/lambda/message.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; +import { serverDB } from '@/database/server'; import { MessageModel } from '@/database/server/models/message'; import { updateMessagePluginSchema } from '@/database/server/schemas/lobechat'; import { authedProcedure, publicProcedure, router } from '@/libs/trpc'; @@ -12,7 +13,7 @@ const messageProcedure = authedProcedure.use(async (opts) => { const { ctx } = opts; return opts.next({ - ctx: { messageModel: new MessageModel(ctx.userId) }, + ctx: { messageModel: new MessageModel(serverDB, ctx.userId) }, }); }); @@ -54,6 +55,7 @@ export const messageRouter = router({ return ctx.messageModel.queryBySessionId(input.sessionId); }), + // TODO: 未来这部分方法也需要使用 authedProcedure getMessages: publicProcedure .input( z.object({ @@ -66,7 +68,7 @@ export const messageRouter = router({ .query(async ({ input, ctx }) => { if (!ctx.userId) return []; - const messageModel = new MessageModel(ctx.userId); + const messageModel = new MessageModel(serverDB, ctx.userId); return messageModel.query(input); }), diff --git a/src/server/routers/lambda/plugin.ts b/src/server/routers/lambda/plugin.ts index 3c691c51df0c..13880b1ff32e 100644 --- a/src/server/routers/lambda/plugin.ts +++ b/src/server/routers/lambda/plugin.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; +import { serverDB } from '@/database/server'; import { PluginModel } from '@/database/server/models/plugin'; import { authedProcedure, publicProcedure, router } from '@/libs/trpc'; import { LobeTool } from '@/types/tool'; @@ -8,7 +9,7 @@ const pluginProcedure = authedProcedure.use(async (opts) => { const { ctx } = opts; return opts.next({ - ctx: { pluginModel: new PluginModel(ctx.userId) }, + ctx: { pluginModel: new PluginModel(serverDB, ctx.userId) }, }); }); @@ -61,10 +62,11 @@ export const pluginRouter = router({ return data.identifier; }), + // TODO: 未来这部分方法也需要使用 authedProcedure getPlugins: publicProcedure.query(async ({ ctx }): Promise => { if (!ctx.userId) return []; - const pluginModel = new PluginModel(ctx.userId); + const pluginModel = new PluginModel(serverDB, ctx.userId); return pluginModel.query(); }), diff --git a/src/server/routers/lambda/ragEval.ts b/src/server/routers/lambda/ragEval.ts index 33b33a715944..150abe4a4e1c 100644 --- a/src/server/routers/lambda/ragEval.ts +++ b/src/server/routers/lambda/ragEval.ts @@ -6,6 +6,7 @@ import pMap from 'p-map'; import { z } from 'zod'; import { DEFAULT_EMBEDDING_MODEL, DEFAULT_MODEL } from '@/const/settings'; +import { serverDB } from '@/database/server'; import { FileModel } from '@/database/server/models/file'; import { EvalDatasetModel, @@ -34,7 +35,7 @@ const ragEvalProcedure = authedProcedure.use(keyVaults).use(async (opts) => { return opts.next({ ctx: { datasetModel: new EvalDatasetModel(ctx.userId), - fileModel: new FileModel(ctx.userId), + fileModel: new FileModel(serverDB, ctx.userId), datasetRecordModel: new EvalDatasetRecordModel(ctx.userId), evaluationModel: new EvalEvaluationModel(ctx.userId), evaluationRecordModel: new EvaluationRecordModel(ctx.userId), diff --git a/src/server/routers/lambda/session.ts b/src/server/routers/lambda/session.ts index a2099efba215..063be3a3fcc1 100644 --- a/src/server/routers/lambda/session.ts +++ b/src/server/routers/lambda/session.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; +import { serverDB } from '@/database/server'; import { SessionModel } from '@/database/server/models/session'; import { SessionGroupModel } from '@/database/server/models/sessionGroup'; import { insertAgentSchema, insertSessionSchema } from '@/database/server/schemas/lobechat'; @@ -15,8 +16,8 @@ const sessionProcedure = authedProcedure.use(async (opts) => { return opts.next({ ctx: { - sessionGroupModel: new SessionGroupModel(ctx.userId), - sessionModel: new SessionModel(ctx.userId), + sessionGroupModel: new SessionGroupModel(serverDB, ctx.userId), + sessionModel: new SessionModel(serverDB, ctx.userId), }, }); }); @@ -84,7 +85,7 @@ export const sessionRouter = router({ sessions: [], }; - const sessionModel = new SessionModel(ctx.userId); + const sessionModel = new SessionModel(serverDB, ctx.userId); return sessionModel.queryWithGroups(); }), diff --git a/src/server/routers/lambda/sessionGroup.ts b/src/server/routers/lambda/sessionGroup.ts index e7dace3277ec..530402a59b94 100644 --- a/src/server/routers/lambda/sessionGroup.ts +++ b/src/server/routers/lambda/sessionGroup.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; +import { serverDB } from '@/database/server'; import { SessionGroupModel } from '@/database/server/models/sessionGroup'; import { insertSessionGroupSchema } from '@/database/server/schemas/lobechat'; import { authedProcedure, router } from '@/libs/trpc'; @@ -10,7 +11,7 @@ const sessionProcedure = authedProcedure.use(async (opts) => { return opts.next({ ctx: { - sessionGroupModel: new SessionGroupModel(ctx.userId), + sessionGroupModel: new SessionGroupModel(serverDB, ctx.userId), }, }); }); diff --git a/src/server/routers/lambda/thread.ts b/src/server/routers/lambda/thread.ts index 91a8f71ca0f9..479e9e360b71 100644 --- a/src/server/routers/lambda/thread.ts +++ b/src/server/routers/lambda/thread.ts @@ -1,5 +1,6 @@ import { z } from 'zod'; +import { serverDB } from '@/database/server'; import { MessageModel } from '@/database/server/models/message'; import { ThreadModel } from '@/database/server/models/thread'; import { insertThreadSchema } from '@/database/server/schemas/lobechat'; @@ -11,8 +12,8 @@ const threadProcedure = authedProcedure.use(async (opts) => { return opts.next({ ctx: { - messageModel: new MessageModel(ctx.userId), - threadModel: new ThreadModel(ctx.userId), + messageModel: new MessageModel(serverDB, ctx.userId), + threadModel: new ThreadModel(serverDB, ctx.userId), }, }); }); diff --git a/src/server/routers/lambda/user.ts b/src/server/routers/lambda/user.ts index 9660a955e691..f7522004940b 100644 --- a/src/server/routers/lambda/user.ts +++ b/src/server/routers/lambda/user.ts @@ -3,6 +3,7 @@ import { currentUser } from '@clerk/nextjs/server'; import { z } from 'zod'; import { enableClerk } from '@/const/auth'; +import { serverDB } from '@/database/server'; import { MessageModel } from '@/database/server/models/message'; import { SessionModel } from '@/database/server/models/session'; import { UserModel, UserNotFoundError } from '@/database/server/models/user'; @@ -12,7 +13,7 @@ import { UserGuideSchema, UserInitializationState, UserPreference } from '@/type const userProcedure = authedProcedure.use(async (opts) => { return opts.next({ - ctx: { userModel: new UserModel() }, + ctx: { userModel: new UserModel(serverDB, opts.ctx.userId) }, }); }); @@ -23,7 +24,7 @@ export const userRouter = router({ // get or create first-time user while (!state) { try { - state = await ctx.userModel.getUserState(ctx.userId); + state = await ctx.userModel.getUserState(); } catch (error) { if (enableClerk && error instanceof UserNotFoundError) { const user = await currentUser(); @@ -56,10 +57,10 @@ export const userRouter = router({ } } - const messageModel = new MessageModel(ctx.userId); + const messageModel = new MessageModel(serverDB, ctx.userId); const messageCount = await messageModel.count(); - const sessionModel = new SessionModel(ctx.userId); + const sessionModel = new SessionModel(serverDB, ctx.userId); const sessionCount = await sessionModel.count(); return { @@ -77,25 +78,25 @@ export const userRouter = router({ }), makeUserOnboarded: userProcedure.mutation(async ({ ctx }) => { - return ctx.userModel.updateUser(ctx.userId, { isOnboarded: true }); + return ctx.userModel.updateUser({ isOnboarded: true }); }), resetSettings: userProcedure.mutation(async ({ ctx }) => { - return ctx.userModel.deleteSetting(ctx.userId); + return ctx.userModel.deleteSetting(); }), updateGuide: userProcedure.input(UserGuideSchema).mutation(async ({ ctx, input }) => { - return ctx.userModel.updateGuide(ctx.userId, input); + return ctx.userModel.updateGuide(input); }), updatePreference: userProcedure.input(z.any()).mutation(async ({ ctx, input }) => { - return ctx.userModel.updatePreference(ctx.userId, input); + return ctx.userModel.updatePreference(input); }), updateSettings: userProcedure .input(z.object({}).passthrough()) .mutation(async ({ ctx, input }) => { - return ctx.userModel.updateSetting(ctx.userId, input); + return ctx.userModel.updateSetting(input); }), }); diff --git a/src/server/services/chunk/index.ts b/src/server/services/chunk/index.ts index 3ef1ba13275d..09c8e1d820b6 100644 --- a/src/server/services/chunk/index.ts +++ b/src/server/services/chunk/index.ts @@ -1,4 +1,5 @@ import { JWTPayload } from '@/const/auth'; +import { serverDB } from '@/database/server'; import { AsyncTaskModel } from '@/database/server/models/asyncTask'; import { FileModel } from '@/database/server/models/file'; import { ChunkContentParams, ContentChunk } from '@/server/modules/ContentChunk'; @@ -21,8 +22,8 @@ export class ChunkService { this.chunkClient = new ContentChunk(); - this.fileModel = new FileModel(userId); - this.asyncTaskModel = new AsyncTaskModel(userId); + this.fileModel = new FileModel(serverDB, userId); + this.asyncTaskModel = new AsyncTaskModel(serverDB, userId); } async chunkContent(params: ChunkContentParams) { diff --git a/src/server/services/nextAuthUser/index.ts b/src/server/services/nextAuthUser/index.ts index ac17b0de0a46..2108e8cac3c4 100644 --- a/src/server/services/nextAuthUser/index.ts +++ b/src/server/services/nextAuthUser/index.ts @@ -7,11 +7,9 @@ import { pino } from '@/libs/logger'; import { LobeNextAuthDbAdapter } from '@/libs/next-auth/adapter'; export class NextAuthUserService { - userModel; adapter; constructor() { - this.userModel = new UserModel(); this.adapter = LobeNextAuthDbAdapter(serverDB); } @@ -29,8 +27,10 @@ export class NextAuthUserService { // 2. If found, Update user data from provider if (user?.id) { + const userModel = new UserModel(serverDB, user.id); + // Perform update - await this.userModel.updateUser(user.id, { + await userModel.updateUser({ avatar: data?.avatar, email: data?.email, fullName: data?.fullName, diff --git a/src/server/services/user/index.ts b/src/server/services/user/index.ts index 1bc3480fed74..ed5ee2395099 100644 --- a/src/server/services/user/index.ts +++ b/src/server/services/user/index.ts @@ -1,13 +1,14 @@ import { UserJSON } from '@clerk/backend'; import { NextResponse } from 'next/server'; +import { serverDB } from '@/database/server'; import { UserModel } from '@/database/server/models/user'; import { pino } from '@/libs/logger'; export class UserService { createUser = async (id: string, params: UserJSON) => { // Check if user already exists - const res = await UserModel.findById(id); + const res = await UserModel.findById(serverDB, id); // If user already exists, skip creating a new user if (res) @@ -27,7 +28,7 @@ export class UserService { /* ↑ cloud slot ↑ */ // 2. create user in database - await UserModel.createUser({ + await UserModel.createUser(serverDB, { avatar: params.image_url, clerkCreatedAt: new Date(params.created_at), email: email?.email_address, @@ -49,7 +50,7 @@ export class UserService { if (id) { pino.info('delete user due to clerk webhook'); - await UserModel.deleteUser(id); + await UserModel.deleteUser(serverDB, id); return NextResponse.json({ message: 'user deleted' }, { status: 200 }); } else { @@ -61,10 +62,10 @@ export class UserService { updateUser = async (id: string, params: UserJSON) => { pino.info('updating user due to clerk webhook'); - const userModel = new UserModel(); + const userModel = new UserModel(serverDB, id); // Check if user already exists - const res = await UserModel.findById(id); + const res = await UserModel.findById(serverDB, id); // If user not exists, skip update the user if (!res) @@ -79,7 +80,7 @@ export class UserService { const email = params.email_addresses.find((e) => e.id === params.primary_email_address_id); const phone = params.phone_numbers.find((e) => e.id === params.primary_phone_number_id); - await userModel.updateUser(id, { + await userModel.updateUser({ avatar: params.image_url, email: email?.email_address, firstName: params.first_name, diff --git a/src/types/meta.ts b/src/types/meta.ts index 459aece85cdb..23ce2b941c2f 100644 --- a/src/types/meta.ts +++ b/src/types/meta.ts @@ -21,19 +21,10 @@ export const LobeMetaDataSchema = z.object({ export type MetaData = z.infer; export interface BaseDataModel { - /** - * @deprecated - */ - createAt?: number; - createdAt: number; id: string; meta: MetaData; - /** - * @deprecated - */ - updateAt?: number; updatedAt: number; } diff --git a/src/types/topic/topic.ts b/src/types/topic/topic.ts index 9f3085602f43..83201e7cb7e5 100644 --- a/src/types/topic/topic.ts +++ b/src/types/topic/topic.ts @@ -36,7 +36,7 @@ export interface ChatTopicSummary { } export interface ChatTopic extends Omit { - favorite?: boolean | null; + favorite?: boolean; historySummary?: string | null; metadata?: ChatTopicMetadata; sessionId?: string;