Skip to content

Commit

Permalink
fully working local sqlite semantic search db
Browse files Browse the repository at this point in the history
  • Loading branch information
jonluca committed Apr 19, 2023
1 parent 14ca5cc commit 25d00b0
Show file tree
Hide file tree
Showing 27 changed files with 455 additions and 329 deletions.
8 changes: 8 additions & 0 deletions _generated/embeddings-db.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
export interface Embeddings {
text: string;
embedding: Buffer;
}

export interface DB {
embeddings: Embeddings;
}
111 changes: 111 additions & 0 deletions electron-src/data/base-database.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import { Kysely } from "kysely";
import logger from "../utils/logger";
import type { Database } from "better-sqlite3";
import SqliteDb from "better-sqlite3";
import { SqliteDialect } from "kysely";
import type { KyselyConfig } from "kysely/dist/cjs/kysely";
import { format } from "sql-formatter";

type PostSetupCallback<T> = (db: Database, ks: Kysely<T>) => Promise<void>;
const debugLoggingEnabled = process.env.DEBUG_LOGGING === "true";
export class BaseDatabase<T> {
path: string;
name: string;
postSetup: PostSetupCallback<T> | undefined;
dbWriter: Kysely<T> | undefined;
isSettingUpDb = false;

private initializationPromise!: Promise<void>;

constructor(name: string, path: string, postSetup?: PostSetupCallback<T>) {
this.path = path;
this.name = name;
this.postSetup = postSetup;
}

isDbInitialized = () => {
return !!this.dbWriter;
};

initialize = () => {
if (this.initializationPromise) {
return this.initializationPromise;
}
this.initializationPromise = new Promise(async (resolve) => {
const success = await this.trySetupDb();
if (!success) {
const startTimeout = () =>
setTimeout(async () => {
const success = await this.trySetupDb();
if (success) {
resolve();
} else {
startTimeout();
}
}, 1000);
startTimeout();
} else {
resolve();
}
});
this.initializationPromise.then(() => {
logger.info(`${this.name} initialized`);
});

return this.initializationPromise;
};
trySetupDb = async () => {
try {
if (this.isSettingUpDb) {
return false;
}
this.isSettingUpDb = true;
logger.info(`Setting up ${this.name}`);
const sqliteDb = new SqliteDb(this.path, { fileMustExist: false });
const dialect = new SqliteDialect({ database: sqliteDb });
const options: KyselyConfig = {
dialect,
log(event): void {
const isError = event.level === "error";

if (isError || debugLoggingEnabled) {
const { sql, parameters } = event.query;

const { queryDurationMillis } = event;
const duration = queryDurationMillis.toFixed(2);
const params = (parameters as string[]) || [];
const formattedSql = format(sql, { params: params.map((l) => String(l)), language: "sqlite" });
if (event.level === "query") {
logger.debug(`[Query - ${duration}ms]:\n${formattedSql}\n`);
}

if (isError) {
logger.error(`[SQL Error - ${duration}ms]: ${event.error}\n\n${formattedSql}\n`);
}
}
},
};

const db = new Kysely<T>(options);
if (this.postSetup) {
await this.postSetup(sqliteDb, db);
}
this.dbWriter = db;
return true;
} catch (e) {
console.error(e);
return false;
} finally {
this.isSettingUpDb = false;
}
};

get db() {
if (!this.dbWriter) {
throw new Error(`${this.name} not initialized!`);
}
return this.dbWriter;
}
}

export default BaseDatabase;
24 changes: 0 additions & 24 deletions electron-src/data/chroma.ts

This file was deleted.

219 changes: 74 additions & 145 deletions electron-src/data/database.ts
Original file line number Diff line number Diff line change
@@ -1,157 +1,19 @@
import SqliteDb from "better-sqlite3";
import type { SelectQueryBuilder } from "kysely";
import { Kysely, sql, SqliteDialect } from "kysely";
import type { SelectQueryBuilder, Kysely } from "kysely";
import { sql } from "kysely";
import type { DB as MesssagesDatabase } from "../../_generated/types";
import logger from "../utils/logger";
import type { KyselyConfig } from "kysely/dist/cjs/kysely";
import { countBy, groupBy, partition } from "lodash-es";
import type { Contact } from "electron-mac-contacts";
import { format } from "sql-formatter";
import { decodeMessageBuffer, getTextFromBuffer } from "../utils/buffer";
import { localDbExists } from "./db-file-utils";
import { appMessagesDbCopy } from "../utils/constants";
import { getStatsForText } from "./semantic-search-stats";
import { removeStopWords } from "./text";
import { getStatsForText } from "../semantic-search/semantic-search-stats";
import { removeStopWords } from "../utils/text";
import BaseDatabase from "./base-database";

type ExtractO<T> = T extends SelectQueryBuilder<any, any, infer O> ? O : never;
type JoinedMessageType = ExtractO<ReturnType<SQLDatabase["getJoinedMessageQuery"]>>;
const debugLoggingEnabled = process.env.DEBUG_LOGGING === "true";

export class SQLDatabase {
path: string = appMessagesDbCopy;
private dbWriter: Kysely<MesssagesDatabase> | undefined;

initializationPromise!: Promise<void>;

isDbInitialized = () => {
return !!this.dbWriter;
};

initialize = () => {
if (this.initializationPromise) {
return this.initializationPromise;
}
this.initializationPromise = new Promise(async (resolve) => {
const success = await this.trySetupDb();
if (!success) {
const startTimeout = () =>
setTimeout(async () => {
const success = await this.trySetupDb();
if (success) {
resolve();
} else {
startTimeout();
}
}, 1000);
startTimeout();
} else {
resolve();
}
});
this.initializationPromise.then(() => {
logger.info("DB initialized");
});

return this.initializationPromise;
};

addParsedTextToNullMessages = async () => {
const db = this.db;
const messagesWithNullText = await db
.selectFrom("message")
.select(["ROWID", "attributedBody"])
.where("text", "is", null)
.where("attributedBody", "is not", null)
.execute();

if (messagesWithNullText.length) {
const now = performance.now();
logger.info(`Adding parsed text to ${messagesWithNullText.length} messages`);

await db.transaction().execute(async (trx) => {
for (const message of messagesWithNullText) {
try {
const { attributedBody, ROWID } = message;
if (attributedBody) {
const parsed = await decodeMessageBuffer(attributedBody);
if (parsed) {
const string = parsed[0]?.value?.string;
if (typeof string === "string") {
await trx.updateTable("message").set({ text: string }).where("ROWID", "=", ROWID).executeTakeFirst();
}
}
}
} catch {
//skip
}
}
});
logger.info(`Done adding parsed text to ${messagesWithNullText.length} messages in ${performance.now() - now}ms`);
}
};
get db() {
if (!this.dbWriter) {
throw new Error("DB not initialized!");
}
return this.dbWriter;
}

isSettingUpDb = false;
trySetupDb = async () => {
try {
if (this.isSettingUpDb || !(await localDbExists())) {
return false;
}
this.isSettingUpDb = true;
logger.info("Setting up db");
const sqliteDb = new SqliteDb(this.path);
const dialect = new SqliteDialect({ database: sqliteDb });
const options: KyselyConfig = {
dialect,
log(event): void {
const isError = event.level === "error";

if (isError || debugLoggingEnabled) {
const { sql, parameters } = event.query;

const { queryDurationMillis } = event;
const duration = queryDurationMillis.toFixed(2);
const params = (parameters as string[]) || [];
const formattedSql = format(sql, { params: params.map((l) => String(l)), language: "sqlite" });
if (event.level === "query") {
logger.debug(`[Query - ${duration}ms]:\n${formattedSql}\n`);
}

if (isError) {
logger.error(`[SQL Error - ${duration}ms]: ${event.error}\n\n${formattedSql}\n`);
}
}
},
};

const db = new Kysely<MesssagesDatabase>(options);
this.dbWriter = db;
// add in text
logger.info("Adding text column to messages");
await this.addParsedTextToNullMessages();
logger.info("Adding text column to messages done");

// create virtual table if not exists
logger.info("Creating virtual table");
await sqliteDb.exec("CREATE VIRTUAL TABLE IF NOT EXISTS message_fts USING fts5(text,message_id)");
const count = await db.selectFrom("message_fts").select("message_id").limit(1).executeTakeFirst();
if (count === undefined) {
await sqliteDb.exec("INSERT INTO message_fts SELECT text, guid as message_id FROM message");
}
logger.info("Creating virtual table done");
return true;
} catch (e) {
console.error(e);
return false;
} finally {
this.isSettingUpDb = false;
}
};
export class SQLDatabase extends BaseDatabase<MesssagesDatabase> {
private getChatsWithMessagesQuery = () => {
const db = this.db;
return db
Expand Down Expand Up @@ -243,6 +105,25 @@ export class SQLDatabase {
.select(sql<number>`message.ROWID`.as("message_id"));
};

getMessageGuidsFromText = async (texts: string[]) => {
const db = this.db;
const query = db.selectFrom("message").select(["guid", "text"]).where("text", "in", texts);
const results = await query.execute();
const indexMap = new Map<string, number>();
for (let i = 0; i < texts.length; i++) {
indexMap.set(texts[i], i);
}
results.sort((a, b) => {
const aIndex = indexMap.get(a.text);
const bIndex = indexMap.get(b.text);
if (aIndex === undefined || bIndex === undefined) {
return 0;
}
return aIndex - bIndex;
});
return results.map((r) => r.guid);
};

globalSearchTextBased = async (
searchTerm: string,
chatIds?: number[],
Expand Down Expand Up @@ -697,7 +578,55 @@ export class SQLDatabase {
};
}

const db = new SQLDatabase();
const addParsedTextToNullMessages = async (db: Kysely<MesssagesDatabase>) => {
const messagesWithNullText = await db
.selectFrom("message")
.select(["ROWID", "attributedBody"])
.where("text", "is", null)
.where("attributedBody", "is not", null)
.execute();

if (messagesWithNullText.length) {
const now = performance.now();
logger.info(`Adding parsed text to ${messagesWithNullText.length} messages`);

await db.transaction().execute(async (trx) => {
for (const message of messagesWithNullText) {
try {
const { attributedBody, ROWID } = message;
if (attributedBody) {
const parsed = await decodeMessageBuffer(attributedBody);
if (parsed) {
const string = parsed[0]?.value?.string;
if (typeof string === "string") {
await trx.updateTable("message").set({ text: string }).where("ROWID", "=", ROWID).executeTakeFirst();
}
}
}
} catch {
//skip
}
}
});
logger.info(`Done adding parsed text to ${messagesWithNullText.length} messages in ${performance.now() - now}ms`);
}
};

const db = new SQLDatabase("Messages DB", appMessagesDbCopy, async (sqliteDb, kdb) => {
// add in text
logger.info("Adding text column to messages");
await addParsedTextToNullMessages(kdb);
logger.info("Adding text column to messages done");

// create virtual table if not exists
logger.info("Creating virtual table");
await sqliteDb.exec("CREATE VIRTUAL TABLE IF NOT EXISTS message_fts USING fts5(text,message_id)");
const count = await kdb.selectFrom("message_fts").select("message_id").limit(1).executeTakeFirst();
if (count === undefined) {
await sqliteDb.exec("INSERT INTO message_fts SELECT text, guid as message_id FROM message");
}
logger.info("Creating virtual table done");
});

// monkey patch to handle ipc calls

Expand Down
Loading

0 comments on commit 25d00b0

Please sign in to comment.