Skip to content

Commit

Permalink
♻️ refactor: seperate user keyVaults encrpyto from user model (#5102)
Browse files Browse the repository at this point in the history
  • Loading branch information
arvinxx authored Dec 20, 2024
1 parent fc26f15 commit 09b63cf
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 80 deletions.
1 change: 0 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@
"i18next-resources-to-backend": "^1.2.1",
"idb-keyval": "^6.2.1",
"immer": "^10.1.1",
"ip": "^2.0.1",
"jose": "^5.9.4",
"js-sha256": "^0.11.0",
"jsonl-parse-stringify": "^1.0.3",
Expand Down
1 change: 1 addition & 0 deletions src/database/schemas/user.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ export const userSettings = pgTable('user_settings', {
defaultAgent: jsonb('default_agent'),
tool: jsonb('tool'),
});
export type UserSettingsItem = typeof userSettings.$inferSelect;

export const installedPlugins = pgTable(
'user_installed_plugins',
Expand Down
42 changes: 23 additions & 19 deletions src/database/server/models/__tests__/user.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
import { UserGuide, UserPreference } from '@/types/user';
import { UserSettings } from '@/types/user/settings';

import { userSettings, users } from '../../../schemas';
import { UserSettingsItem, userSettings, users } from '../../../schemas';
import { SessionModel } from '../session';
import { UserModel } from '../user';

Expand Down Expand Up @@ -101,7 +101,7 @@ describe('UserModel', () => {
keyVaults: encryptedKeyVaults,
});

const state = await userModel.getUserState();
const state = await userModel.getUserState(KeyVaultsGateKeeper.getUserKeyVaults);

expect(state.userId).toBe(userId);
expect(state.preference).toEqual(preference);
Expand All @@ -111,7 +111,9 @@ describe('UserModel', () => {
it('should throw an error if user not found', async () => {
const userModel = new UserModel(serverDB, 'invalid-user-id');

await expect(userModel.getUserState()).rejects.toThrow('user not found');
await expect(userModel.getUserState(KeyVaultsGateKeeper.getUserKeyVaults)).rejects.toThrow(
'user not found',
);
});
});

Expand Down Expand Up @@ -144,11 +146,10 @@ describe('UserModel', () => {
});

describe('updateSetting', () => {
it('should update user settings with encrypted keyVaults', async () => {
it('should update user settings with new item', async () => {
const settings = {
general: { language: 'en-US' },
keyVaults: { openai: { apiKey: 'secret' } },
} as UserSettings;
} as UserSettingsItem;
await serverDB.insert(users).values({ id: userId });

await userModel.updateSetting(settings);
Expand All @@ -157,23 +158,18 @@ describe('UserModel', () => {
where: eq(users.id, userId),
});
expect(updatedSettings?.general).toEqual(settings.general);
expect(updatedSettings?.keyVaults).not.toBe(JSON.stringify(settings.keyVaults));

const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();
const { plaintext } = await gateKeeper.decrypt(updatedSettings!.keyVaults!);
expect(JSON.parse(plaintext)).toEqual(settings.keyVaults);
});

it('should update user settings with encrypted keyVaults', async () => {
it('should update user settings with exist item', async () => {
const settings = {
general: { language: 'en-US' },
} as UserSettings;
} as UserSettingsItem;
await serverDB.insert(users).values({ id: userId });
await serverDB.insert(userSettings).values({ ...settings, keyVaults: '', id: userId });

const newSettings = {
general: { fontSize: 16, language: 'zh-CN', themeMode: 'dark' },
} as UserSettings;
} as UserSettingsItem;
await userModel.updateSetting(newSettings);

const updatedSettings = await serverDB.query.userSettings.findFirst({
Expand Down Expand Up @@ -229,14 +225,18 @@ describe('UserModel', () => {
keyVaults: encryptedKeyVaults,
});

const result = await UserModel.getUserApiKeys(serverDB, userId);
const result = await UserModel.getUserApiKeys(
serverDB,
userId,
KeyVaultsGateKeeper.getUserKeyVaults,
);
expect(result).toEqual(keyVaults);
});

it('should throw error when user not found', async () => {
await expect(UserModel.getUserApiKeys(serverDB, 'non-existent-id')).rejects.toThrow(
'user not found',
);
await expect(
UserModel.getUserApiKeys(serverDB, 'non-existent-id', KeyVaultsGateKeeper.getUserKeyVaults),
).rejects.toThrow('user not found');
});

it('should handle decrypt failure and return empty object', async () => {
Expand All @@ -249,7 +249,11 @@ describe('UserModel', () => {
keyVaults: invalidEncryptedData,
});

const result = await UserModel.getUserApiKeys(serverDB, userId);
const result = await UserModel.getUserApiKeys(
serverDB,
userId,
KeyVaultsGateKeeper.getUserKeyVaults,
);
expect(result).toEqual({});
});
});
Expand Down
83 changes: 25 additions & 58 deletions src/database/server/models/user.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@ import { eq } from 'drizzle-orm/expressions';
import { DeepPartial } from 'utility-types';

import { LobeChatDatabase } from '@/database/type';
import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
import { UserGuide, UserPreference } from '@/types/user';
import { UserKeyVaults, UserSettings } from '@/types/user/settings';
import { merge } from '@/utils/merge';

import { NewUser, UserItem, userSettings, users } from '../../schemas';
import { NewUser, UserItem, UserSettingsItem, userSettings, users } from '../../schemas';
import { SessionModel } from './session';

type DecryptUserKeyVaults = (
encryptKeyVaultsStr: string | null,
userId?: string,
) => Promise<UserKeyVaults>;

export class UserNotFoundError extends TRPCError {
constructor() {
super({ code: 'UNAUTHORIZED', message: 'user not found' });
Expand All @@ -26,7 +30,7 @@ export class UserModel {
this.db = db;
}

getUserState = async () => {
getUserState = async (decryptor: DecryptUserKeyVaults) => {
const result = await this.db
.select({
isOnboarded: users.isOnboarded,
Expand All @@ -51,19 +55,7 @@ export class UserModel {
const state = result[0];

// Decrypt keyVaults
let decryptKeyVaults = {};
if (state.settingsKeyVaults) {
const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();
const { wasAuthentic, plaintext } = await gateKeeper.decrypt(state.settingsKeyVaults);

if (wasAuthentic) {
try {
decryptKeyVaults = JSON.parse(plaintext);
} catch (e) {
console.error(`Failed to parse keyVaults ,userId: ${this.userId}. Error:`, e);
}
}
}
const decryptKeyVaults = await decryptor(state.settingsKeyVaults, this.userId);

const settings: DeepPartial<UserSettings> = {
defaultAgent: state.settingsDefaultAgent || {},
Expand Down Expand Up @@ -94,32 +86,17 @@ export class UserModel {
return this.db.delete(userSettings).where(eq(userSettings.id, this.userId));
};

updateSetting = async (value: Partial<UserSettings>) => {
const { keyVaults, ...res } = value;

// Encrypt keyVaults
let encryptedKeyVaults: string | null = null;

if (keyVaults) {
// TODO: better to add a validation
const data = JSON.stringify(keyVaults);
const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();

encryptedKeyVaults = await gateKeeper.encrypt(data);
}

const newValue = { ...res, keyVaults: encryptedKeyVaults };

// update or create user settings
const settings = await this.db.query.userSettings.findFirst({
where: eq(users.id, this.userId),
});
if (!settings) {
await this.db.insert(userSettings).values({ id: this.userId, ...newValue });
return;
}

return this.db.update(userSettings).set(newValue).where(eq(userSettings.id, this.userId));
updateSetting = async (value: Partial<UserSettingsItem>) => {
return this.db
.insert(userSettings)
.values({
id: this.userId,
...value,
})
.onConflictDoUpdate({
set: value,
target: userSettings.id,
});
};

updatePreference = async (value: Partial<UserPreference>) => {
Expand Down Expand Up @@ -175,7 +152,11 @@ export class UserModel {
return db.query.users.findFirst({ where: eq(users.email, email) });
};

static getUserApiKeys = async (db: LobeChatDatabase, id: string) => {
static getUserApiKeys = async (
db: LobeChatDatabase,
id: string,
decryptor: DecryptUserKeyVaults,
) => {
const result = await db
.select({
settingsKeyVaults: userSettings.keyVaults,
Expand All @@ -190,20 +171,6 @@ export class UserModel {
const state = result[0];

// Decrypt keyVaults
let decryptKeyVaults = {};
if (state.settingsKeyVaults) {
const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();
const { wasAuthentic, plaintext } = await gateKeeper.decrypt(state.settingsKeyVaults);

if (wasAuthentic) {
try {
decryptKeyVaults = JSON.parse(plaintext);
} catch (e) {
console.error(`Failed to parse keyVaults ,userId: ${id}. Error:`, e);
}
}
}

return decryptKeyVaults as UserKeyVaults;
return await decryptor(state.settingsKeyVaults, id);
};
}
23 changes: 23 additions & 0 deletions src/server/modules/KeyVaultsEncrypt/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { getServerDBConfig } from '@/config/db';
import { UserKeyVaults } from '@/types/user/settings';

interface DecryptionResult {
plaintext: string;
Expand Down Expand Up @@ -90,4 +91,26 @@ If you don't have it, please run \`openssl rand -base64 32\` to create one.
};
}
};

static getUserKeyVaults = async (
encryptedKeyVaults: string | null,
userId?: string,
): Promise<UserKeyVaults> => {
if (!encryptedKeyVaults) return {};
// Decrypt keyVaults
let decryptKeyVaults = {};

const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();
const { wasAuthentic, plaintext } = await gateKeeper.decrypt(encryptedKeyVaults);

if (wasAuthentic) {
try {
decryptKeyVaults = JSON.parse(plaintext);
} catch (e) {
console.error(`Failed to parse keyVaults, userId: ${userId}. Error:`, e);
}
}

return decryptKeyVaults as UserKeyVaults;
};
}
21 changes: 19 additions & 2 deletions src/server/routers/lambda/user.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import { MessageModel } from '@/database/server/models/message';
import { SessionModel } from '@/database/server/models/session';
import { UserModel, UserNotFoundError } from '@/database/server/models/user';
import { authedProcedure, router } from '@/libs/trpc';
import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
import { UserService } from '@/server/services/user';
import { UserGuideSchema, UserInitializationState, UserPreference } from '@/types/user';
import { UserSettings } from '@/types/user/settings';

const userProcedure = authedProcedure.use(async (opts) => {
return opts.next({
Expand All @@ -24,7 +26,7 @@ export const userRouter = router({
// get or create first-time user
while (!state) {
try {
state = await ctx.userModel.getUserState();
state = await ctx.userModel.getUserState(KeyVaultsGateKeeper.getUserKeyVaults);
} catch (error) {
if (enableClerk && error instanceof UserNotFoundError) {
const user = await currentUser();
Expand Down Expand Up @@ -97,7 +99,22 @@ export const userRouter = router({
updateSettings: userProcedure
.input(z.object({}).passthrough())
.mutation(async ({ ctx, input }) => {
return ctx.userModel.updateSetting(input);
const { keyVaults, ...res } = input as Partial<UserSettings>;

// Encrypt keyVaults
let encryptedKeyVaults: string | null = null;

if (keyVaults) {
// TODO: better to add a validation
const data = JSON.stringify(keyVaults);
const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();

encryptedKeyVaults = await gateKeeper.encrypt(data);
}

const nextValue = { ...res, keyVaults: encryptedKeyVaults };

return ctx.userModel.updateSetting(nextValue);
}),
});

Expand Down
5 changes: 5 additions & 0 deletions src/server/services/user/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { UserJSON } from '@clerk/backend';
import { serverDB } from '@/database/server';
import { UserModel } from '@/database/server/models/user';
import { pino } from '@/libs/logger';
import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';

export class UserService {
createUser = async (id: string, params: UserJSON) => {
Expand Down Expand Up @@ -84,4 +85,8 @@ export class UserService {

return { message: 'user updated', success: true };
};

getUserApiKeys = async (id: string) => {
return UserModel.getUserApiKeys(serverDB, id, KeyVaultsGateKeeper.getUserKeyVaults);
};
}

0 comments on commit 09b63cf

Please sign in to comment.