Skip to content

Commit

Permalink
Fix/llm bugs empty extraction (#1533)
Browse files Browse the repository at this point in the history
* Add llm singleton and check for empty extraction

* Semver

* Tests and spellcheck

* Move the singletons to a proper place

* Leftover print

* Ruff
  • Loading branch information
AlonsoGuevara authored Dec 18, 2024
1 parent f7cd155 commit cfe2082
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 2 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241218221915558063.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Manage llm instances inside a cached singleton. Check for empty dfs after entity/relationship extraction"
}
19 changes: 19 additions & 0 deletions graphrag/index/flows/extract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ async def extract_graph(
num_threads=extraction_num_threads,
)

if not _validate_data(entity_dfs):
error_msg = "Entity Extraction failed. No entities detected during extraction."
callbacks.error(error_msg)
raise ValueError(error_msg)

if not _validate_data(relationship_dfs):
error_msg = (
"Entity Extraction failed. No relationships detected during extraction."
)
callbacks.error(error_msg)
raise ValueError(error_msg)

merged_entities = _merge_entities(entity_dfs)
merged_relationships = _merge_relationships(relationship_dfs)

Expand Down Expand Up @@ -145,3 +157,10 @@ def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
{"name": node, "degree": int(degree)}
for node, degree in graph.degree # type: ignore
])


def _validate_data(df_list: list[pd.DataFrame]) -> bool:
"""Validate that the dataframe list is valid. At least one dataframe must contain data."""
return any(
len(df) > 0 for df in df_list
) # Check for len, not .empty, as the dfs have schemas in some cases
18 changes: 16 additions & 2 deletions graphrag/index/llm/load_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import graphrag.config.defaults as defs
from graphrag.config.enums import LLMType
from graphrag.config.models.llm_parameters import LLMParameters
from graphrag.index.llm.manager import ChatLLMSingleton, EmbeddingsLLMSingleton

from .mock_llm import MockChatLLM

Expand Down Expand Up @@ -110,6 +111,10 @@ def load_llm(
chat_only=False,
) -> ChatLLM:
"""Load the LLM for the entity extraction chain."""
singleton_llm = ChatLLMSingleton().get_llm(name)
if singleton_llm is not None:
return singleton_llm

on_error = _create_error_handler(callbacks)
llm_type = config.type

Expand All @@ -119,7 +124,9 @@ def load_llm(
raise ValueError(msg)

loader = loaders[llm_type]
return loader["load"](on_error, create_cache(cache, name), config)
llm_instance = loader["load"](on_error, create_cache(cache, name), config)
ChatLLMSingleton().set_llm(name, llm_instance)
return llm_instance

msg = f"Unknown LLM type {llm_type}"
raise ValueError(msg)
Expand All @@ -134,15 +141,21 @@ def load_llm_embeddings(
chat_only=False,
) -> EmbeddingsLLM:
"""Load the LLM for the entity extraction chain."""
singleton_llm = EmbeddingsLLMSingleton().get_llm(name)
if singleton_llm is not None:
return singleton_llm

on_error = _create_error_handler(callbacks)
llm_type = llm_config.type
if llm_type in loaders:
if chat_only and not loaders[llm_type]["chat"]:
msg = f"LLM type {llm_type} does not support chat"
raise ValueError(msg)
return loaders[llm_type]["load"](
llm_instance = loaders[llm_type]["load"](
on_error, create_cache(cache, name), llm_config or {}
)
EmbeddingsLLMSingleton().set_llm(name, llm_instance)
return llm_instance

msg = f"Unknown LLM type {llm_type}"
raise ValueError(msg)
Expand Down Expand Up @@ -198,6 +211,7 @@ def _create_openai_config(config: LLMParameters, azure: bool) -> OpenAIConfig:
n=config.n,
temperature=config.temperature,
)

if azure:
if config.api_base is None:
msg = "Azure OpenAI Chat LLM requires an API base"
Expand Down
40 changes: 40 additions & 0 deletions graphrag/index/llm/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""LLM Manager singleton."""

from functools import cache

from fnllm import ChatLLM, EmbeddingsLLM


@cache
class ChatLLMSingleton:
"""A singleton class for the chat LLM instances."""

def __init__(self):
self.llm_dict = {}

def set_llm(self, name, llm):
"""Add an LLM to the dictionary."""
self.llm_dict[name] = llm

def get_llm(self, name) -> ChatLLM | None:
"""Get an LLM from the dictionary."""
return self.llm_dict.get(name)


@cache
class EmbeddingsLLMSingleton:
"""A singleton class for the embeddings LLM instances."""

def __init__(self):
self.llm_dict = {}

def set_llm(self, name, llm):
"""Add an LLM to the dictionary."""
self.llm_dict[name] = llm

def get_llm(self, name) -> EmbeddingsLLM | None:
"""Get an LLM from the dictionary."""
return self.llm_dict.get(name)

0 comments on commit cfe2082

Please sign in to comment.