diff --git a/electron-src/data/database.ts b/electron-src/data/database.ts index 15d7d69..a109967 100644 --- a/electron-src/data/database.ts +++ b/electron-src/data/database.ts @@ -114,8 +114,8 @@ export class SQLDatabase extends BaseDatabase { indexMap.set(texts[i], i); } results.sort((a, b) => { - const aIndex = indexMap.get(a.text); - const bIndex = indexMap.get(b.text); + const aIndex = indexMap.get(a.text!); + const bIndex = indexMap.get(b.text!); if (aIndex === undefined || bIndex === undefined) { return 0; } diff --git a/electron-src/data/embeddings-database.ts b/electron-src/data/embeddings-database.ts index 18d6ef8..d09aa73 100644 --- a/electron-src/data/embeddings-database.ts +++ b/electron-src/data/embeddings-database.ts @@ -2,6 +2,7 @@ import type { DB as EmbeddingsDb } from "../../_generated/embeddings-db"; import logger from "../utils/logger"; import { embeddingsDbPath } from "../utils/constants"; import BaseDatabase from "./base-database"; +import cosineSimilarity from "../semantic-search/vector-comparison"; export class EmbeddingsDatabase extends BaseDatabase { embeddingsCache: { text: string; embedding: Float32Array }[] = []; @@ -14,7 +15,16 @@ export class EmbeddingsDatabase extends BaseDatabase { return result[0].count as number; }; - getAllEmbeddings = async () => { + calculateSimilarity = async (embedding: Float32Array) => { + const allEmbeddings = await this.getAllEmbeddings(); + const similarities = allEmbeddings.map((e) => { + const similarity = cosineSimilarity(embedding!, e.embedding); + return { similarity, text: e.text }; + }); + similarities.sort((a, b) => b.similarity - a.similarity); + return similarities.slice(0, 100).map((l) => l.text!); + }; + loadVectorsIntoMemory = async () => { if (this.embeddingsCache.length) { return this.embeddingsCache; } @@ -32,6 +42,9 @@ export class EmbeddingsDatabase extends BaseDatabase { ), }; }); + }; + getAllEmbeddings = async () => { + await this.loadVectorsIntoMemory(); return this.embeddingsCache; }; getEmbeddingByText = async (text: string) => { diff --git a/electron-src/esbuild.main.config.ts b/electron-src/esbuild.main.config.ts index 64560c6..771c539 100644 --- a/electron-src/esbuild.main.config.ts +++ b/electron-src/esbuild.main.config.ts @@ -6,6 +6,7 @@ const config: BuildOptions = { entryPoints: [ path.resolve("electron-src/index.ts"), path.resolve("electron-src/workers/worker.ts"), + path.resolve("electron-src/workers/embeddings-worker.ts"), path.resolve("electron-src/utils/preload.ts"), ], bundle: true, diff --git a/electron-src/semantic-search/semantic-search.ts b/electron-src/semantic-search/semantic-search.ts index b6f367a..ba9da85 100644 --- a/electron-src/semantic-search/semantic-search.ts +++ b/electron-src/semantic-search/semantic-search.ts @@ -5,8 +5,6 @@ import { handleIpc } from "../ipc/ipc"; import logger from "../utils/logger"; import { BatchOpenAi, OPENAI_EMBEDDING_MODEL } from "./batch-utils"; import pMap from "p-map"; -import embeddingsDb from "../data/embeddings-database"; -import cosineSimilarity from "./vector-comparison"; import { uniqBy } from "lodash-es"; export interface SemanticSearchMetadata { @@ -82,9 +80,9 @@ export const createEmbeddings = async ({ openAiKey }: { openAiKey: string }) => }); const openai = new OpenAIApi(configuration); - await embeddingsDb.initialize(); + await dbWorker.embeddingsWorker.initialize(); - const existingText = await embeddingsDb.getAllText(); + const existingText = await dbWorker.embeddingsWorker.getAllText(); const set = new Set(existingText); numCompleted = existingText.length; const notParsed = messages.filter((m) => m.text && !set.has(m.text)); @@ -105,7 +103,7 @@ export const createEmbeddings = async ({ openAiKey }: { openAiKey: string }) => try { logger.info(`Inserting ${itemEmbeddings.length} vectors`); const embeddings = itemEmbeddings.map((l) => ({ embedding: l.values, text: l.metadata.text })); - await embeddingsDb.insertEmbeddings(embeddings); + await dbWorker.embeddingsWorker.insertEmbeddings(embeddings); logger.info(`Inserted ${itemEmbeddings.length} vectors`); numCompleted += itemEmbeddings.length; } catch (e) { @@ -134,7 +132,7 @@ interface SemanticQueryOpts { } export async function semanticQuery({ queryText, openAiKey }: SemanticQueryOpts) { - const existingEmbedding = await embeddingsDb.getEmbeddingByText(queryText); + const existingEmbedding = await dbWorker.embeddingsWorker.getEmbeddingByText(queryText); let floatEmbedding = existingEmbedding?.embedding; if (!existingEmbedding) { @@ -153,17 +151,11 @@ export async function semanticQuery({ queryText, openAiKey }: SemanticQueryOpts) return []; } // save embedding - await embeddingsDb.insertEmbeddings([{ embedding, text: queryText }]); + await dbWorker.embeddingsWorker.insertEmbeddings([{ embedding, text: queryText }]); floatEmbedding = new Float32Array(embedding); } - const allEmbeddings = await embeddingsDb.getAllEmbeddings(); - const similarities = allEmbeddings.map((e) => { - const similarity = cosineSimilarity(floatEmbedding!, e.embedding); - return { similarity, text: e.text }; - }); - similarities.sort((a, b) => b.similarity - a.similarity); - return similarities.slice(0, 100).map((l) => l.text!); + return dbWorker.embeddingsWorker.calculateSimilarity(floatEmbedding!); } handleIpc("createEmbeddings", async ({ openAiKey: openAiKey }) => { @@ -178,7 +170,7 @@ handleIpc("getEmbeddingsCompleted", async () => { handleIpc("calculateSemanticSearchStatsEnhanced", async () => { const stats = await dbWorker.worker.calculateSemanticSearchStats(); - const localDb = embeddingsDb; + const localDb = dbWorker.embeddingsWorker; try { await localDb.initialize(); const count = await localDb.countEmbeddings(); diff --git a/electron-src/semantic-search/vector-comparison.ts b/electron-src/semantic-search/vector-comparison.ts index 5389768..8ebcdd6 100644 --- a/electron-src/semantic-search/vector-comparison.ts +++ b/electron-src/semantic-search/vector-comparison.ts @@ -35,11 +35,27 @@ function dot(x: ArrayLike, y: ArrayLike) { } return sum; } -function cosineSimilarity(x: ArrayLike, y: ArrayLike): number { - const a = dot(x, y); - const b = l2norm(x); - const c = l2norm(y); - return a / (b * c); -} +// function cosineSimilarity(x: ArrayLike, y: ArrayLike): number { +// const a = dot(x, y); +// const b = l2norm(x); +// const c = l2norm(y); +// return a / (b * c); +// } +export default function cosineSimilarity(vectorA: Float32Array, vectorB: Float32Array) { + const dimensionality = Math.min(vectorA.length, vectorB.length); + let dotAB = 0; + let dotA = 0; + let dotB = 0; + let dimension = 0; + while (dimension < dimensionality) { + const componentA = vectorA[dimension]; + const componentB = vectorB[dimension]; + dotAB += componentA * componentB; + dotA += componentA * componentA; + dotB += componentB * componentB; + dimension += 1; + } -export default cosineSimilarity; + const magnitude = Math.sqrt(dotA * dotB); + return magnitude === 0 ? 0 : dotAB / magnitude; +} diff --git a/electron-src/workers/database-worker.ts b/electron-src/workers/database-worker.ts index 9e05c65..ba36640 100644 --- a/electron-src/workers/database-worker.ts +++ b/electron-src/workers/database-worker.ts @@ -6,15 +6,22 @@ import { copyLatestDb, localDbExists } from "../data/db-file-utils"; import isDev from "electron-is-dev"; import { join } from "path"; import logger from "../utils/logger"; +import type { EmbeddingsDatabase } from "../data/embeddings-database"; +import embeddingsDb from "../data/embeddings-database"; + type WorkerType = { [P in keyof T]: T[P] extends (...args: infer A) => infer R ? (...args: A) => Promise : never; }; class DbWorker { worker!: WorkerType | SQLDatabase; + embeddingsWorker!: WorkerType | EmbeddingsDatabase; startWorker = async () => { const path = isDev ? "workers/worker.js" : join("..", "..", "..", "app.asar.unpacked", "worker.js"); + const embeddingWorkerPath = isDev + ? "workers/embeddings-worker.js" + : join("..", "..", "..", "app.asar.unpacked", "embeddings-worker.js"); this.worker = isDev ? db @@ -23,6 +30,13 @@ class DbWorker { resourceLimits: { maxOldGenerationSizeMb: 32678, maxYoungGenerationSizeMb: 32678 }, }), ); + this.embeddingsWorker = isDev + ? embeddingsDb + : await spawn>( + new Worker(embeddingWorkerPath, { + resourceLimits: { maxOldGenerationSizeMb: 32678, maxYoungGenerationSizeMb: 32678 }, + }), + ); }; setupHandlers() { @@ -40,6 +54,8 @@ class DbWorker { handleIpc(prop, dbElement as any); } } + + handleIpc("loadVectorsIntoMemory", this.embeddingsWorker.loadVectorsIntoMemory); } isCopying = false; doesLocalDbCopyExist = async () => { diff --git a/electron-src/workers/embeddings-worker.ts b/electron-src/workers/embeddings-worker.ts new file mode 100644 index 0000000..56fba4e --- /dev/null +++ b/electron-src/workers/embeddings-worker.ts @@ -0,0 +1,14 @@ +import type { EmbeddingsDatabase } from "../data/embeddings-database"; +import { expose } from "threads/worker"; +import embeddingsDb from "../data/embeddings-database"; + +const exposed: Partial, any>> = {}; +for (const property in embeddingsDb) { + const prop = property as keyof EmbeddingsDatabase; + const dbElement = embeddingsDb[prop]; + if (typeof dbElement === "function") { + exposed[prop] = dbElement; + } +} + +expose(exposed); diff --git a/package.json b/package.json index efa328c..3a80307 100644 --- a/package.json +++ b/package.json @@ -3,7 +3,7 @@ "author": "JonLuca DeCaro ", "main": "build/electron-src/index.js", "name": "MiMessage", - "version": "1.0.10", + "version": "1.0.11", "productName": "Mimessage", "description": "Apple Messages UI alternative, with export, search, and more.", "scripts": { @@ -155,11 +155,16 @@ { "from": "build/electron-src/workers/worker.js", "to": "app.asar.unpacked/worker.js" + }, + { + "from": "build/electron-src/workers/embeddings-worker.js", + "to": "app.asar.unpacked/embeddings-worker.js" } ], "mac": { "binaries": [ - "build/electron-src/workers/worker.js" + "build/electron-src/workers/worker.js", + "build/electron-src/workers/embeddings-worker.js" ], "target": { "target": "default", diff --git a/src/components/global-search/GlobalSearch.tsx b/src/components/global-search/GlobalSearch.tsx index 12b23dd..3164721 100644 --- a/src/components/global-search/GlobalSearch.tsx +++ b/src/components/global-search/GlobalSearch.tsx @@ -8,8 +8,9 @@ import { useGlobalSearch, useGroupChatList, useHandleMap, + useLoadSemanticResultsIntoMemory, } from "../../hooks/dataHooks"; -import { LinearProgress } from "@mui/material"; +import { CircularProgress, LinearProgress } from "@mui/material"; import { Button, Checkbox, FormControlLabel, FormGroup } from "@mui/material"; import { Virtuoso } from "react-virtuoso"; @@ -29,6 +30,7 @@ import { DayPicker } from "react-day-picker"; import Popover from "@mui/material/Popover"; import { shallow } from "zustand/shallow"; import { SemanticSearchInfo } from "../chat/OpenAiKey"; +import Backdrop from "@mui/material/Backdrop"; const GloablSearchInput = () => { const globalSearch = useMimessage((state) => state.globalSearch); @@ -190,28 +192,69 @@ const GroupChatFilter = () => { ); }; +const ToggleSemanticSearch = () => { + const { mutateAsync, isLoading } = useLoadSemanticResultsIntoMemory(); + const { openAiKey, setUseSemanticSearch, useSemanticSearch } = useMimessage( + (state) => ({ + useSemanticSearch: state.useSemanticSearch, + setUseSemanticSearch: state.setUseSemanticSearch, + openAiKey: state.openAiKey, + }), + shallow, + ); + return ( + <> + {isLoading && ( + + e.stopPropagation()} + sx={{ background: "#2c2c2c", maxWidth: 600, p: 2, m: 2 }} + display={"flex"} + flexDirection={"column"} + > + + Loading Vectors into Memory + + + This takes ~2s per 100k messages + + + + + )} + + { + setUseSemanticSearch(!useSemanticSearch); + mutateAsync(); + }} + disabled={!openAiKey} + title={openAiKey ? "" : "OpenAI Key Required"} + /> + } + label="Use Semantic Search" + /> + + + ); +}; + const GlobalSearchFilter = () => { const { data: results } = useGlobalSearch(); const count = results?.length || 0; - const { - openAiKey, - setUseSemanticSearch, - useSemanticSearch, - startDate, - setStartDate, - setEndDate, - endDate, - globalSearch, - } = useMimessage( + const { startDate, setStartDate, setEndDate, endDate, globalSearch } = useMimessage( (state) => ({ startDate: state.startDate, endDate: state.endDate, setStartDate: state.setStartDate, setEndDate: state.setEndDate, globalSearch: state.globalSearch, - useSemanticSearch: state.useSemanticSearch, - setUseSemanticSearch: state.setUseSemanticSearch, - openAiKey: state.openAiKey, }), shallow, ); @@ -237,22 +280,7 @@ const GlobalSearchFilter = () => { - - setUseSemanticSearch(!useSemanticSearch)} - disabled={!openAiKey} - title={openAiKey ? "" : "OpenAI Key Required"} - /> - } - label="Use Semantic Search" - /> - + ); }; diff --git a/src/hooks/dataHooks.ts b/src/hooks/dataHooks.ts index fe577da..1199fb5 100644 --- a/src/hooks/dataHooks.ts +++ b/src/hooks/dataHooks.ts @@ -396,7 +396,6 @@ export const useGlobalSearch = () => { }), shallow, ); - const handleMap = useContactToHandleMap(); return useQuery( [ @@ -612,6 +611,12 @@ export const useOpenFileAtLocation = () => { }); }; +export const useLoadSemanticResultsIntoMemory = () => { + return useMutation(["loadVectorsIntoMemory"], async () => { + await ipcRenderer.invoke("loadVectorsIntoMemory"); + }); +}; + export const useAccessibilityPermissionsCheck = () => { return useQuery( ["accessibility-permissions"],