Skip to content

Commit

Permalink
Merge pull request #11 from cocktail-collective/dev
Browse files Browse the repository at this point in the history
Implement Blur Hash & Generation Data View.
  • Loading branch information
rfletchr authored Nov 27, 2023
2 parents 6ebbe3e + 4306abf commit eef5c3d
Show file tree
Hide file tree
Showing 15 changed files with 229 additions and 44 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ def run(self):

setup(
name="cocktail",
version="0.3.1",
version="0.4.0",
description="Cocktail",
package_dir={"": "src"},
packages=find_namespace_packages(where="src"),
install_requires=[
"PySide6",
"qtawesome",
"platformdirs",
"blurhash-python",
],
entry_points={
"console_scripts": [
Expand Down
43 changes: 34 additions & 9 deletions src/cocktail/core/database/api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
__all__ = [
"get_connection",
"insert_page",
"get_update_period",
"get_db_update_period",
"get_last_updated",
"set_last_updated",
"calculate_period",
]

import json
import os
import time
import logging
Expand All @@ -16,6 +17,7 @@
from PySide6 import QtSql
from cocktail.core.database import data_classes

CURRENT_SCHEMA_VERSION = 1

logger = logging.getLogger(__name__)

Expand All @@ -27,6 +29,7 @@ def insert_or_replace(db, table_name, rows: typing.Iterable[typing.NamedTuple]):

start = time.time()
column_names = [name for name in rows[0]._fields]

column_names = ", ".join(column_names)
placeholder = ", ".join(["?"] * len(rows[0]._fields))

Expand All @@ -44,6 +47,9 @@ def insert_or_replace(db, table_name, rows: typing.Iterable[typing.NamedTuple]):
query.prepare(statement)

for index, value in enumerate(row):
if isinstance(value, (list, dict)):
value = json.dumps(value)

query.bindValue(index, value)

if not query.exec():
Expand Down Expand Up @@ -92,6 +98,9 @@ def set_last_updated(db, dt: datetime.datetime = None):


def create_tables(db):
"""
populates the database with the schema defined in schema.sql
"""
logger.info("creating new database")

schema = importlib.resources.read_text("cocktail.core.database", "schema.sql")
Expand All @@ -106,15 +115,22 @@ def create_tables(db):
logger.error(query.lastError().text())
return False

query = QtSql.QSqlQuery(db)
query.prepare("PRAGMA user_version = ?")
query.bindValue(0, CURRENT_SCHEMA_VERSION)
query.exec()

epoch = datetime.datetime.fromtimestamp(0)
set_last_updated(db, epoch)


def get_update_period(db):
now = datetime.datetime.now()

def get_db_update_period(db):
last_updated = get_last_updated(db)
return calculate_period(last_updated)


def calculate_period(last_updated):
now = datetime.datetime.now()
days = (now - last_updated).days

if days <= 2:
Expand Down Expand Up @@ -150,8 +166,17 @@ def get_connection(filepath=None):
return db


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
db = get_connection("./cocktail.sqlite3")
def get_schema_version(db):
"""
Returns the schema version of the database.
"""
query = QtSql.QSqlQuery(db)
query.prepare("PRAGMA user_version")

update_db(db, data_classes.Period.AllTime)
if not query.exec():
raise RuntimeError(f"Failed to execute statement: {query.lastError().text()}")

if not query.next():
raise RuntimeError("Failed to get schema version")
version = query.value(0)
return version
41 changes: 33 additions & 8 deletions src/cocktail/core/database/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ def items_from_model_json(data: dict):

if len(versions) == 0:
logger.warning(f"Model {model.name} has no versions, discarding.")
return [], [], [], []
return None, None, None, None
else:
return model, versions, files, images


def deserialise_page(page: typing.List[dict]):
def deserialise_items(page: typing.List[dict]):
models = []
versions = []
files = []
Expand All @@ -67,12 +67,16 @@ def deserialise_page(page: typing.List[dict]):
model, model_versions, model_files, model_images = items_from_model_json(
model_data
)

if model is None:
continue

models.append(model)
versions.extend(model_versions)
files.extend(model_files)
images.extend(model_images)

return Page(models, versions, files, images)
return Page(models, versions, images, files)


class Model(typing.NamedTuple):
Expand All @@ -84,6 +88,7 @@ class Model(typing.NamedTuple):
creator_name: str
creator_image: str
image: str
image_blur_hash: str
description: str
updated_at: int

Expand All @@ -106,6 +111,7 @@ def from_json(cls, data: dict):
creator_name=data["creator"]["username"],
creator_image=data["creator"]["image"] or "",
image=image_data.get("url", ""),
image_blur_hash=image_data.get("hash", "") or "",
description=data["description"] or "",
updated_at=timestamp,
)
Expand All @@ -121,6 +127,7 @@ def from_record(cls, record: QtSql.QSqlRecord):
creator_name=record.value("creator_name"),
creator_image=record.value("creator_image"),
image=record.value("image"),
image_blur_hash=record.value("image_blur_hash"),
description=record.value("description"),
updated_at=record.value("updated_at"),
)
Expand Down Expand Up @@ -185,20 +192,32 @@ class ModelImage(typing.NamedTuple):
model_version_id: int
url: str
generation_data: str
blur_hash: str
width: int
height: int

@classmethod
def from_json(cls, model_id, model_version_id, data: dict):
metadata = {
"prompt": data.get("prompt", ""),
"negativePrompt": data.get("negativePrompt", ""),
metadata = data["meta"] or {}

generation_data = {
"prompt": metadata.get("prompt", ""),
"negativePrompt": metadata.get("negativePrompt", ""),
"seed": metadata.get("seed", ""),
"steps": metadata.get("steps", 20),
"cfgScale": metadata.get("cfgScale", 7.0),
"sampler": metadata.get("sampler", ""),
}

return cls(
id=data["id"],
model_id=model_id,
model_version_id=model_version_id,
url=data["url"],
generation_data=json.dumps(metadata),
generation_data=generation_data,
blur_hash=data.get("hash", "") or "",
width=data["width"],
height=data["height"],
)

@classmethod
Expand All @@ -209,6 +228,9 @@ def from_record(cls, record: QtSql.QSqlRecord):
model_version_id=record.value("model_version_id"),
url=record.value("url"),
generation_data=json.loads(record.value("generation_data")),
blur_hash=record.value("blur_hash"),
width=record.value("width"),
height=record.value("height"),
)


Expand All @@ -217,6 +239,7 @@ class ModelVersion(typing.NamedTuple):
model_id: int
name: str
description: str
trained_words: typing.List[str]

@classmethod
def from_json(cls, data: dict):
Expand All @@ -225,6 +248,7 @@ def from_json(cls, data: dict):
model_id=data["modelId"],
name=data["name"],
description=data["description"] or "",
trained_words=data["trainedWords"],
)

@classmethod
Expand All @@ -234,14 +258,15 @@ def from_record(cls, record: QtSql.QSqlRecord):
model_id=record.value("model_id"),
name=record.value("name"),
description=record.value("description"),
trained_words=json.loads(record.value("trained_words")),
)


class Page(typing.NamedTuple):
models: typing.List[Model]
versions: typing.List[ModelVersion]
files: typing.List[ModelFile]
images: typing.List[ModelImage]
files: typing.List[ModelFile]


def parse_timestamp(date_str: str):
Expand Down
9 changes: 8 additions & 1 deletion src/cocktail/core/database/schema.sql
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
pragma journal_mode = WAL;
pragma synchronous = normal;

CREATE TABLE IF NOT EXISTS model (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
Expand All @@ -7,6 +10,7 @@ CREATE TABLE IF NOT EXISTS model (
creator_name TEXT NOT NULL,
creator_image TEXT NOT NULL,
image TEXT NOT NULL,
image_blur_hash TEXT NOT NULL,
description TEXT NOT NULL,
updated_at INTEGER NOT NULL

Expand All @@ -17,6 +21,7 @@ CREATE TABLE IF NOT EXISTS model_version (
model_id INTEGER NOT NULL,
name TEXT NOT NULL,
description TEXT NOT NULL,
trained_words TEXT NOT NULL,
FOREIGN KEY (model_id) REFERENCES model (id)
);

Expand All @@ -32,7 +37,6 @@ CREATE TABLE IF NOT EXISTS model_file (
format TEXT NOT NULL,
datatype TEXT NOT NULL,
pruned BOOLEAN NOT NULL,

FOREIGN KEY (model_id) REFERENCES model (id),
FOREIGN KEY (model_version_id) REFERENCES model_version (id)
);
Expand All @@ -43,6 +47,9 @@ CREATE TABLE IF NOT EXISTS model_image (
model_version_id INTEGER NOT NULL,
url TEXT NOT NULL,
generation_data TEXT NOT NULL,
blur_hash TEXT NOT NULL,
width INTEGER NOT NULL,
height INTEGER NOT NULL,
FOREIGN KEY (model_id) REFERENCES model (id),
FOREIGN KEY (model_version_id) REFERENCES model_version (id)
);
Expand Down
29 changes: 19 additions & 10 deletions src/cocktail/core/providers/image.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
__all__ = ["ImageProvider", "ImageProviderProxyModel"]
import blurhash
from PIL import ImageQt
from PySide6 import QtCore, QtGui, QtNetwork
from functools import partial

Expand All @@ -18,6 +20,7 @@ class ImageProviderProxyModel(QtCore.QIdentityProxyModel):
def __init__(self, image_provider=None, parent=None):
super().__init__(parent)
self.image_provider: ImageProvider = image_provider or ImageProvider()
self.blur_cache = FixedLengthMapping(max_entries=100)

def data(self, index: QtCore.QModelIndex, role: int = ...):
if role in self.ImageRoles:
Expand All @@ -33,13 +36,18 @@ def getImage(self, index, role=QtCore.Qt.ItemDataRole.DecorationRole):

callback = partial(self.onImageDownloaded, index=index, url=url)

self.image_provider.queueImageDownload(url, callback)
self.image_provider.queueImageDownload(
url, callback, blur_hash=self.getBlurHash(index, role)
)

return None
return self.image_provider.getImage(url)

def getUrl(self, index: QtCore.QModelIndex, role):
raise NotImplementedError

def getBlurHash(self, index, role):
raise NotImplementedError

def onImageDownloaded(self, image, url, index):
self.dataChanged.emit(index, index, [QtCore.Qt.DecorationRole])

Expand All @@ -60,19 +68,21 @@ def hasImage(self, url):
def getImage(self, url):
return self._cache[url]

def queueImageDownload(self, url, callback):
def queueImageDownload(self, url, callback, blur_hash=None):
if url in self._cache:
callback(self._cache[url])
return

self._cache[url] = None # prevent duplicate requests
if not blur_hash:
self._cache[url] = None # prevent duplicate requests
elif blurhash.is_valid_blurhash(blur_hash):
self._cache[url] = ImageQt.ImageQt(blurhash.decode(blur_hash, 8, 12))

reply = self.network_manager.get(url)
callback = partial(self.onImageDownloaded, reply=reply, callback=callback)

reply.finished.connect(callback)

return None
return self._cache[url]

def onImageDownloaded(self, reply: QtNetwork.QNetworkReply, callback):
if reply.error() != QtNetwork.QNetworkReply.NoError:
Expand All @@ -81,7 +91,6 @@ def onImageDownloaded(self, reply: QtNetwork.QNetworkReply, callback):
return

image = QtGui.QImage.fromData(reply.readAll())

self._cache[reply.url().toString()] = image

callback(image)
if not image.isNull():
self._cache[reply.url().toString()] = image
callback(image)
2 changes: 1 addition & 1 deletion src/cocktail/core/providers/model_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def onRequestFinished(self, reply: QtNetwork.QNetworkReply):
data = json.loads(bytearray(data))
items = data["items"]

self.queue.put(data_classes.deserialise_page(items))
self.queue.put(data_classes.deserialise_items(items))
self.pageReady.emit()

metadata = data["metadata"]
Expand Down
2 changes: 1 addition & 1 deletion src/cocktail/ui/database/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, connection, view=None, parent=None):

def updateModelData(self, period: data_classes.Period = None):
if period is None:
period = db_api.get_update_period(self.connection)
period = db_api.get_db_update_period(self.connection)

self.logger.info(f"Updating model data for period: {period.value}")

Expand Down
Loading

0 comments on commit eef5c3d

Please sign in to comment.