Skip to content

Commit

Permalink
Graph collapse (#1464)
Browse files Browse the repository at this point in the history
* Refactor graph creation

* Semver

* Spellcheck

* Update integ pipeline

* Fix cast

* Improve pandas chaining

* Cleaner apply

* Use list comprehensions

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
  • Loading branch information
natoverse and AlonsoGuevara authored Dec 5, 2024
1 parent 756f5c3 commit d17dfd0
Show file tree
Hide file tree
Showing 61 changed files with 444 additions and 1,192 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/minor-20241203220552914273.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Refactor graph creation."
}
2 changes: 0 additions & 2 deletions docs/config/yaml.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,6 @@ This is the base LLM configuration section. Other steps may override this config

- `embeddings` **bool** - Export embeddings snapshots to parquet.
- `graphml` **bool** - Export graph snapshots to GraphML.
- `raw_entities` **bool** - Export raw entity snapshots to JSON.
- `top_level_nodes` **bool** - Export top-level-node snapshots to JSON.
- `transient` **bool** - Export transient workflow tables snapshots to parquet.

### encoding_model
Expand Down
3 changes: 0 additions & 3 deletions graphrag/config/create_graphrag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,6 @@ def hydrate_parallelization_params(
):
snapshots_model = SnapshotsConfig(
graphml=reader.bool("graphml") or defs.SNAPSHOTS_GRAPHML,
raw_entities=reader.bool("raw_entities") or defs.SNAPSHOTS_RAW_ENTITIES,
top_level_nodes=reader.bool("top_level_nodes")
or defs.SNAPSHOTS_TOP_LEVEL_NODES,
embeddings=reader.bool("embeddings") or defs.SNAPSHOTS_EMBEDDINGS,
transient=reader.bool("transient") or defs.SNAPSHOTS_TRANSIENT,
)
Expand Down
2 changes: 0 additions & 2 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@
REPORTING_TYPE = ReportingType.file
REPORTING_BASE_DIR = "logs"
SNAPSHOTS_GRAPHML = False
SNAPSHOTS_RAW_ENTITIES = False
SNAPSHOTS_TOP_LEVEL_NODES = False
SNAPSHOTS_EMBEDDINGS = False
SNAPSHOTS_TRANSIENT = False
STORAGE_BASE_DIR = "output"
Expand Down
2 changes: 0 additions & 2 deletions graphrag/config/init_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@
snapshots:
graphml: false
raw_entities: false
top_level_nodes: false
embeddings: false
transient: false
Expand Down
2 changes: 0 additions & 2 deletions graphrag/config/input_models/snapshots_config_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,4 @@ class SnapshotsConfigInput(TypedDict):

embeddings: NotRequired[bool | str | None]
graphml: NotRequired[bool | str | None]
raw_entities: NotRequired[bool | str | None]
top_level_nodes: NotRequired[bool | str | None]
transient: NotRequired[bool | str | None]
8 changes: 0 additions & 8 deletions graphrag/config/models/snapshots_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,6 @@ class SnapshotsConfig(BaseModel):
description="A flag indicating whether to take snapshots of GraphML.",
default=defs.SNAPSHOTS_GRAPHML,
)
raw_entities: bool = Field(
description="A flag indicating whether to take snapshots of raw entities.",
default=defs.SNAPSHOTS_RAW_ENTITIES,
)
top_level_nodes: bool = Field(
description="A flag indicating whether to take snapshots of top-level nodes.",
default=defs.SNAPSHOTS_TOP_LEVEL_NODES,
)
transient: bool = Field(
description="A flag indicating whether to take snapshots of transient tables.",
default=defs.SNAPSHOTS_TRANSIENT,
Expand Down
6 changes: 2 additions & 4 deletions graphrag/index/create_pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference
config={
"snapshot_graphml": settings.snapshots.graphml,
"snapshot_transient": settings.snapshots.transient,
"snapshot_raw_entities": settings.snapshots.raw_entities,
"entity_extract": {
**settings.entity_extraction.parallelization.model_dump(),
"async_mode": settings.entity_extraction.async_mode,
Expand All @@ -236,11 +235,9 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference
settings.root_dir,
),
},
"embed_graph_enabled": settings.embed_graph.enabled,
"cluster_graph": {
"strategy": settings.cluster_graph.resolved_strategy()
},
"embed_graph": {"strategy": settings.embed_graph.resolved_strategy()},
},
),
PipelineWorkflowReference(
Expand All @@ -255,7 +252,8 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference
name=create_final_nodes,
config={
"layout_graph_enabled": settings.umap.enabled,
"snapshot_top_level_nodes": settings.snapshots.top_level_nodes,
"embed_graph_enabled": settings.embed_graph.enabled,
"embed_graph": {"strategy": settings.embed_graph.resolved_strategy()},
},
),
]
Expand Down
164 changes: 97 additions & 67 deletions graphrag/index/flows/create_base_entity_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
"""All the steps to create the base entity graph."""

from typing import Any, cast
from uuid import uuid4

import networkx as nx
import pandas as pd
from datashaper import (
AsyncType,
Expand All @@ -13,12 +15,10 @@

from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.operations.cluster_graph import cluster_graph
from graphrag.index.operations.embed_graph import embed_graph
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.extract_entities import extract_entities
from graphrag.index.operations.merge_graphs import merge_graphs
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
from graphrag.index.operations.snapshot_rows import snapshot_rows
from graphrag.index.operations.summarize_descriptions import (
summarize_descriptions,
)
Expand All @@ -30,23 +30,20 @@ async def create_base_entity_graph(
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
runtime_storage: PipelineStorage,
clustering_strategy: dict[str, Any],
extraction_strategy: dict[str, Any] | None = None,
extraction_num_threads: int = 4,
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
node_merge_config: dict[str, Any] | None = None,
edge_merge_config: dict[str, Any] | None = None,
summarization_strategy: dict[str, Any] | None = None,
summarization_num_threads: int = 4,
embedding_strategy: dict[str, Any] | None = None,
snapshot_graphml_enabled: bool = False,
snapshot_raw_entities_enabled: bool = False,
snapshot_transient_enabled: bool = False,
) -> pd.DataFrame:
) -> None:
"""All the steps to create the base entity graph."""
# this returns a graph for each text unit, to be merged later
entities, entity_graphs = await extract_entities(
entity_dfs, relationship_dfs = await extract_entities(
text_units,
callbacks,
cache,
Expand All @@ -55,89 +52,122 @@ async def create_base_entity_graph(
strategy=extraction_strategy,
async_mode=extraction_async_mode,
entity_types=entity_types,
to="entities",
num_threads=extraction_num_threads,
)

merged_graph = merge_graphs(
entity_graphs,
callbacks,
node_operations=node_merge_config,
edge_operations=edge_merge_config,
)
merged_entities = _merge_entities(entity_dfs)
merged_relationships = _merge_relationships(relationship_dfs)

summarized = await summarize_descriptions(
merged_graph,
entity_summaries, relationship_summaries = await summarize_descriptions(
merged_entities,
merged_relationships,
callbacks,
cache,
strategy=summarization_strategy,
num_threads=summarization_num_threads,
)

clustered = cluster_graph(
summarized,
callbacks,
column="entity_graph",
base_relationship_edges = _prep_edges(merged_relationships, relationship_summaries)

graph = create_graph(base_relationship_edges)

base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph)

communities = cluster_graph(
graph,
strategy=clustering_strategy,
to="clustered_graph",
level_to="level",
)

if embedding_strategy:
clustered["embeddings"] = await embed_graph(
clustered,
callbacks,
column="clustered_graph",
strategy=embedding_strategy,
)
base_communities = _prep_communities(communities)

if snapshot_raw_entities_enabled:
await snapshot(
entities,
name="raw_extracted_entities",
storage=storage,
formats=["json"],
)
await runtime_storage.set("base_entity_nodes", base_entity_nodes)
await runtime_storage.set("base_relationship_edges", base_relationship_edges)
await runtime_storage.set("base_communities", base_communities)

if snapshot_graphml_enabled:
# todo: extract graphs at each level, and add in meta like descriptions
await snapshot_graphml(
merged_graph,
name="merged_graph",
graph,
name="graph",
storage=storage,
)
await snapshot_graphml(
summarized,
name="summarized_graph",

if snapshot_transient_enabled:
await snapshot(
base_entity_nodes,
name="base_entity_nodes",
storage=storage,
formats=["parquet"],
)
await snapshot_rows(
clustered,
column="clustered_graph",
base_name="clustered_graph",
await snapshot(
base_relationship_edges,
name="base_relationship_edges",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
formats=["parquet"],
)
if embedding_strategy:
await snapshot_rows(
clustered,
column="entity_graph",
base_name="embedded_graph",
storage=storage,
formats=[{"format": "text", "extension": "graphml"}],
)

final_columns = ["level", "clustered_graph"]
if embedding_strategy:
final_columns.append("embeddings")

output = cast(pd.DataFrame, clustered[final_columns])

if snapshot_transient_enabled:
await snapshot(
output,
name="create_base_entity_graph",
base_communities,
name="base_communities",
storage=storage,
formats=["parquet"],
)

return output

def _merge_entities(entity_dfs) -> pd.DataFrame:
all_entities = pd.concat(entity_dfs, ignore_index=True)
return (
all_entities.groupby(["name", "type"], sort=False)
.agg({"description": list, "source_id": list})
.reset_index()
)


def _merge_relationships(relationship_dfs) -> pd.DataFrame:
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
return (
all_relationships.groupby(["source", "target"], sort=False)
.agg({"description": list, "source_id": list, "weight": "sum"})
.reset_index()
)


def _prep_nodes(entities, summaries, graph) -> pd.DataFrame:
degrees_df = _compute_degree(graph)
entities.drop(columns=["description"], inplace=True)
nodes = (
entities.merge(summaries, on="name", how="left")
.merge(degrees_df, on="name")
.drop_duplicates(subset="name")
.rename(columns={"name": "title", "source_id": "text_unit_ids"})
)
nodes = nodes.loc[nodes["title"].notna()].reset_index()
nodes["human_readable_id"] = nodes.index
nodes["id"] = nodes["human_readable_id"].apply(lambda _x: str(uuid4()))
return nodes


def _prep_edges(relationships, summaries) -> pd.DataFrame:
edges = (
relationships.drop(columns=["description"])
.drop_duplicates(subset=["source", "target"])
.merge(summaries, on=["source", "target"], how="left")
.rename(columns={"source_id": "text_unit_ids"})
)
edges["human_readable_id"] = edges.index
edges["id"] = edges["human_readable_id"].apply(lambda _x: str(uuid4()))
return edges


def _prep_communities(communities) -> pd.DataFrame:
base_communities = pd.DataFrame(
communities, columns=cast(Any, ["level", "community", "title"])
)
base_communities = base_communities.explode("title")
base_communities["community"] = base_communities["community"].astype(int)
return base_communities


def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
return pd.DataFrame([
{"name": node, "degree": int(degree)} for node, degree in graph.degree
]) # type: ignore
Loading

0 comments on commit d17dfd0

Please sign in to comment.