Skip to content

Commit

Permalink
feat: add graphrag modes (#574) #none
Browse files Browse the repository at this point in the history
* feat: add support for retrieval modes in LightRAG & NanoGraphRAG

* feat: expose custom prompts in LightRAG & NanoGraphRAG

* fix: optimize setting UI

* fix: update non local mode in LightRAG

* fix: update graphRAG mode
  • Loading branch information
taprosoft authored Dec 17, 2024
1 parent c667bf9 commit 1d3c4f4
Show file tree
Hide file tree
Showing 10 changed files with 239 additions and 47 deletions.
2 changes: 1 addition & 1 deletion flowsettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@

if USE_NANO_GRAPHRAG:
GRAPHRAG_INDEX_TYPES.append("ktem.index.file.graph.NanoGraphRAGIndex")
elif USE_LIGHTRAG:
if USE_LIGHTRAG:
GRAPHRAG_INDEX_TYPES.append("ktem.index.file.graph.LightRAGIndex")

KH_INDEX_TYPES = [
Expand Down
11 changes: 11 additions & 0 deletions libs/ktem/ktem/assets/css/main.css
Original file line number Diff line number Diff line change
Expand Up @@ -204,13 +204,24 @@ mark {
right: 15px;
}

/* prevent overflow of html info panel */
#html-info-panel {
overflow-x: auto !important;
}

#chat-expand-button {
position: absolute;
top: 6px;
right: -10px;
z-index: 1;
}

#save-setting-btn {
width: 150px;
height: 30px;
min-width: 100px !important;
}

#quick-setting-labels {
margin-top: 5px;
margin-bottom: -10px;
Expand Down
5 changes: 5 additions & 0 deletions libs/ktem/ktem/assets/js/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ function run() {
let chat_column = document.getElementById("main-chat-bot");
let conv_column = document.getElementById("conv-settings-panel");

// move setting close button
let setting_tab_nav_bar = document.querySelector("#settings-tab .tab-nav");
let setting_close_button = document.getElementById("save-setting-btn");
setting_tab_nav_bar.appendChild(setting_close_button);

let default_conv_column_min_width = "min(300px, 100%)";
conv_column.style.minWidth = default_conv_column_min_width

Expand Down
20 changes: 19 additions & 1 deletion libs/ktem/ktem/index/file/graph/light_graph_index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any

from ..base import BaseFileIndexRetriever
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .graph_index import GraphRAGIndex
from .lightrag_pipelines import LightRAGIndexingPipeline, LightRAGRetrieverPipeline

Expand All @@ -12,14 +12,32 @@ def _setup_indexing_cls(self):
def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [LightRAGRetrieverPipeline]

def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
pipeline = super().get_indexing_pipeline(settings, user_id)
# indexing settings
prefix = f"index.options.{self.id}."
striped_settings = {
key[len(prefix) :]: value
for key, value in settings.items()
if key.startswith(prefix)
}
# set the prompts
pipeline.prompts = striped_settings
return pipeline

def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
_, file_ids, _ = selected
# retrieval settings
prefix = f"index.options.{self.id}."
search_type = settings.get(prefix + "search_type", "local")

retrievers = [
LightRAGRetrieverPipeline(
file_ids=file_ids,
Index=self._resources["Index"],
search_type=search_type,
)
]

Expand Down
100 changes: 82 additions & 18 deletions libs/ktem/ktem/index/file/graph/lightrag_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def llm_func(
if if_cache_return is not None:
return if_cache_return["return"]

output = model(input_messages).text
output = (await model.ainvoke(input_messages)).text

print("-" * 50)
print(output, "\n", "-" * 50)
Expand Down Expand Up @@ -220,7 +220,37 @@ def build_graphrag(working_dir, llm_func, embedding_func):
class LightRAGIndexingPipeline(GraphRAGIndexingPipeline):
"""GraphRAG specific indexing pipeline"""

prompts: dict[str, str] = {}

@classmethod
def get_user_settings(cls) -> dict:
try:
from lightrag.prompt import PROMPTS

blacklist_keywords = ["default", "response", "process"]
return {
prompt_name: {
"name": f"Prompt for '{prompt_name}'",
"value": content,
"component": "text",
}
for prompt_name, content in PROMPTS.items()
if all(
keyword not in prompt_name.lower() for keyword in blacklist_keywords
)
}
except ImportError as e:
print(e)
return {}

def call_graphrag_index(self, graph_id: str, docs: list[Document]):
from lightrag.prompt import PROMPTS

# modify the prompt if it is set in the settings
for prompt_name, content in self.prompts.items():
if prompt_name in PROMPTS:
PROMPTS[prompt_name] = content

_, input_path = prepare_graph_index_path(graph_id)
input_path.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -302,6 +332,19 @@ class LightRAGRetrieverPipeline(BaseFileIndexRetriever):

Index = Param(help="The SQLAlchemy Index table")
file_ids: list[str] = []
search_type: str = "local"

@classmethod
def get_user_settings(cls) -> dict:
return {
"search_type": {
"name": "Search type",
"value": "local",
"choices": ["local", "global", "hybrid"],
"component": "dropdown",
"info": "Whether to use local or global search in the graph.",
}
}

def _build_graph_search(self):
file_id = self.file_ids[0]
Expand All @@ -326,7 +369,8 @@ def _build_graph_search(self):
llm_func=llm_func,
embedding_func=embedding_func,
)
query_params = QueryParam(mode="local", only_need_context=True)
print("search_type", self.search_type)
query_params = QueryParam(mode=self.search_type, only_need_context=True)

return graphrag_func, query_params

Expand Down Expand Up @@ -381,20 +425,40 @@ def run(
return []

graphrag_func, query_params = self._build_graph_search()
entities, relationships, sources = asyncio.run(
lightrag_build_local_query_context(graphrag_func, text, query_params)
)

documents = self.format_context_records(entities, relationships, sources)
plot = self.plot_graph(relationships)

return documents + [
RetrievedDocument(
text="",
metadata={
"file_name": "GraphRAG",
"type": "plot",
"data": plot,
},
),
]
# only local mode support graph visualization
if query_params.mode == "local":
entities, relationships, sources = asyncio.run(
lightrag_build_local_query_context(graphrag_func, text, query_params)
)
documents = self.format_context_records(entities, relationships, sources)
plot = self.plot_graph(relationships)
documents += [
RetrievedDocument(
text="",
metadata={
"file_name": "GraphRAG",
"type": "plot",
"data": plot,
},
),
]
else:
context = graphrag_func.query(text, query_params)

# account for missing ``` for closing code block
context += "\n```"

documents = [
RetrievedDocument(
text=context,
metadata={
"file_name": "GraphRAG {} Search".format(
query_params.mode.capitalize()
),
"type": "table",
},
)
]

return documents
20 changes: 19 additions & 1 deletion libs/ktem/ktem/index/file/graph/nano_graph_index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any

from ..base import BaseFileIndexRetriever
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .graph_index import GraphRAGIndex
from .nano_pipelines import NanoGraphRAGIndexingPipeline, NanoGraphRAGRetrieverPipeline

Expand All @@ -12,14 +12,32 @@ def _setup_indexing_cls(self):
def _setup_retriever_cls(self):
self._retriever_pipeline_cls = [NanoGraphRAGRetrieverPipeline]

def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
pipeline = super().get_indexing_pipeline(settings, user_id)
# indexing settings
prefix = f"index.options.{self.id}."
striped_settings = {
key[len(prefix) :]: value
for key, value in settings.items()
if key.startswith(prefix)
}
# set the prompts
pipeline.prompts = striped_settings
return pipeline

def get_retriever_pipelines(
self, settings: dict, user_id: int, selected: Any = None
) -> list["BaseFileIndexRetriever"]:
_, file_ids, _ = selected
# retrieval settings
prefix = f"index.options.{self.id}."
search_type = settings.get(prefix + "search_type", "local")

retrievers = [
NanoGraphRAGRetrieverPipeline(
file_ids=file_ids,
Index=self._resources["Index"],
search_type=search_type,
)
]

Expand Down
Loading

0 comments on commit 1d3c4f4

Please sign in to comment.