Skip to content

Commit

Permalink
refactor: use type[_]
Browse files Browse the repository at this point in the history
  • Loading branch information
phv2312 committed Oct 27, 2024
1 parent ad826ca commit 8c87120
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 18 deletions.
19 changes: 11 additions & 8 deletions libs/kotaemon/kotaemon/deps/container.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from functools import cached_property, lru_cache
from pathlib import Path
from typing import Any, Sequence
from typing import Any, Sequence, TypeAlias

from ktem.index.file.base import BaseFileIndexIndexing, BaseFileIndexRetriever
from ktem.index.file.pipelines import DocumentRetrievalPipeline, IndexDocumentPipeline
Expand All @@ -19,6 +19,8 @@
from ..storages.vectorstores import BaseVectorStore, ChromaVectorStore
from .registry import Dependency, Registry

SourceRecords: TypeAlias = Sequence[Row[tuple[Source]]]


@lru_cache(1)
def get_engine() -> Engine:
Expand Down Expand Up @@ -58,19 +60,19 @@ def __init__(self, collection_idx: int, private: bool = False):
self.private = private

@cached_property
def source(self) -> Source:
def source(self) -> type[Source]:
source = Source.from_index(self.collection_idx, self.private)
source.metadata.create_all(self.engine)
return source

@cached_property
def index(self) -> Index:
def index(self) -> type[Index]:
index = Index.from_index(self.collection_idx)
index.metadata.create_all(self.engine)
return index

@cached_property
def filegroup(self) -> FileGroup:
def filegroup(self) -> type[FileGroup]:
filegroup = FileGroup.from_index(self.collection_idx)
filegroup.metadata.create_all(self.engine)
return filegroup
Expand All @@ -87,11 +89,12 @@ def filestorage_path(self) -> Path:
class FileCRUD:
engine = get_engine()

def __init__(self, source: type[Source]):
self.source = source

def list_docids(self) -> list[str]:
with Session(self.engine) as session:
records: Sequence[Row[tuple[Source]]] = session.execute(
select(Source)
).all()
records: SourceRecords = session.execute(select(self.source)).all()
return [record[0].id for record in records]


Expand Down Expand Up @@ -179,7 +182,7 @@ class Container:
user_idx: int = 1
private: bool = False
fileschema: FileSchemaFactory = FileSchemaFactory(collection_idx, private)
crud: FileCRUD = FileCRUD()
crud: FileCRUD = FileCRUD(fileschema.source)

vectorstores: Registry[BaseVectorStore] = Registry(
{"chroma": Dependency(VectorstoreFactory.chroma)}
Expand Down
2 changes: 1 addition & 1 deletion libs/kotaemon/kotaemon/schemas/file/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ class BaseSchema(DeclarativeBase):
__abstract__ = True

@classmethod
def from_dict(cls, cls_name: str, params: dict[str, Any]) -> Any:
def from_dict(cls, cls_name: str, params: dict[str, Any]) -> type["BaseSchema"]:
params["__abstract__"] = False
return type(cls_name, (BaseSchema,), params)
5 changes: 2 additions & 3 deletions libs/kotaemon/kotaemon/schemas/file/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from sqlalchemy import JSON, Column, DateTime, Integer, String, func
from sqlalchemy.ext.mutable import MutableDict
from typing_extensions import Self

from .base import BaseSchema

Expand All @@ -17,6 +16,6 @@ class FileGroup(BaseSchema):
data = Column(MutableDict.as_mutable(JSON), default={"files": []}) # type: ignore

@classmethod
def from_index(cls, idx: int) -> Self:
def from_index(cls, idx: int) -> type["FileGroup"]:
cls.__tablename__ = f"index__{idx}__group"
return cast(Self, cls.from_dict("FileGroup", dict(vars(cls))))
return cast(type["FileGroup"], cls.from_dict("FileGroup", dict(vars(cls))))
5 changes: 2 additions & 3 deletions libs/kotaemon/kotaemon/schemas/file/index.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import cast

from sqlalchemy import Column, Integer, String
from typing_extensions import Self

from .base import BaseSchema

Expand All @@ -16,6 +15,6 @@ class Index(BaseSchema):
user = Column(Integer, default=1)

@classmethod
def from_index(cls, idx: int) -> Self:
def from_index(cls, idx: int) -> type["Index"]:
cls.__tablename__ = f"index__{idx}__index"
return cast(Self, cls.from_dict("Index", dict(vars(cls))))
return cast(type["Index"], cls.from_dict("Index", dict(vars(cls))))
5 changes: 2 additions & 3 deletions libs/kotaemon/kotaemon/schemas/file/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from sqlalchemy import JSON, Column, DateTime, Integer, String, UniqueConstraint, func
from sqlalchemy.ext.mutable import MutableDict
from typing_extensions import Self

from .base import BaseSchema

Expand All @@ -22,7 +21,7 @@ class Source(BaseSchema):
note = Column(MutableDict.as_mutable(JSON), default={}) # type: ignore

@classmethod
def from_index(cls, idx: int, private: bool = False) -> Self:
def from_index(cls, idx: int, private: bool = False) -> type["Source"]:
cls.__tablename__ = f"index__{idx}__source"
if private:
cls.__table_args__ = (
Expand All @@ -32,4 +31,4 @@ def from_index(cls, idx: int, private: bool = False) -> Self:
else:
cls.name = Column(String, unique=True)

return cast(Self, cls.from_dict("Source", dict(vars(cls))))
return cast(type["Source"], cls.from_dict("Source", dict(vars(cls))))

0 comments on commit 8c87120

Please sign in to comment.