Skip to content

Commit

Permalink
improve llms selection of simple reasoning pipeline and fix non persi…
Browse files Browse the repository at this point in the history
…stent settings bug

- improve llms selection of simple reasoning pipeline
- enable llms selection for reranking
- fix non-persistent settings bug
  • Loading branch information
lone17 authored Mar 28, 2024
2 parents b208924 + 14482e9 commit e8d3c70
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 13 deletions.
3 changes: 1 addition & 2 deletions libs/kotaemon/kotaemon/contribs/promptui/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
BINARY_REMOTE_NAME = f"frpc_{platform.system().lower()}_{machine.lower()}"
EXTENSION = ".exe" if os.name == "nt" else ""
BINARY_URL = (
"some-endpoint.com"
f"/kotaemon/tunneling/{VERSION}/{BINARY_REMOTE_NAME}{EXTENSION}"
"some-endpoint.com" f"/kotaemon/tunneling/{VERSION}/{BINARY_REMOTE_NAME}{EXTENSION}"
)

BINARY_FILENAME = f"{BINARY_REMOTE_NAME}_v{VERSION}"
Expand Down
1 change: 0 additions & 1 deletion libs/kotaemon/kotaemon/llms/chats/langchain_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def _get_lc_class(self):


class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore

def __init__(
self,
azure_endpoint: str | None = None,
Expand Down
3 changes: 2 additions & 1 deletion libs/ktem/ktem/components.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Common components, some kind of config"""

import logging
from functools import cache
from pathlib import Path
Expand Down Expand Up @@ -71,7 +72,7 @@ def settings(self) -> dict:
}

def options(self) -> dict:
"""Present a list of models"""
"""Present a dict of models"""
return self._models

def get_random_name(self) -> str:
Expand Down
32 changes: 27 additions & 5 deletions libs/ktem/ktem/index/file/pipelines.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import shutil
import warnings
from collections import defaultdict
Expand All @@ -8,7 +9,7 @@
from pathlib import Path
from typing import Optional

from ktem.components import embeddings, filestorage_path, llms
from ktem.components import embeddings, filestorage_path
from ktem.db.models import engine
from llama_index.vector_stores import (
FilterCondition,
Expand All @@ -25,10 +26,12 @@
from kotaemon.base import RetrievedDocument
from kotaemon.indices import VectorIndexing, VectorRetrieval
from kotaemon.indices.ingests import DocumentIngestor
from kotaemon.indices.rankings import BaseReranking, LLMReranking
from kotaemon.indices.rankings import BaseReranking

from .base import BaseFileIndexIndexing, BaseFileIndexRetriever

logger = logging.getLogger(__name__)


@lru_cache
def dev_settings():
Expand Down Expand Up @@ -67,7 +70,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
embedding=embeddings.get_default(),
)
reranker: BaseReranking = LLMReranking.withx(llm=llms.get_lowest_cost())
reranker: BaseReranking
get_extra_table: bool = False

def run(
Expand Down Expand Up @@ -153,7 +156,23 @@ def run(

@classmethod
def get_user_settings(cls) -> dict:
from ktem.components import llms

try:
reranking_llm = llms.get_lowest_cost_name()
reranking_llm_choices = list(llms.options().keys())
except Exception as e:
logger.error(e)
reranking_llm = None
reranking_llm_choices = []

return {
"reranking_llm": {
"name": "LLM for reranking",
"value": reranking_llm,
"component": "dropdown",
"choices": reranking_llm_choices,
},
"separate_embedding": {
"name": "Use separate embedding",
"value": False,
Expand Down Expand Up @@ -185,7 +204,7 @@ def get_user_settings(cls) -> dict:
},
"use_reranking": {
"name": "Use reranking",
"value": True,
"value": False,
"choices": [True, False],
"component": "checkbox",
},
Expand All @@ -199,7 +218,10 @@ def get_pipeline(cls, user_settings, index_settings, selected):
settings: the settings of the app
kwargs: other arguments
"""
retriever = cls(get_extra_table=user_settings["prioritize_table"])
retriever = cls(
get_extra_table=user_settings["prioritize_table"],
reranker=user_settings["reranking_llm"],
)
if not user_settings["use_reranking"]:
retriever.reranker = None # type: ignore

Expand Down
31 changes: 31 additions & 0 deletions libs/ktem/ktem/pages/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,28 @@ def on_building_ui(self):
self.reasoning_tab()

def on_subscribe_public_events(self):
"""
Subscribes to public events related to user management.
This function is responsible for subscribing to the "onSignIn" event, which is
triggered when a user signs in. It registers two event handlers for this event.
The first event handler, "load_setting", is responsible for loading the user's
settings when they sign in. It takes the user ID as input and returns the
settings state and a list of component outputs. The progress indicator for this
event is set to "hidden".
The second event handler, "get_name", is responsible for retrieving the
username of the current user. It takes the user ID as input and returns the
username if it exists, otherwise it returns "___". The progress indicator for
this event is also set to "hidden".
Parameters:
self (object): The instance of the class.
Returns:
None
"""
if self._app.f_user_management:
self._app.subscribe_event(
name="onSignIn",
Expand Down Expand Up @@ -290,3 +312,12 @@ def components(self) -> list:
def component_names(self):
"""Get the setting components"""
return self._settings_keys

def _on_app_created(self):
if not self._app.f_user_management:
self._app.app.load(
self.load_setting,
inputs=self._user_id,
outputs=[self._settings_state] + self.components(),
show_progress="hidden",
)
18 changes: 14 additions & 4 deletions libs/ktem/ktem/reasoning/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ class AnswerWithContextPipeline(BaseComponent):
qa_table_template: str = DEFAULT_QA_TABLE_PROMPT
qa_chatbot_template: str = DEFAULT_QA_CHATBOT_PROMPT

enable_citation: bool = False
system_prompt: str = ""
lang: str = "English" # support English and Japanese

Expand Down Expand Up @@ -200,7 +201,8 @@ async def run( # type: ignore
lang=self.lang,
)

if evidence:
citation_task = None
if evidence and self.enable_citation:
citation_task = asyncio.create_task(
self.citation_pipeline.ainvoke(context=evidence, question=question)
)
Expand All @@ -226,7 +228,7 @@ async def run( # type: ignore

# retrieve the citation
print("Waiting for citation task")
if evidence:
if citation_task is not None:
citation = await citation_task
else:
citation = None
Expand Down Expand Up @@ -353,7 +355,15 @@ def get_pipeline(cls, settings, retrievers):
_id = cls.get_info()["id"]

pipeline = FullQAPipeline(retrievers=retrievers)
pipeline.answering_pipeline.llm = llms.get_highest_accuracy()
pipeline.answering_pipeline.llm = llms[
settings[f"reasoning.options.{_id}.main_llm"]
]
pipeline.answering_pipeline.citation_pipeline.llm = llms[
settings[f"reasoning.options.{_id}.citation_llm"]
]
pipeline.answering_pipeline.enable_citation = settings[
f"reasoning.options.{_id}.highlight_citation"
]
pipeline.answering_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
settings["reasoning.lang"], "English"
)
Expand Down Expand Up @@ -384,7 +394,7 @@ def get_user_settings(cls) -> dict:
return {
"highlight_citation": {
"name": "Highlight Citation",
"value": True,
"value": False,
"component": "checkbox",
},
"citation_llm": {
Expand Down

0 comments on commit e8d3c70

Please sign in to comment.