diff --git a/_generated/embeddings-db.d.ts b/_generated/embeddings-db.d.ts new file mode 100644 index 0000000..5602919 --- /dev/null +++ b/_generated/embeddings-db.d.ts @@ -0,0 +1,8 @@ +export interface Embeddings { + text: string; + embedding: Buffer; +} + +export interface DB { + embeddings: Embeddings; +} diff --git a/electron-src/data/base-database.ts b/electron-src/data/base-database.ts new file mode 100644 index 0000000..31c95f8 --- /dev/null +++ b/electron-src/data/base-database.ts @@ -0,0 +1,111 @@ +import { Kysely } from "kysely"; +import logger from "../utils/logger"; +import type { Database } from "better-sqlite3"; +import SqliteDb from "better-sqlite3"; +import { SqliteDialect } from "kysely"; +import type { KyselyConfig } from "kysely/dist/cjs/kysely"; +import { format } from "sql-formatter"; + +type PostSetupCallback = (db: Database, ks: Kysely) => Promise; +const debugLoggingEnabled = process.env.DEBUG_LOGGING === "true"; +export class BaseDatabase { + path: string; + name: string; + postSetup: PostSetupCallback | undefined; + dbWriter: Kysely | undefined; + isSettingUpDb = false; + + private initializationPromise!: Promise; + + constructor(name: string, path: string, postSetup?: PostSetupCallback) { + this.path = path; + this.name = name; + this.postSetup = postSetup; + } + + isDbInitialized = () => { + return !!this.dbWriter; + }; + + initialize = () => { + if (this.initializationPromise) { + return this.initializationPromise; + } + this.initializationPromise = new Promise(async (resolve) => { + const success = await this.trySetupDb(); + if (!success) { + const startTimeout = () => + setTimeout(async () => { + const success = await this.trySetupDb(); + if (success) { + resolve(); + } else { + startTimeout(); + } + }, 1000); + startTimeout(); + } else { + resolve(); + } + }); + this.initializationPromise.then(() => { + logger.info(`${this.name} initialized`); + }); + + return this.initializationPromise; + }; + trySetupDb = async () => { + try { + if (this.isSettingUpDb) { + return false; + } + this.isSettingUpDb = true; + logger.info(`Setting up ${this.name}`); + const sqliteDb = new SqliteDb(this.path, { fileMustExist: false }); + const dialect = new SqliteDialect({ database: sqliteDb }); + const options: KyselyConfig = { + dialect, + log(event): void { + const isError = event.level === "error"; + + if (isError || debugLoggingEnabled) { + const { sql, parameters } = event.query; + + const { queryDurationMillis } = event; + const duration = queryDurationMillis.toFixed(2); + const params = (parameters as string[]) || []; + const formattedSql = format(sql, { params: params.map((l) => String(l)), language: "sqlite" }); + if (event.level === "query") { + logger.debug(`[Query - ${duration}ms]:\n${formattedSql}\n`); + } + + if (isError) { + logger.error(`[SQL Error - ${duration}ms]: ${event.error}\n\n${formattedSql}\n`); + } + } + }, + }; + + const db = new Kysely(options); + if (this.postSetup) { + await this.postSetup(sqliteDb, db); + } + this.dbWriter = db; + return true; + } catch (e) { + console.error(e); + return false; + } finally { + this.isSettingUpDb = false; + } + }; + + get db() { + if (!this.dbWriter) { + throw new Error(`${this.name} not initialized!`); + } + return this.dbWriter; + } +} + +export default BaseDatabase; diff --git a/electron-src/data/chroma.ts b/electron-src/data/chroma.ts deleted file mode 100644 index c3ce76d..0000000 --- a/electron-src/data/chroma.ts +++ /dev/null @@ -1,24 +0,0 @@ -import type { Collection } from "chromadb"; -import { ChromaClient } from "chromadb"; - -const COLLECTION_NAME = "mimessage-embeddings"; -export const getCollection = async () => { - try { - const collection: Collection | null = await Promise.race([ - new Promise((resolve) => setTimeout(() => resolve(null), 10000)), - (async () => { - const client = new ChromaClient(); - const collections = await client.listCollections(); - const collection: Collection = collections.find((l: any) => l.name === COLLECTION_NAME) - ? await client.getCollection(COLLECTION_NAME) - : await client.createCollection(COLLECTION_NAME, {}); - - return collection; - })(), - ]); - - return collection; - } catch (e) { - return null; - } -}; diff --git a/electron-src/data/database.ts b/electron-src/data/database.ts index 7bdc978..15d7d69 100644 --- a/electron-src/data/database.ts +++ b/electron-src/data/database.ts @@ -1,157 +1,19 @@ -import SqliteDb from "better-sqlite3"; -import type { SelectQueryBuilder } from "kysely"; -import { Kysely, sql, SqliteDialect } from "kysely"; +import type { SelectQueryBuilder, Kysely } from "kysely"; +import { sql } from "kysely"; import type { DB as MesssagesDatabase } from "../../_generated/types"; import logger from "../utils/logger"; -import type { KyselyConfig } from "kysely/dist/cjs/kysely"; import { countBy, groupBy, partition } from "lodash-es"; import type { Contact } from "electron-mac-contacts"; -import { format } from "sql-formatter"; import { decodeMessageBuffer, getTextFromBuffer } from "../utils/buffer"; -import { localDbExists } from "./db-file-utils"; import { appMessagesDbCopy } from "../utils/constants"; -import { getStatsForText } from "./semantic-search-stats"; -import { removeStopWords } from "./text"; +import { getStatsForText } from "../semantic-search/semantic-search-stats"; +import { removeStopWords } from "../utils/text"; +import BaseDatabase from "./base-database"; type ExtractO = T extends SelectQueryBuilder ? O : never; type JoinedMessageType = ExtractO>; -const debugLoggingEnabled = process.env.DEBUG_LOGGING === "true"; -export class SQLDatabase { - path: string = appMessagesDbCopy; - private dbWriter: Kysely | undefined; - - initializationPromise!: Promise; - - isDbInitialized = () => { - return !!this.dbWriter; - }; - - initialize = () => { - if (this.initializationPromise) { - return this.initializationPromise; - } - this.initializationPromise = new Promise(async (resolve) => { - const success = await this.trySetupDb(); - if (!success) { - const startTimeout = () => - setTimeout(async () => { - const success = await this.trySetupDb(); - if (success) { - resolve(); - } else { - startTimeout(); - } - }, 1000); - startTimeout(); - } else { - resolve(); - } - }); - this.initializationPromise.then(() => { - logger.info("DB initialized"); - }); - - return this.initializationPromise; - }; - - addParsedTextToNullMessages = async () => { - const db = this.db; - const messagesWithNullText = await db - .selectFrom("message") - .select(["ROWID", "attributedBody"]) - .where("text", "is", null) - .where("attributedBody", "is not", null) - .execute(); - - if (messagesWithNullText.length) { - const now = performance.now(); - logger.info(`Adding parsed text to ${messagesWithNullText.length} messages`); - - await db.transaction().execute(async (trx) => { - for (const message of messagesWithNullText) { - try { - const { attributedBody, ROWID } = message; - if (attributedBody) { - const parsed = await decodeMessageBuffer(attributedBody); - if (parsed) { - const string = parsed[0]?.value?.string; - if (typeof string === "string") { - await trx.updateTable("message").set({ text: string }).where("ROWID", "=", ROWID).executeTakeFirst(); - } - } - } - } catch { - //skip - } - } - }); - logger.info(`Done adding parsed text to ${messagesWithNullText.length} messages in ${performance.now() - now}ms`); - } - }; - get db() { - if (!this.dbWriter) { - throw new Error("DB not initialized!"); - } - return this.dbWriter; - } - - isSettingUpDb = false; - trySetupDb = async () => { - try { - if (this.isSettingUpDb || !(await localDbExists())) { - return false; - } - this.isSettingUpDb = true; - logger.info("Setting up db"); - const sqliteDb = new SqliteDb(this.path); - const dialect = new SqliteDialect({ database: sqliteDb }); - const options: KyselyConfig = { - dialect, - log(event): void { - const isError = event.level === "error"; - - if (isError || debugLoggingEnabled) { - const { sql, parameters } = event.query; - - const { queryDurationMillis } = event; - const duration = queryDurationMillis.toFixed(2); - const params = (parameters as string[]) || []; - const formattedSql = format(sql, { params: params.map((l) => String(l)), language: "sqlite" }); - if (event.level === "query") { - logger.debug(`[Query - ${duration}ms]:\n${formattedSql}\n`); - } - - if (isError) { - logger.error(`[SQL Error - ${duration}ms]: ${event.error}\n\n${formattedSql}\n`); - } - } - }, - }; - - const db = new Kysely(options); - this.dbWriter = db; - // add in text - logger.info("Adding text column to messages"); - await this.addParsedTextToNullMessages(); - logger.info("Adding text column to messages done"); - - // create virtual table if not exists - logger.info("Creating virtual table"); - await sqliteDb.exec("CREATE VIRTUAL TABLE IF NOT EXISTS message_fts USING fts5(text,message_id)"); - const count = await db.selectFrom("message_fts").select("message_id").limit(1).executeTakeFirst(); - if (count === undefined) { - await sqliteDb.exec("INSERT INTO message_fts SELECT text, guid as message_id FROM message"); - } - logger.info("Creating virtual table done"); - return true; - } catch (e) { - console.error(e); - return false; - } finally { - this.isSettingUpDb = false; - } - }; +export class SQLDatabase extends BaseDatabase { private getChatsWithMessagesQuery = () => { const db = this.db; return db @@ -243,6 +105,25 @@ export class SQLDatabase { .select(sql`message.ROWID`.as("message_id")); }; + getMessageGuidsFromText = async (texts: string[]) => { + const db = this.db; + const query = db.selectFrom("message").select(["guid", "text"]).where("text", "in", texts); + const results = await query.execute(); + const indexMap = new Map(); + for (let i = 0; i < texts.length; i++) { + indexMap.set(texts[i], i); + } + results.sort((a, b) => { + const aIndex = indexMap.get(a.text); + const bIndex = indexMap.get(b.text); + if (aIndex === undefined || bIndex === undefined) { + return 0; + } + return aIndex - bIndex; + }); + return results.map((r) => r.guid); + }; + globalSearchTextBased = async ( searchTerm: string, chatIds?: number[], @@ -697,7 +578,55 @@ export class SQLDatabase { }; } -const db = new SQLDatabase(); +const addParsedTextToNullMessages = async (db: Kysely) => { + const messagesWithNullText = await db + .selectFrom("message") + .select(["ROWID", "attributedBody"]) + .where("text", "is", null) + .where("attributedBody", "is not", null) + .execute(); + + if (messagesWithNullText.length) { + const now = performance.now(); + logger.info(`Adding parsed text to ${messagesWithNullText.length} messages`); + + await db.transaction().execute(async (trx) => { + for (const message of messagesWithNullText) { + try { + const { attributedBody, ROWID } = message; + if (attributedBody) { + const parsed = await decodeMessageBuffer(attributedBody); + if (parsed) { + const string = parsed[0]?.value?.string; + if (typeof string === "string") { + await trx.updateTable("message").set({ text: string }).where("ROWID", "=", ROWID).executeTakeFirst(); + } + } + } + } catch { + //skip + } + } + }); + logger.info(`Done adding parsed text to ${messagesWithNullText.length} messages in ${performance.now() - now}ms`); + } +}; + +const db = new SQLDatabase("Messages DB", appMessagesDbCopy, async (sqliteDb, kdb) => { + // add in text + logger.info("Adding text column to messages"); + await addParsedTextToNullMessages(kdb); + logger.info("Adding text column to messages done"); + + // create virtual table if not exists + logger.info("Creating virtual table"); + await sqliteDb.exec("CREATE VIRTUAL TABLE IF NOT EXISTS message_fts USING fts5(text,message_id)"); + const count = await kdb.selectFrom("message_fts").select("message_id").limit(1).executeTakeFirst(); + if (count === undefined) { + await sqliteDb.exec("INSERT INTO message_fts SELECT text, guid as message_id FROM message"); + } + logger.info("Creating virtual table done"); +}); // monkey patch to handle ipc calls diff --git a/electron-src/data/db-file-utils.ts b/electron-src/data/db-file-utils.ts index 194b5dd..2eaeba3 100644 --- a/electron-src/data/db-file-utils.ts +++ b/electron-src/data/db-file-utils.ts @@ -1,8 +1,19 @@ import jetpack from "fs-jetpack"; import logger from "../utils/logger"; import { appMessagesDbCopy, messagesDb } from "../utils/constants"; +import childProcess from "child_process"; +const exec = childProcess.exec; export const copyLatestDb = async () => { + logger.info("Killing iMessage if it's running"); + try { + exec("pkill -f Messages"); + // wait 500ms for iMessage to close + await new Promise((resolve) => setTimeout(resolve, 500)); + logger.info("iMessage killed"); + } catch (e) { + logger.info("iMessage was not running"); + } await copyDbAtPath(messagesDb); }; export const copyDbAtPath = async (path: string) => { diff --git a/electron-src/data/embeddings-database.ts b/electron-src/data/embeddings-database.ts new file mode 100644 index 0000000..18d6ef8 --- /dev/null +++ b/electron-src/data/embeddings-database.ts @@ -0,0 +1,92 @@ +import type { DB as EmbeddingsDb } from "../../_generated/embeddings-db"; +import logger from "../utils/logger"; +import { embeddingsDbPath } from "../utils/constants"; +import BaseDatabase from "./base-database"; + +export class EmbeddingsDatabase extends BaseDatabase { + embeddingsCache: { text: string; embedding: Float32Array }[] = []; + countEmbeddings = async (): Promise => { + await this.initialize(); + const result = await this.db + .selectFrom("embeddings") + .select((e) => e.fn.count("embeddings.text").as("count")) + .execute(); + return result[0].count as number; + }; + + getAllEmbeddings = async () => { + if (this.embeddingsCache.length) { + return this.embeddingsCache; + } + await this.initialize(); + const result = await this.db.selectFrom("embeddings").selectAll().execute(); + + this.embeddingsCache = result.map((r) => { + const embedding = r.embedding!; + return { + text: r.text, + embedding: new Float32Array( + embedding.buffer, + embedding.byteOffset, + embedding.byteLength / Float32Array.BYTES_PER_ELEMENT, + ), + }; + }); + return this.embeddingsCache; + }; + getEmbeddingByText = async (text: string) => { + await this.initialize(); + const result = await this.db.selectFrom("embeddings").where("text", "=", text).selectAll().executeTakeFirst(); + if (!result) { + return null; + } + const embedding = result.embedding!; + return { + text: result.text, + embedding: new Float32Array( + embedding.buffer, + embedding.byteOffset, + embedding.byteLength / Float32Array.BYTES_PER_ELEMENT, + ), + }; + }; + + getAllText = async (): Promise => { + await this.initialize(); + const result = await this.db.selectFrom("embeddings").select("text").execute(); + return result.map((l) => l.text!); + }; + + insertEmbeddings = async (embeddings: { text: string; embedding: number[] }[]) => { + await this.initialize(); + const values = embeddings.map((e) => { + const typedBuffer = new Float32Array(e.embedding); + const buffer = Buffer.from(typedBuffer.buffer); + return { + text: e.text, + embedding: buffer, + }; + }); + const insert = this.db + .insertInto("embeddings") + .values(values) + .onConflict((oc) => oc.column("text").doNothing()); + await insert.execute(); + }; +} + +const embeddingsDb = new EmbeddingsDatabase("Embeddings DB", embeddingsDbPath, async (db) => { + // create virtual table if not exists + logger.info("Creating index table"); + await db.exec(` +CREATE TABLE if not exists embeddings ( + text TEXT PRIMARY KEY NOT NULL, + embedding BLOB NOT NULL +); +CREATE UNIQUE INDEX if not exists idx_embeddings +ON embeddings (text); + `); + logger.info("Creating index table done"); +}); + +export default embeddingsDb; diff --git a/electron-src/esbuild.main.config.ts b/electron-src/esbuild.main.config.ts index 41f5757..64560c6 100644 --- a/electron-src/esbuild.main.config.ts +++ b/electron-src/esbuild.main.config.ts @@ -5,7 +5,7 @@ const config: BuildOptions = { platform: "node", entryPoints: [ path.resolve("electron-src/index.ts"), - path.resolve("electron-src/data/worker.ts"), + path.resolve("electron-src/workers/worker.ts"), path.resolve("electron-src/utils/preload.ts"), ], bundle: true, diff --git a/electron-src/index.ts b/electron-src/index.ts index 6ce3418..5ca1fd5 100644 --- a/electron-src/index.ts +++ b/electron-src/index.ts @@ -1,9 +1,9 @@ import { app, Menu, nativeTheme, protocol, shell } from "electron"; // Global imports to monkeypatch/polyfill/register -import "./data/semantic-search"; -import "./data/ipc"; -import "./data/options"; -import "./data/ipc-onboarding"; +import "./semantic-search/semantic-search"; +import "./ipc/ipc"; +import "./options"; +import "./ipc/ipc-onboarding"; import "./utils/dns-cache"; // normal imports import type { CustomScheme } from "electron"; @@ -16,10 +16,10 @@ import { getMenu } from "./window/menu"; import "better-sqlite3"; import { logPath, logStream, mainAppIconDevPng } from "./constants"; import logger from "./utils/logger"; -import { setupRouteHandlers } from "./data/routes"; +import { setupRouteHandlers } from "./utils/routes"; import { DESKTOP_VERSION } from "./versions"; import { autoUpdater } from "electron-updater"; -import dbWorker from "./data/database-worker"; +import dbWorker from "./workers/database-worker"; import winston from "winston"; addFlags(app); diff --git a/electron-src/data/ipc-onboarding.ts b/electron-src/ipc/ipc-onboarding.ts similarity index 99% rename from electron-src/data/ipc-onboarding.ts rename to electron-src/ipc/ipc-onboarding.ts index 523b3a0..917eef1 100644 --- a/electron-src/data/ipc-onboarding.ts +++ b/electron-src/ipc/ipc-onboarding.ts @@ -2,7 +2,7 @@ import { getAllContacts, getAuthStatus, requestAccess } from "electron-mac-contacts"; import logger from "../utils/logger"; import { askForFullDiskAccess, getAuthStatus as getPermissionsStatus } from "node-electron-permissions"; -import { setSkipContactsPermsCheck, shouldSkipContactsCheck } from "./options"; +import { setSkipContactsPermsCheck, shouldSkipContactsCheck } from "../options"; import { handleIpc } from "./ipc"; import { v4 as uuid } from "uuid"; handleIpc("contacts", async () => { diff --git a/electron-src/data/ipc.ts b/electron-src/ipc/ipc.ts similarity index 97% rename from electron-src/data/ipc.ts rename to electron-src/ipc/ipc.ts index 480304a..8571055 100644 --- a/electron-src/data/ipc.ts +++ b/electron-src/ipc/ipc.ts @@ -1,6 +1,6 @@ import type { IpcMainInvokeEvent } from "electron"; import { dialog, ipcMain, shell } from "electron"; // deconstructing assignment -import type { SQLDatabase } from "./database"; +import type { SQLDatabase } from "../data/database"; import fs from "fs-extra"; import jsonexport from "jsonexport"; import jetpack from "fs-jetpack"; @@ -10,7 +10,7 @@ import { fileTypeFromFile } from "file-type"; import { debugLoggingEnabled } from "../constants"; import logger from "../utils/logger"; import { decodeMessageBuffer } from "../utils/buffer"; -import dbWorker from "./database-worker"; +import dbWorker from "../workers/database-worker"; export const handleIpc = (event: string, handler: (...args: any[]) => unknown) => { ipcMain.handle(event, async (e: IpcMainInvokeEvent, ...args) => { if (debugLoggingEnabled) { @@ -69,7 +69,7 @@ handleIpc( } const messages = await dbWorker.worker.getMessagesForChatId(chat.chat_id!); - type HandleType = typeof handles[number]; + type HandleType = (typeof handles)[number]; const handleMap: Record = {}; for (const handle of handles) { handleMap[handle.ROWID!] = handle; diff --git a/electron-src/data/options.ts b/electron-src/options.ts similarity index 100% rename from electron-src/data/options.ts rename to electron-src/options.ts diff --git a/electron-src/data/batch-utils.ts b/electron-src/semantic-search/batch-utils.ts similarity index 60% rename from electron-src/data/batch-utils.ts rename to electron-src/semantic-search/batch-utils.ts index 215e52c..4e2af5f 100644 --- a/electron-src/data/batch-utils.ts +++ b/electron-src/semantic-search/batch-utils.ts @@ -1,4 +1,3 @@ -import type { Collection } from "chromadb"; import logger from "../utils/logger"; import type { OpenAIApi } from "openai"; import type { Chunk, SemanticSearchMetadata, SemanticSearchVector } from "./semantic-search"; @@ -8,7 +7,7 @@ import { pRateLimit } from "p-ratelimit"; export class BatchOpenAi { private openai: OpenAIApi; private batch: PendingVector[] = []; - private batchSize = 250; // create 100 embeddings at a time with the openai api + private batchSize = 500; // create 500 embeddings at a time with the openai api constructor(openai: OpenAIApi) { this.openai = openai; @@ -45,55 +44,6 @@ export class BatchOpenAi { return []; } } -const chromaLimit = pRateLimit({ - concurrency: 8, // no more than 60 running at once -}); -export class BatchChroma { - private collection: Collection; - private batch: SemanticSearchVector[] = []; - private batchSize = 1000; - - flushPromise: Promise | null = null; - constructor(collection: Collection) { - this.collection = collection; - // @ts-ignore - this.collection.api.axios.defaults.maxContentLength = Infinity; - // @ts-ignore - this.collection.api.axios.defaults.maxBodyLength = Infinity; - } - - public async insert(vector: SemanticSearchVector[]) { - const flushPromise = this.flushPromise; - if (flushPromise) { - await flushPromise; - } - this.batch.push(...vector); - if (this.batch.length >= this.batchSize) { - this.flushPromise = this.flush(); - await this.flushPromise; - } - } - - public async flush() { - if (this.batch.length === 0) { - return; - } - const batch = this.batch; - this.batch = []; - logger.info(`Inserting ${batch.length} vectors`); - const ids = batch.map((item) => item.id); - const embeddings = batch.map((item) => item.values); - const text = batch.map((item) => item.metadata.text); - const metadata = batch.map((item) => item.metadata); - try { - await chromaLimit(async () => { - await this.collection.add(ids, embeddings, metadata, text); - }); - } catch (e) { - logger.error(e); - } - } -} interface PendingVector { id: string; @@ -110,12 +60,13 @@ const rateLimit = pRateLimit({ rate: 3500, // 3500 calls per minute concurrency: 60, // no more than 60 running at once }); - const embeddingsFromPendingVectors = async (pendingVectors: PendingVector[], openai: OpenAIApi) => { - const input = pendingVectors.map((c) => c.input); + const vectors: SemanticSearchVector[] = []; + let timeout = 10_000; - while (input.length) { + while (pendingVectors.length) { try { + const input = pendingVectors.map((l) => l.input); const { data: embed } = await rateLimit(() => openai.createEmbedding({ input, @@ -123,7 +74,6 @@ const embeddingsFromPendingVectors = async (pendingVectors: PendingVector[], ope }), ); const embeddings = embed.data; - const vectors: SemanticSearchVector[] = []; for (let i = 0; i < embeddings.length; i++) { const embedding = embeddings[i].embedding; if (embedding) { diff --git a/electron-src/data/semantic-search-stats.ts b/electron-src/semantic-search/semantic-search-stats.ts similarity index 73% rename from electron-src/data/semantic-search-stats.ts rename to electron-src/semantic-search/semantic-search-stats.ts index 32ff9a9..0501cd6 100644 --- a/electron-src/data/semantic-search-stats.ts +++ b/electron-src/semantic-search/semantic-search-stats.ts @@ -4,16 +4,18 @@ const tokenizer = new GPT4Tokenizer({ type: "gpt3" }); export const getStatsForText = (text: { text: string | null }[]) => { let totalTokens = 0; + const uniqueText = new Set(); for (const line of text) { - if (line.text) { + if (line.text && !uniqueText.has(line.text)) { + uniqueText.add(line.text); const tokens = tokenizer.estimateTokenCount(line.text); totalTokens += tokens; } } - const totalMessages = text.length; - const estimatedTimeRpm = totalMessages / 3500 / 10; // we batch so divide it by 50 - const estimatedTimeTpm = totalTokens / 350000 / 10; // we batch so divide it by 50 + const totalMessages = uniqueText.size; + const estimatedTimeRpm = totalMessages / 3500 / 10; // we batch so divide it by a heuristic i eyeballed + const estimatedTimeTpm = totalTokens / 350000 / 10; const estimatedTime = Math.max(estimatedTimeRpm, estimatedTimeTpm); return { totalMessages, diff --git a/electron-src/data/semantic-search.ts b/electron-src/semantic-search/semantic-search.ts similarity index 59% rename from electron-src/data/semantic-search.ts rename to electron-src/semantic-search/semantic-search.ts index 4b7122d..b6f367a 100644 --- a/electron-src/data/semantic-search.ts +++ b/electron-src/semantic-search/semantic-search.ts @@ -1,12 +1,13 @@ import { GPT4Tokenizer } from "gpt4-tokenizer"; import { Configuration, OpenAIApi } from "openai"; -import dbWorker from "./database-worker"; -import { handleIpc } from "./ipc"; +import dbWorker from "../workers/database-worker"; +import { handleIpc } from "../ipc/ipc"; import logger from "../utils/logger"; -import { chunk } from "lodash-es"; -import { BatchChroma, BatchOpenAi, OPENAI_EMBEDDING_MODEL } from "./batch-utils"; -import { getCollection } from "./chroma"; +import { BatchOpenAi, OPENAI_EMBEDDING_MODEL } from "./batch-utils"; import pMap from "p-map"; +import embeddingsDb from "../data/embeddings-database"; +import cosineSimilarity from "./vector-comparison"; +import { uniqBy } from "lodash-es"; export interface SemanticSearchMetadata { id: string; @@ -81,31 +82,14 @@ export const createEmbeddings = async ({ openAiKey }: { openAiKey: string }) => }); const openai = new OpenAIApi(configuration); - const collection = await getCollection(); + await embeddingsDb.initialize(); - if (!collection) { - logger.error("Could not get collection"); - return; - } - - // remove already parsed messages - const chunked = chunk(messages, 5000); - const notParsed = []; - logger.info(`Checking if ${messages.length} messages have been parsed already`); + const existingText = await embeddingsDb.getAllText(); + const set = new Set(existingText); + numCompleted = existingText.length; + const notParsed = messages.filter((m) => m.text && !set.has(m.text)); - for (const chunk of chunked) { - const parsed = await collection.get(chunk.map((m) => `${m.guid}:0`)); - const parsedIds = new Set(parsed.ids.map((m: string) => m.split(":")[0])); - numCompleted += parsedIds.size; - for (const m of chunk) { - if (!parsedIds.has(m.guid)) { - notParsed.push(m); - } - } - } - logger.info(`Found ${notParsed.length} messages that need to be parsed`); - - const batchChroma = new BatchChroma(collection); + const uniqueMessages = uniqBy(notParsed, "text"); const batchOpenai = new BatchOpenAi(openai); const processMessage = async (message: (typeof messages)[number]) => { @@ -117,9 +101,16 @@ export const createEmbeddings = async ({ openAiKey }: { openAiKey: string }) => const chunks = splitIntoChunks(message.text); const itemEmbeddings = await batchOpenai.addPendingVectors(chunks, message.guid); - if (itemEmbeddings && itemEmbeddings.length) { - await batchChroma.insert(itemEmbeddings); - numCompleted += itemEmbeddings.length; + if (itemEmbeddings.length) { + try { + logger.info(`Inserting ${itemEmbeddings.length} vectors`); + const embeddings = itemEmbeddings.map((l) => ({ embedding: l.values, text: l.metadata.text })); + await embeddingsDb.insertEmbeddings(embeddings); + logger.info(`Inserted ${itemEmbeddings.length} vectors`); + numCompleted += itemEmbeddings.length; + } catch (e) { + logger.error(e); + } } if (debugLoggingEnabled) { @@ -133,8 +124,7 @@ export const createEmbeddings = async ({ openAiKey }: { openAiKey: string }) => return []; }; - await pMap(notParsed, processMessage, { concurrency: 100 }); - await batchChroma.flush(); + await pMap(uniqueMessages, processMessage, { concurrency: 100 }); logger.info("Done creating embeddings"); }; @@ -144,27 +134,36 @@ interface SemanticQueryOpts { } export async function semanticQuery({ queryText, openAiKey }: SemanticQueryOpts) { - const configuration = new Configuration({ - apiKey: openAiKey, - }); - const openai = new OpenAIApi(configuration); - const collection = await getCollection(); - if (!collection) { - logger.error("Could not get collection"); - return; - } - const embed = ( - await openai.createEmbedding({ + const existingEmbedding = await embeddingsDb.getEmbeddingByText(queryText); + let floatEmbedding = existingEmbedding?.embedding; + + if (!existingEmbedding) { + const configuration = new Configuration({ + apiKey: openAiKey, + }); + // first look up embedding in db in case we've already done it + const openai = new OpenAIApi(configuration); + const openAiResponse = await openai.createEmbedding({ input: queryText, model: OPENAI_EMBEDDING_MODEL, - }) - ).data; - const embedding = embed.data?.[0]?.embedding; - if (!embedding) { - return []; + }); + const embed = openAiResponse.data; + const embedding = embed.data?.[0]?.embedding; + if (!embedding) { + return []; + } + // save embedding + await embeddingsDb.insertEmbeddings([{ embedding, text: queryText }]); + floatEmbedding = new Float32Array(embedding); } - const results = await collection.query(embedding, 100, undefined, [queryText]); - return results.ids[0].map((id: string) => id.split(":")[0]); + + const allEmbeddings = await embeddingsDb.getAllEmbeddings(); + const similarities = allEmbeddings.map((e) => { + const similarity = cosineSimilarity(floatEmbedding!, e.embedding); + return { similarity, text: e.text }; + }); + similarities.sort((a, b) => b.similarity - a.similarity); + return similarities.slice(0, 100).map((l) => l.text!); } handleIpc("createEmbeddings", async ({ openAiKey: openAiKey }) => { @@ -179,12 +178,15 @@ handleIpc("getEmbeddingsCompleted", async () => { handleIpc("calculateSemanticSearchStatsEnhanced", async () => { const stats = await dbWorker.worker.calculateSemanticSearchStats(); - const collection = await getCollection(); - if (!collection) { + const localDb = embeddingsDb; + try { + await localDb.initialize(); + const count = await localDb.countEmbeddings(); + return { ...stats, completedAlready: count }; + } catch (e) { + logger.error(e); return stats; } - const count = await collection.count(); - return { ...stats, completedAlready: count }; }); handleIpc( "globalSearch", @@ -202,14 +204,16 @@ handleIpc( } if (useSemanticSearch) { logger.info("Using semantic search"); - const ids = await semanticQuery({ + const messageTexts = await semanticQuery({ openAiKey, queryText: searchTerm, }); - logger.info(`Got ${ids.length} results`); + logger.info(`Got ${messageTexts.length} results`); + + const guids = await dbWorker.worker.getMessageGuidsFromText(messageTexts); return await dbWorker.worker.fullTextMessageSearchWithGuids( - ids, + guids, searchTerm, chatIds, handleIds, diff --git a/electron-src/semantic-search/vector-comparison.ts b/electron-src/semantic-search/vector-comparison.ts new file mode 100644 index 0000000..5389768 --- /dev/null +++ b/electron-src/semantic-search/vector-comparison.ts @@ -0,0 +1,45 @@ +type ArrayLike = Float32Array; + +function l2norm(arr: ArrayLike) { + const len = arr.length; + let t = 0; + let s = 1; + let r; + let val; + let abs; + let i; + + for (i = 0; i < len; i++) { + val = arr[i]; + abs = val < 0 ? -val : val; + if (abs > 0) { + if (abs > t) { + r = t / val; + s = 1 + s * r * r; + t = abs; + } else { + r = val / t; + s = s + r * r; + } + } + } + return t * Math.sqrt(s); +} +function dot(x: ArrayLike, y: ArrayLike) { + const len = x.length; + let sum = 0; + let i; + + for (i = 0; i < len; i++) { + sum += x[i] * y[i]; + } + return sum; +} +function cosineSimilarity(x: ArrayLike, y: ArrayLike): number { + const a = dot(x, y); + const b = l2norm(x); + const c = l2norm(y); + return a / (b * c); +} + +export default cosineSimilarity; diff --git a/electron-src/utils/constants.ts b/electron-src/utils/constants.ts index 7b2829d..6331020 100644 --- a/electron-src/utils/constants.ts +++ b/electron-src/utils/constants.ts @@ -4,3 +4,4 @@ import { appPath } from "../versions"; export const messagesDb = os.homedir() + "/Library/Messages/chat.db"; export const appMessagesDbCopy = path.join(os.homedir(), "Library", "Application Support", appPath, "db.sqlite"); +export const embeddingsDbPath = path.join(os.homedir(), "Library", "Application Support", appPath, "embeddings.sqlite"); diff --git a/electron-src/data/routes.ts b/electron-src/utils/routes.ts similarity index 100% rename from electron-src/data/routes.ts rename to electron-src/utils/routes.ts diff --git a/electron-src/data/text.ts b/electron-src/utils/text.ts similarity index 100% rename from electron-src/data/text.ts rename to electron-src/utils/text.ts diff --git a/electron-src/window/main-window.ts b/electron-src/window/main-window.ts index 9a995e2..5708a37 100644 --- a/electron-src/window/main-window.ts +++ b/electron-src/window/main-window.ts @@ -10,7 +10,7 @@ import { showErrorAlert, withRetries } from "../utils/util"; import prepareNext from "../utils/next-helper"; import logger from "../utils/logger"; import { windows } from "../index"; -import { addWebRequestToSession } from "../data/routes"; +import { addWebRequestToSession } from "../utils/routes"; const setupNext = async () => { try { diff --git a/electron-src/window/menu.ts b/electron-src/window/menu.ts index 9fcb36e..0d0847d 100644 --- a/electron-src/window/menu.ts +++ b/electron-src/window/menu.ts @@ -3,8 +3,8 @@ import type { MenuItemConstructorOptions } from "electron"; import { app, dialog, Menu, shell } from "electron"; import { windows } from "../index"; import { showApp } from "../utils/util"; -import { requestContactsPerms, requestFullDiskAccess } from "../data/ipc-onboarding"; -import { clearSkipContactsPermsCheck } from "../data/options"; +import { requestContactsPerms, requestFullDiskAccess } from "../ipc/ipc-onboarding"; +import { clearSkipContactsPermsCheck } from "../options"; import { copyDbAtPath, copyLatestDb } from "../data/db-file-utils"; import { logPath } from "../constants"; diff --git a/electron-src/data/database-worker.ts b/electron-src/workers/database-worker.ts similarity index 84% rename from electron-src/data/database-worker.ts rename to electron-src/workers/database-worker.ts index 512ba73..9e05c65 100644 --- a/electron-src/data/database-worker.ts +++ b/electron-src/workers/database-worker.ts @@ -1,8 +1,8 @@ import { spawn, Worker } from "threads"; -import type { SQLDatabase } from "./database"; -import { handleIpc } from "./ipc"; -import db from "./database"; -import { copyLatestDb, localDbExists } from "./db-file-utils"; +import type { SQLDatabase } from "../data/database"; +import { handleIpc } from "../ipc/ipc"; +import db from "../data/database"; +import { copyLatestDb, localDbExists } from "../data/db-file-utils"; import isDev from "electron-is-dev"; import { join } from "path"; import logger from "../utils/logger"; @@ -14,7 +14,7 @@ class DbWorker { worker!: WorkerType | SQLDatabase; startWorker = async () => { - const path = isDev ? "data/worker.js" : join("..", "..", "..", "app.asar.unpacked", "worker.js"); + const path = isDev ? "workers/worker.js" : join("..", "..", "..", "app.asar.unpacked", "worker.js"); this.worker = isDev ? db diff --git a/electron-src/data/worker.ts b/electron-src/workers/worker.ts similarity index 78% rename from electron-src/data/worker.ts rename to electron-src/workers/worker.ts index 982dde2..5333724 100644 --- a/electron-src/data/worker.ts +++ b/electron-src/workers/worker.ts @@ -1,5 +1,5 @@ -import type { SQLDatabase } from "./database"; -import db from "./database"; +import type { SQLDatabase } from "../data/database"; +import db from "../data/database"; import { expose } from "threads/worker"; const exposed: Partial, any>> = {}; diff --git a/package.json b/package.json index 087b392..efa328c 100644 --- a/package.json +++ b/package.json @@ -51,7 +51,6 @@ "axios": "^1.3.5", "bplist-universal": "^1.1.0", "chart.js": "^4.2.1", - "chromadb": "^1.4.0", "date-fns": "^2.29.3", "dayjs": "^1.11.7", "electron": "^24", @@ -154,13 +153,13 @@ "extraResources": [ "assets/**/*", { - "from": "build/electron-src/data/worker.js", + "from": "build/electron-src/workers/worker.js", "to": "app.asar.unpacked/worker.js" } ], "mac": { "binaries": [ - "build/electron-src/data/worker.js" + "build/electron-src/workers/worker.js" ], "target": { "target": "default", diff --git a/scripts/gen-types.ts b/scripts/gen-types.ts index 8705db4..af6c284 100644 --- a/scripts/gen-types.ts +++ b/scripts/gen-types.ts @@ -21,8 +21,10 @@ const prettierOptions: prettier.Options = { const libDir = path.join(os.homedir(), "Library"); const dbDir = path.join(libDir, "Application Support", appPath, "db.sqlite"); +const embeddingDbDir = path.join(libDir, "Application Support", appPath, "embeddings.sqlite"); const filename = "types.d.ts"; +const filenameEmbedding = "embeddings-db.d.ts"; if (!(await jetpack.existsAsync(dbDir))) { await jetpack.copy(path.join(libDir, "/Messages/chat.db"), dbDir, { overwrite: false }); @@ -30,7 +32,7 @@ if (!(await jetpack.existsAsync(dbDir))) { console.log(dbDir); const run = async () => { if (!(await fs.pathExists(dbDir))) { - console.log("Database does not exist - make sure you run the recorder at least once before running this script."); + console.log("Database does not exist - make sure you run the app at least once before running this script."); return; } const sqliteDb = new SqliteDb(dbDir); @@ -46,12 +48,24 @@ const run = async () => { console.log(out.stderr); const typeStr = await fs.readFile("node_modules/kysely-codegen/dist/db.d.ts", "utf8"); await fs.writeFile(path.join(dir, filename), prettier.format(typeStr, prettierOptions)); + + try { + const embeddins = await execa(`DATABASE_URL="${embeddingDbDir}" yarn kysely-codegen`, { shell: true }); + console.log(embeddins.stdout); + console.log(embeddins.stderr); + const typeStrEmbedding = await fs.readFile("node_modules/kysely-codegen/dist/db.d.ts", "utf8"); + await fs.writeFile(path.join(dir, filenameEmbedding), prettier.format(typeStrEmbedding, prettierOptions)); + } catch (e) { + console.log(e); + } }; try { await run(); } catch (e) { + console.log("Rebuilding binaries for arch..."); await execa(`npm rebuild better-sqlite3 --update-binary`, { shell: true }); + console.log("Done rebuilding binaries"); await run(); } process.exit(0); diff --git a/src/components/chat/OpenAiKey.tsx b/src/components/chat/OpenAiKey.tsx index a155b6e..76cd4c8 100644 --- a/src/components/chat/OpenAiKey.tsx +++ b/src/components/chat/OpenAiKey.tsx @@ -119,15 +119,6 @@ export const SemanticSearchInfo = () => { You can use AI to search through your messages. To enable this feature, please enter your OpenAI API key below. Note: this will take a long time and might cost you a bit. The estimates are below. - - Important! Chroma must be running locally for this to work, on port 8000! - - - Learn more - {isLoading && } {hasProgressInEmbeddings && data ? ( @@ -151,7 +142,7 @@ export const SemanticSearchInfo = () => { {data && ( - Total Messages: {data.totalMessages.toLocaleString()} + Total Unique Messages: {data.totalMessages.toLocaleString()} Total Tokens: {data.totalTokens.toLocaleString()} Avg Tokens / msg: {data.averageTokensPerLine.toLocaleString()} diff --git a/src/pages/_app.tsx b/src/pages/_app.tsx index b5ef6a9..9a0318c 100644 --- a/src/pages/_app.tsx +++ b/src/pages/_app.tsx @@ -75,7 +75,7 @@ const Initializing = () => { }} > {"Setup"} - Initializing database... + Initializing database... diff --git a/yarn.lock b/yarn.lock index 0c1ad04..422f0fd 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1712,13 +1712,6 @@ chownr@^2.0.0: resolved "https://registry.npmjs.org/chownr/-/chownr-2.0.0.tgz#15bfbe53d2eab4cf70f18a8cd68ebe5b3cb1dece" integrity sha512-bIomtDF5KGpdogkLd9VspvFzk9KfpyyGlS8YFVZl7TGPBHL5snIOnxeshwVgPteQ9b4Eydl+pVbIyE1DcvCWgQ== -chromadb@^1.4.0: - version "1.4.0" - resolved "https://registry.npmjs.org/chromadb/-/chromadb-1.4.0.tgz#2cd3c9ca3f38327743b0b1046e86680f434f6b73" - integrity sha512-FeIidtLSXhjqapdPmY+46QlB0rzf4ZWNwIyp8ZAkGF7su//78yozafA3ayor7bs6CK3T27C0isNVxZaoq0uQOg== - dependencies: - axios "^0.26.0" - chromium-pickle-js@^0.2.0: version "0.2.0" resolved "https://registry.npmjs.org/chromium-pickle-js/-/chromium-pickle-js-0.2.0.tgz#04a106672c18b085ab774d983dfa3ea138f22205"