Skip to content

Commit

Permalink
entirely local embeddings logic
Browse files Browse the repository at this point in the history
  • Loading branch information
jonluca committed Apr 19, 2023
1 parent 25d00b0 commit 8964d19
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 58 deletions.
4 changes: 2 additions & 2 deletions electron-src/data/database.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ export class SQLDatabase extends BaseDatabase<MesssagesDatabase> {
indexMap.set(texts[i], i);
}
results.sort((a, b) => {
const aIndex = indexMap.get(a.text);
const bIndex = indexMap.get(b.text);
const aIndex = indexMap.get(a.text!);
const bIndex = indexMap.get(b.text!);
if (aIndex === undefined || bIndex === undefined) {
return 0;
}
Expand Down
15 changes: 14 additions & 1 deletion electron-src/data/embeddings-database.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import type { DB as EmbeddingsDb } from "../../_generated/embeddings-db";
import logger from "../utils/logger";
import { embeddingsDbPath } from "../utils/constants";
import BaseDatabase from "./base-database";
import cosineSimilarity from "../semantic-search/vector-comparison";

export class EmbeddingsDatabase extends BaseDatabase<EmbeddingsDb> {
embeddingsCache: { text: string; embedding: Float32Array }[] = [];
Expand All @@ -14,7 +15,16 @@ export class EmbeddingsDatabase extends BaseDatabase<EmbeddingsDb> {
return result[0].count as number;
};

getAllEmbeddings = async () => {
calculateSimilarity = async (embedding: Float32Array) => {
const allEmbeddings = await this.getAllEmbeddings();
const similarities = allEmbeddings.map((e) => {
const similarity = cosineSimilarity(embedding!, e.embedding);
return { similarity, text: e.text };
});
similarities.sort((a, b) => b.similarity - a.similarity);
return similarities.slice(0, 100).map((l) => l.text!);
};
loadVectorsIntoMemory = async () => {
if (this.embeddingsCache.length) {
return this.embeddingsCache;
}
Expand All @@ -32,6 +42,9 @@ export class EmbeddingsDatabase extends BaseDatabase<EmbeddingsDb> {
),
};
});
};
getAllEmbeddings = async () => {
await this.loadVectorsIntoMemory();
return this.embeddingsCache;
};
getEmbeddingByText = async (text: string) => {
Expand Down
1 change: 1 addition & 0 deletions electron-src/esbuild.main.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ const config: BuildOptions = {
entryPoints: [
path.resolve("electron-src/index.ts"),
path.resolve("electron-src/workers/worker.ts"),
path.resolve("electron-src/workers/embeddings-worker.ts"),
path.resolve("electron-src/utils/preload.ts"),
],
bundle: true,
Expand Down
22 changes: 7 additions & 15 deletions electron-src/semantic-search/semantic-search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import { handleIpc } from "../ipc/ipc";
import logger from "../utils/logger";
import { BatchOpenAi, OPENAI_EMBEDDING_MODEL } from "./batch-utils";
import pMap from "p-map";
import embeddingsDb from "../data/embeddings-database";
import cosineSimilarity from "./vector-comparison";
import { uniqBy } from "lodash-es";

export interface SemanticSearchMetadata {
Expand Down Expand Up @@ -82,9 +80,9 @@ export const createEmbeddings = async ({ openAiKey }: { openAiKey: string }) =>
});
const openai = new OpenAIApi(configuration);

await embeddingsDb.initialize();
await dbWorker.embeddingsWorker.initialize();

const existingText = await embeddingsDb.getAllText();
const existingText = await dbWorker.embeddingsWorker.getAllText();
const set = new Set(existingText);
numCompleted = existingText.length;
const notParsed = messages.filter((m) => m.text && !set.has(m.text));
Expand All @@ -105,7 +103,7 @@ export const createEmbeddings = async ({ openAiKey }: { openAiKey: string }) =>
try {
logger.info(`Inserting ${itemEmbeddings.length} vectors`);
const embeddings = itemEmbeddings.map((l) => ({ embedding: l.values, text: l.metadata.text }));
await embeddingsDb.insertEmbeddings(embeddings);
await dbWorker.embeddingsWorker.insertEmbeddings(embeddings);
logger.info(`Inserted ${itemEmbeddings.length} vectors`);
numCompleted += itemEmbeddings.length;
} catch (e) {
Expand Down Expand Up @@ -134,7 +132,7 @@ interface SemanticQueryOpts {
}

export async function semanticQuery({ queryText, openAiKey }: SemanticQueryOpts) {
const existingEmbedding = await embeddingsDb.getEmbeddingByText(queryText);
const existingEmbedding = await dbWorker.embeddingsWorker.getEmbeddingByText(queryText);
let floatEmbedding = existingEmbedding?.embedding;

if (!existingEmbedding) {
Expand All @@ -153,17 +151,11 @@ export async function semanticQuery({ queryText, openAiKey }: SemanticQueryOpts)
return [];
}
// save embedding
await embeddingsDb.insertEmbeddings([{ embedding, text: queryText }]);
await dbWorker.embeddingsWorker.insertEmbeddings([{ embedding, text: queryText }]);
floatEmbedding = new Float32Array(embedding);
}

const allEmbeddings = await embeddingsDb.getAllEmbeddings();
const similarities = allEmbeddings.map((e) => {
const similarity = cosineSimilarity(floatEmbedding!, e.embedding);
return { similarity, text: e.text };
});
similarities.sort((a, b) => b.similarity - a.similarity);
return similarities.slice(0, 100).map((l) => l.text!);
return dbWorker.embeddingsWorker.calculateSimilarity(floatEmbedding!);
}

handleIpc("createEmbeddings", async ({ openAiKey: openAiKey }) => {
Expand All @@ -178,7 +170,7 @@ handleIpc("getEmbeddingsCompleted", async () => {

handleIpc("calculateSemanticSearchStatsEnhanced", async () => {
const stats = await dbWorker.worker.calculateSemanticSearchStats();
const localDb = embeddingsDb;
const localDb = dbWorker.embeddingsWorker;
try {
await localDb.initialize();
const count = await localDb.countEmbeddings();
Expand Down
30 changes: 23 additions & 7 deletions electron-src/semantic-search/vector-comparison.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,27 @@ function dot(x: ArrayLike, y: ArrayLike) {
}
return sum;
}
function cosineSimilarity(x: ArrayLike, y: ArrayLike): number {
const a = dot(x, y);
const b = l2norm(x);
const c = l2norm(y);
return a / (b * c);
}
// function cosineSimilarity(x: ArrayLike, y: ArrayLike): number {
// const a = dot(x, y);
// const b = l2norm(x);
// const c = l2norm(y);
// return a / (b * c);
// }
export default function cosineSimilarity(vectorA: Float32Array, vectorB: Float32Array) {
const dimensionality = Math.min(vectorA.length, vectorB.length);
let dotAB = 0;
let dotA = 0;
let dotB = 0;
let dimension = 0;
while (dimension < dimensionality) {
const componentA = vectorA[dimension];
const componentB = vectorB[dimension];
dotAB += componentA * componentB;
dotA += componentA * componentA;
dotB += componentB * componentB;
dimension += 1;
}

export default cosineSimilarity;
const magnitude = Math.sqrt(dotA * dotB);
return magnitude === 0 ? 0 : dotAB / magnitude;
}
16 changes: 16 additions & 0 deletions electron-src/workers/database-worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,22 @@ import { copyLatestDb, localDbExists } from "../data/db-file-utils";
import isDev from "electron-is-dev";
import { join } from "path";
import logger from "../utils/logger";
import type { EmbeddingsDatabase } from "../data/embeddings-database";
import embeddingsDb from "../data/embeddings-database";

type WorkerType<T> = {
[P in keyof T]: T[P] extends (...args: infer A) => infer R ? (...args: A) => Promise<R> : never;
};

class DbWorker {
worker!: WorkerType<SQLDatabase> | SQLDatabase;
embeddingsWorker!: WorkerType<EmbeddingsDatabase> | EmbeddingsDatabase;

startWorker = async () => {
const path = isDev ? "workers/worker.js" : join("..", "..", "..", "app.asar.unpacked", "worker.js");
const embeddingWorkerPath = isDev
? "workers/embeddings-worker.js"
: join("..", "..", "..", "app.asar.unpacked", "embeddings-worker.js");

this.worker = isDev
? db
Expand All @@ -23,6 +30,13 @@ class DbWorker {
resourceLimits: { maxOldGenerationSizeMb: 32678, maxYoungGenerationSizeMb: 32678 },
}),
);
this.embeddingsWorker = isDev
? embeddingsDb
: await spawn<WorkerType<EmbeddingsDatabase>>(
new Worker(embeddingWorkerPath, {
resourceLimits: { maxOldGenerationSizeMb: 32678, maxYoungGenerationSizeMb: 32678 },
}),
);
};

setupHandlers() {
Expand All @@ -40,6 +54,8 @@ class DbWorker {
handleIpc(prop, dbElement as any);
}
}

handleIpc("loadVectorsIntoMemory", this.embeddingsWorker.loadVectorsIntoMemory);
}
isCopying = false;
doesLocalDbCopyExist = async () => {
Expand Down
14 changes: 14 additions & 0 deletions electron-src/workers/embeddings-worker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import type { EmbeddingsDatabase } from "../data/embeddings-database";
import { expose } from "threads/worker";
import embeddingsDb from "../data/embeddings-database";

const exposed: Partial<Record<Partial<keyof EmbeddingsDatabase>, any>> = {};
for (const property in embeddingsDb) {
const prop = property as keyof EmbeddingsDatabase;
const dbElement = embeddingsDb[prop];
if (typeof dbElement === "function") {
exposed[prop] = dbElement;
}
}

expose(exposed);
9 changes: 7 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"author": "JonLuca DeCaro <mimessage@jonlu.ca>",
"main": "build/electron-src/index.js",
"name": "MiMessage",
"version": "1.0.10",
"version": "1.0.11",
"productName": "Mimessage",
"description": "Apple Messages UI alternative, with export, search, and more.",
"scripts": {
Expand Down Expand Up @@ -155,11 +155,16 @@
{
"from": "build/electron-src/workers/worker.js",
"to": "app.asar.unpacked/worker.js"
},
{
"from": "build/electron-src/workers/embeddings-worker.js",
"to": "app.asar.unpacked/embeddings-worker.js"
}
],
"mac": {
"binaries": [
"build/electron-src/workers/worker.js"
"build/electron-src/workers/worker.js",
"build/electron-src/workers/embeddings-worker.js"
],
"target": {
"target": "default",
Expand Down
88 changes: 58 additions & 30 deletions src/components/global-search/GlobalSearch.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ import {
useGlobalSearch,
useGroupChatList,
useHandleMap,
useLoadSemanticResultsIntoMemory,
} from "../../hooks/dataHooks";
import { LinearProgress } from "@mui/material";
import { CircularProgress, LinearProgress } from "@mui/material";
import { Button, Checkbox, FormControlLabel, FormGroup } from "@mui/material";

import { Virtuoso } from "react-virtuoso";
Expand All @@ -29,6 +30,7 @@ import { DayPicker } from "react-day-picker";
import Popover from "@mui/material/Popover";
import { shallow } from "zustand/shallow";
import { SemanticSearchInfo } from "../chat/OpenAiKey";
import Backdrop from "@mui/material/Backdrop";

const GloablSearchInput = () => {
const globalSearch = useMimessage((state) => state.globalSearch);
Expand Down Expand Up @@ -190,28 +192,69 @@ const GroupChatFilter = () => {
);
};

const ToggleSemanticSearch = () => {
const { mutateAsync, isLoading } = useLoadSemanticResultsIntoMemory();
const { openAiKey, setUseSemanticSearch, useSemanticSearch } = useMimessage(
(state) => ({
useSemanticSearch: state.useSemanticSearch,
setUseSemanticSearch: state.setUseSemanticSearch,
openAiKey: state.openAiKey,
}),
shallow,
);
return (
<>
{isLoading && (
<Backdrop open>
<Box
onClick={(e) => e.stopPropagation()}
sx={{ background: "#2c2c2c", maxWidth: 600, p: 2, m: 2 }}
display={"flex"}
flexDirection={"column"}
>
<Typography variant="h1" sx={{ color: "white" }}>
Loading Vectors into Memory
</Typography>
<Typography variant="h6" sx={{ color: "white" }}>
This takes ~2s per 100k messages
</Typography>
<CircularProgress />
</Box>
</Backdrop>
)}
<FormGroup>
<FormControlLabel
control={
<Checkbox
style={{
color: "white",
}}
checked={useSemanticSearch}
onChange={() => {
setUseSemanticSearch(!useSemanticSearch);
mutateAsync();
}}
disabled={!openAiKey}
title={openAiKey ? "" : "OpenAI Key Required"}
/>
}
label="Use Semantic Search"
/>
</FormGroup>
</>
);
};

const GlobalSearchFilter = () => {
const { data: results } = useGlobalSearch();
const count = results?.length || 0;
const {
openAiKey,
setUseSemanticSearch,
useSemanticSearch,
startDate,
setStartDate,
setEndDate,
endDate,
globalSearch,
} = useMimessage(
const { startDate, setStartDate, setEndDate, endDate, globalSearch } = useMimessage(
(state) => ({
startDate: state.startDate,
endDate: state.endDate,
setStartDate: state.setStartDate,
setEndDate: state.setEndDate,
globalSearch: state.globalSearch,
useSemanticSearch: state.useSemanticSearch,
setUseSemanticSearch: state.setUseSemanticSearch,
openAiKey: state.openAiKey,
}),
shallow,
);
Expand All @@ -237,22 +280,7 @@ const GlobalSearchFilter = () => {
<GroupChatFilter />
<DateFilter selection={startDate} setSelection={setStartDate} text={"Start Date"} />
<DateFilter selection={endDate} setSelection={setEndDate} text={"End Date"} />
<FormGroup>
<FormControlLabel
control={
<Checkbox
style={{
color: "white",
}}
checked={useSemanticSearch}
onChange={() => setUseSemanticSearch(!useSemanticSearch)}
disabled={!openAiKey}
title={openAiKey ? "" : "OpenAI Key Required"}
/>
}
label="Use Semantic Search"
/>
</FormGroup>
<ToggleSemanticSearch />
</Box>
);
};
Expand Down
7 changes: 6 additions & 1 deletion src/hooks/dataHooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,6 @@ export const useGlobalSearch = () => {
}),
shallow,
);

const handleMap = useContactToHandleMap();
return useQuery<GlobalSearchResponse>(
[
Expand Down Expand Up @@ -612,6 +611,12 @@ export const useOpenFileAtLocation = () => {
});
};

export const useLoadSemanticResultsIntoMemory = () => {
return useMutation(["loadVectorsIntoMemory"], async () => {
await ipcRenderer.invoke("loadVectorsIntoMemory");
});
};

export const useAccessibilityPermissionsCheck = () => {
return useQuery<boolean | null>(
["accessibility-permissions"],
Expand Down

0 comments on commit 8964d19

Please sign in to comment.