From 1d3c4f44336daf7c98f6341964c55a13650d7a2a Mon Sep 17 00:00:00 2001 From: "Tuan Anh Nguyen Dang (Tadashi_Cin)" Date: Tue, 17 Dec 2024 16:49:37 +0700 Subject: [PATCH] feat: add graphrag modes (#574) #none * feat: add support for retrieval modes in LightRAG & NanoGraphRAG * feat: expose custom prompts in LightRAG & NanoGraphRAG * fix: optimize setting UI * fix: update non local mode in LightRAG * fix: update graphRAG mode --- flowsettings.py | 2 +- libs/ktem/ktem/assets/css/main.css | 11 ++ libs/ktem/ktem/assets/js/main.js | 5 + .../index/file/graph/light_graph_index.py | 20 +++- .../index/file/graph/lightrag_pipelines.py | 100 ++++++++++++++--- .../ktem/index/file/graph/nano_graph_index.py | 20 +++- .../ktem/index/file/graph/nano_pipelines.py | 105 ++++++++++++++---- libs/ktem/ktem/index/file/graph/pipelines.py | 2 +- libs/ktem/ktem/pages/settings.py | 13 ++- libs/ktem/ktem/utils/render.py | 8 +- 10 files changed, 239 insertions(+), 47 deletions(-) diff --git a/flowsettings.py b/flowsettings.py index 119adab12..0647b71a4 100644 --- a/flowsettings.py +++ b/flowsettings.py @@ -302,7 +302,7 @@ if USE_NANO_GRAPHRAG: GRAPHRAG_INDEX_TYPES.append("ktem.index.file.graph.NanoGraphRAGIndex") -elif USE_LIGHTRAG: +if USE_LIGHTRAG: GRAPHRAG_INDEX_TYPES.append("ktem.index.file.graph.LightRAGIndex") KH_INDEX_TYPES = [ diff --git a/libs/ktem/ktem/assets/css/main.css b/libs/ktem/ktem/assets/css/main.css index ed3f2219d..82689a597 100644 --- a/libs/ktem/ktem/assets/css/main.css +++ b/libs/ktem/ktem/assets/css/main.css @@ -204,6 +204,11 @@ mark { right: 15px; } +/* prevent overflow of html info panel */ +#html-info-panel { + overflow-x: auto !important; +} + #chat-expand-button { position: absolute; top: 6px; @@ -211,6 +216,12 @@ mark { z-index: 1; } +#save-setting-btn { + width: 150px; + height: 30px; + min-width: 100px !important; +} + #quick-setting-labels { margin-top: 5px; margin-bottom: -10px; diff --git a/libs/ktem/ktem/assets/js/main.js b/libs/ktem/ktem/assets/js/main.js index beab6d9cd..7a9445d16 100644 --- a/libs/ktem/ktem/assets/js/main.js +++ b/libs/ktem/ktem/assets/js/main.js @@ -21,6 +21,11 @@ function run() { let chat_column = document.getElementById("main-chat-bot"); let conv_column = document.getElementById("conv-settings-panel"); + // move setting close button + let setting_tab_nav_bar = document.querySelector("#settings-tab .tab-nav"); + let setting_close_button = document.getElementById("save-setting-btn"); + setting_tab_nav_bar.appendChild(setting_close_button); + let default_conv_column_min_width = "min(300px, 100%)"; conv_column.style.minWidth = default_conv_column_min_width diff --git a/libs/ktem/ktem/index/file/graph/light_graph_index.py b/libs/ktem/ktem/index/file/graph/light_graph_index.py index 583945eeb..0238ff824 100644 --- a/libs/ktem/ktem/index/file/graph/light_graph_index.py +++ b/libs/ktem/ktem/index/file/graph/light_graph_index.py @@ -1,6 +1,6 @@ from typing import Any -from ..base import BaseFileIndexRetriever +from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever from .graph_index import GraphRAGIndex from .lightrag_pipelines import LightRAGIndexingPipeline, LightRAGRetrieverPipeline @@ -12,14 +12,32 @@ def _setup_indexing_cls(self): def _setup_retriever_cls(self): self._retriever_pipeline_cls = [LightRAGRetrieverPipeline] + def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing: + pipeline = super().get_indexing_pipeline(settings, user_id) + # indexing settings + prefix = f"index.options.{self.id}." + striped_settings = { + key[len(prefix) :]: value + for key, value in settings.items() + if key.startswith(prefix) + } + # set the prompts + pipeline.prompts = striped_settings + return pipeline + def get_retriever_pipelines( self, settings: dict, user_id: int, selected: Any = None ) -> list["BaseFileIndexRetriever"]: _, file_ids, _ = selected + # retrieval settings + prefix = f"index.options.{self.id}." + search_type = settings.get(prefix + "search_type", "local") + retrievers = [ LightRAGRetrieverPipeline( file_ids=file_ids, Index=self._resources["Index"], + search_type=search_type, ) ] diff --git a/libs/ktem/ktem/index/file/graph/lightrag_pipelines.py b/libs/ktem/ktem/index/file/graph/lightrag_pipelines.py index 95dd58ec5..6a374f4a8 100644 --- a/libs/ktem/ktem/index/file/graph/lightrag_pipelines.py +++ b/libs/ktem/ktem/index/file/graph/lightrag_pipelines.py @@ -70,7 +70,7 @@ async def llm_func( if if_cache_return is not None: return if_cache_return["return"] - output = model(input_messages).text + output = (await model.ainvoke(input_messages)).text print("-" * 50) print(output, "\n", "-" * 50) @@ -220,7 +220,37 @@ def build_graphrag(working_dir, llm_func, embedding_func): class LightRAGIndexingPipeline(GraphRAGIndexingPipeline): """GraphRAG specific indexing pipeline""" + prompts: dict[str, str] = {} + + @classmethod + def get_user_settings(cls) -> dict: + try: + from lightrag.prompt import PROMPTS + + blacklist_keywords = ["default", "response", "process"] + return { + prompt_name: { + "name": f"Prompt for '{prompt_name}'", + "value": content, + "component": "text", + } + for prompt_name, content in PROMPTS.items() + if all( + keyword not in prompt_name.lower() for keyword in blacklist_keywords + ) + } + except ImportError as e: + print(e) + return {} + def call_graphrag_index(self, graph_id: str, docs: list[Document]): + from lightrag.prompt import PROMPTS + + # modify the prompt if it is set in the settings + for prompt_name, content in self.prompts.items(): + if prompt_name in PROMPTS: + PROMPTS[prompt_name] = content + _, input_path = prepare_graph_index_path(graph_id) input_path.mkdir(parents=True, exist_ok=True) @@ -302,6 +332,19 @@ class LightRAGRetrieverPipeline(BaseFileIndexRetriever): Index = Param(help="The SQLAlchemy Index table") file_ids: list[str] = [] + search_type: str = "local" + + @classmethod + def get_user_settings(cls) -> dict: + return { + "search_type": { + "name": "Search type", + "value": "local", + "choices": ["local", "global", "hybrid"], + "component": "dropdown", + "info": "Whether to use local or global search in the graph.", + } + } def _build_graph_search(self): file_id = self.file_ids[0] @@ -326,7 +369,8 @@ def _build_graph_search(self): llm_func=llm_func, embedding_func=embedding_func, ) - query_params = QueryParam(mode="local", only_need_context=True) + print("search_type", self.search_type) + query_params = QueryParam(mode=self.search_type, only_need_context=True) return graphrag_func, query_params @@ -381,20 +425,40 @@ def run( return [] graphrag_func, query_params = self._build_graph_search() - entities, relationships, sources = asyncio.run( - lightrag_build_local_query_context(graphrag_func, text, query_params) - ) - documents = self.format_context_records(entities, relationships, sources) - plot = self.plot_graph(relationships) - - return documents + [ - RetrievedDocument( - text="", - metadata={ - "file_name": "GraphRAG", - "type": "plot", - "data": plot, - }, - ), - ] + # only local mode support graph visualization + if query_params.mode == "local": + entities, relationships, sources = asyncio.run( + lightrag_build_local_query_context(graphrag_func, text, query_params) + ) + documents = self.format_context_records(entities, relationships, sources) + plot = self.plot_graph(relationships) + documents += [ + RetrievedDocument( + text="", + metadata={ + "file_name": "GraphRAG", + "type": "plot", + "data": plot, + }, + ), + ] + else: + context = graphrag_func.query(text, query_params) + + # account for missing ``` for closing code block + context += "\n```" + + documents = [ + RetrievedDocument( + text=context, + metadata={ + "file_name": "GraphRAG {} Search".format( + query_params.mode.capitalize() + ), + "type": "table", + }, + ) + ] + + return documents diff --git a/libs/ktem/ktem/index/file/graph/nano_graph_index.py b/libs/ktem/ktem/index/file/graph/nano_graph_index.py index d6ec1317f..064c46008 100644 --- a/libs/ktem/ktem/index/file/graph/nano_graph_index.py +++ b/libs/ktem/ktem/index/file/graph/nano_graph_index.py @@ -1,6 +1,6 @@ from typing import Any -from ..base import BaseFileIndexRetriever +from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever from .graph_index import GraphRAGIndex from .nano_pipelines import NanoGraphRAGIndexingPipeline, NanoGraphRAGRetrieverPipeline @@ -12,14 +12,32 @@ def _setup_indexing_cls(self): def _setup_retriever_cls(self): self._retriever_pipeline_cls = [NanoGraphRAGRetrieverPipeline] + def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing: + pipeline = super().get_indexing_pipeline(settings, user_id) + # indexing settings + prefix = f"index.options.{self.id}." + striped_settings = { + key[len(prefix) :]: value + for key, value in settings.items() + if key.startswith(prefix) + } + # set the prompts + pipeline.prompts = striped_settings + return pipeline + def get_retriever_pipelines( self, settings: dict, user_id: int, selected: Any = None ) -> list["BaseFileIndexRetriever"]: _, file_ids, _ = selected + # retrieval settings + prefix = f"index.options.{self.id}." + search_type = settings.get(prefix + "search_type", "local") + retrievers = [ NanoGraphRAGRetrieverPipeline( file_ids=file_ids, Index=self._resources["Index"], + search_type=search_type, ) ] diff --git a/libs/ktem/ktem/index/file/graph/nano_pipelines.py b/libs/ktem/ktem/index/file/graph/nano_pipelines.py index bfee52286..bbfdf26be 100644 --- a/libs/ktem/ktem/index/file/graph/nano_pipelines.py +++ b/libs/ktem/ktem/index/file/graph/nano_pipelines.py @@ -71,7 +71,7 @@ async def llm_func( if if_cache_return is not None: return if_cache_return["return"] - output = model(input_messages).text + output = (await model.ainvoke(input_messages)).text print("-" * 50) print(output, "\n", "-" * 50) @@ -216,7 +216,37 @@ def build_graphrag(working_dir, llm_func, embedding_func): class NanoGraphRAGIndexingPipeline(GraphRAGIndexingPipeline): """GraphRAG specific indexing pipeline""" + prompts: dict[str, str] = {} + + @classmethod + def get_user_settings(cls) -> dict: + try: + from nano_graphrag.prompt import PROMPTS + + blacklist_keywords = ["default", "response", "process"] + return { + prompt_name: { + "name": f"Prompt for '{prompt_name}'", + "value": content, + "component": "text", + } + for prompt_name, content in PROMPTS.items() + if all( + keyword not in prompt_name.lower() for keyword in blacklist_keywords + ) + } + except ImportError as e: + print(e) + return {} + def call_graphrag_index(self, graph_id: str, docs: list[Document]): + from nano_graphrag.prompt import PROMPTS + + # modify the prompt if it is set in the settings + for prompt_name, content in self.prompts.items(): + if prompt_name in PROMPTS: + PROMPTS[prompt_name] = content + _, input_path = prepare_graph_index_path(graph_id) input_path.mkdir(parents=True, exist_ok=True) @@ -297,6 +327,19 @@ class NanoGraphRAGRetrieverPipeline(BaseFileIndexRetriever): Index = Param(help="The SQLAlchemy Index table") file_ids: list[str] = [] + search_type: str = "local" + + @classmethod + def get_user_settings(cls) -> dict: + return { + "search_type": { + "name": "Search type", + "value": "local", + "choices": ["local", "global"], + "component": "dropdown", + "info": "Whether to use local or global search in the graph.", + } + } def _build_graph_search(self): file_id = self.file_ids[0] @@ -321,7 +364,8 @@ def _build_graph_search(self): llm_func=llm_func, embedding_func=embedding_func, ) - query_params = QueryParam(mode="local", only_need_context=True) + print("search_type", self.search_type) + query_params = QueryParam(mode=self.search_type, only_need_context=True) return graphrag_func, query_params @@ -384,22 +428,43 @@ def run( return [] graphrag_func, query_params = self._build_graph_search() - entities, relationships, reports, sources = asyncio.run( - nano_graph_rag_build_local_query_context(graphrag_func, text, query_params) - ) - documents = self.format_context_records( - entities, relationships, reports, sources - ) - plot = self.plot_graph(relationships) - - return documents + [ - RetrievedDocument( - text="", - metadata={ - "file_name": "GraphRAG", - "type": "plot", - "data": plot, - }, - ), - ] + # only local mode support graph visualization + if query_params.mode == "local": + entities, relationships, reports, sources = asyncio.run( + nano_graph_rag_build_local_query_context( + graphrag_func, text, query_params + ) + ) + + documents = self.format_context_records( + entities, relationships, reports, sources + ) + plot = self.plot_graph(relationships) + + documents += [ + RetrievedDocument( + text="", + metadata={ + "file_name": "GraphRAG", + "type": "plot", + "data": plot, + }, + ), + ] + else: + context = graphrag_func.query(text, query_params) + + documents = [ + RetrievedDocument( + text=context, + metadata={ + "file_name": "GraphRAG {} Search".format( + query_params.mode.capitalize() + ), + "type": "table", + }, + ) + ] + + return documents diff --git a/libs/ktem/ktem/index/file/graph/pipelines.py b/libs/ktem/ktem/index/file/graph/pipelines.py index d1a12c677..31b491b45 100644 --- a/libs/ktem/ktem/index/file/graph/pipelines.py +++ b/libs/ktem/ktem/index/file/graph/pipelines.py @@ -180,7 +180,7 @@ def get_user_settings(cls) -> dict: "search_type": { "name": "Search type", "value": "local", - "choices": ["local", "global"], + "choices": ["local"], "component": "dropdown", "info": "Whether to use local or global search in the graph.", } diff --git a/libs/ktem/ktem/pages/settings.py b/libs/ktem/ktem/pages/settings.py index b74d641f0..f899a459e 100644 --- a/libs/ktem/ktem/pages/settings.py +++ b/libs/ktem/ktem/pages/settings.py @@ -106,6 +106,12 @@ def __init__(self, app): self.on_building_ui() def on_building_ui(self): + self.setting_save_btn = gr.Button( + "Save & Close", + variant="primary", + elem_classes=["right-button"], + elem_id="save-setting-btn", + ) if self._app.f_user_management: with gr.Tab("User settings"): self.user_tab() @@ -114,10 +120,6 @@ def on_building_ui(self): self.index_tab() self.reasoning_tab() - self.setting_save_btn = gr.Button( - "Save changes", variant="primary", scale=1, elem_classes=["right-button"] - ) - def on_subscribe_public_events(self): """ Subscribes to public events related to user management. @@ -177,6 +179,9 @@ def on_register_events(self): self.save_setting, inputs=[self._user_id] + self.components(), outputs=self._settings_state, + ).then( + lambda: gr.Tabs(selected="chat-tab"), + outputs=self._app.tabs, ) self._components["reasoning.use"].change( self.change_reasoning_mode, diff --git a/libs/ktem/ktem/utils/render.py b/libs/ktem/ktem/utils/render.py index 9bab73186..c42b8ad66 100644 --- a/libs/ktem/ktem/utils/render.py +++ b/libs/ktem/ktem/utils/render.py @@ -49,7 +49,13 @@ def collapsible(header, content, open: bool = False) -> str: def table(text: str) -> str: """Render table from markdown format into HTML""" text = replace_mardown_header(text) - return markdown.markdown(text, extensions=["markdown.extensions.tables"]) + return markdown.markdown( + text, + extensions=[ + "markdown.extensions.tables", + "markdown.extensions.fenced_code", + ], + ) @staticmethod def preview(