Skip to content

Commit

Permalink
Flow cleanup (#1510)
Browse files Browse the repository at this point in the history
* Move snapshots out of flows into verbs

* Move degree compute out of extract_graph

* Move entity/relationship df merging into extract

* Move "title" to extraction source

* Move text_unit_ids agg closer to extraction

* Move data definition

* Update test data

* Semver

* Update smoke tests

* Fix empty degree field and update smoke tests and verb data

* Move extractors (#1516)

* Consolidate graph embedding and umap

* Consolidate claim extraction

* Consolidate graph extractor

* Move graph utils

* Move summarizers

* Semver

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>

* Fix syntax typo

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
  • Loading branch information
natoverse and AlonsoGuevara authored Dec 19, 2024
1 parent d0543d1 commit c1c09ba
Show file tree
Hide file tree
Showing 33 changed files with 142 additions and 172 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241212190223784600.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Streamline flows."
}
14 changes: 1 addition & 13 deletions graphrag/index/flows/compute_communities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,11 @@

from graphrag.index.operations.cluster_graph import cluster_graph
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.snapshot import snapshot
from graphrag.storage.pipeline_storage import PipelineStorage


async def compute_communities(
def compute_communities(
base_relationship_edges: pd.DataFrame,
storage: PipelineStorage,
clustering_strategy: dict[str, Any],
snapshot_transient_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to create the base entity graph."""
graph = create_graph(base_relationship_edges)
Expand All @@ -32,12 +28,4 @@ async def compute_communities(
).explode("title")
base_communities["community"] = base_communities["community"].astype(int)

if snapshot_transient_enabled:
await snapshot(
base_communities,
name="base_communities",
storage=storage,
formats=["parquet"],
)

return base_communities
20 changes: 2 additions & 18 deletions graphrag/index/flows/create_base_text_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,14 @@
)

from graphrag.index.operations.chunk_text import chunk_text
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.storage.pipeline_storage import PipelineStorage


async def create_base_text_units(
def create_base_text_units(
documents: pd.DataFrame,
callbacks: VerbCallbacks,
storage: PipelineStorage,
chunk_by_columns: list[str],
chunk_strategy: dict[str, Any] | None = None,
snapshot_transient_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to transform base text_units."""
sort = documents.sort_values(by=["id"], ascending=[True])
Expand Down Expand Up @@ -74,19 +70,7 @@ async def create_base_text_units(
# rename for downstream consumption
chunked.rename(columns={"chunk": "text"}, inplace=True)

output = cast(
"pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True)
)

if snapshot_transient_enabled:
await snapshot(
output,
name="create_base_text_units",
storage=storage,
formats=["parquet"],
)

return output
return cast("pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True))


# TODO: would be nice to inline this completely in the main method with pandas
Expand Down
19 changes: 12 additions & 7 deletions graphrag/index/flows/create_final_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
VerbCallbacks,
)

from graphrag.index.operations.compute_degree import compute_degree
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.embed_graph.embed_graph import embed_graph
from graphrag.index.operations.layout_graph.layout_graph import layout_graph
Expand Down Expand Up @@ -37,15 +38,19 @@ def create_final_nodes(
layout_strategy,
embeddings=graph_embeddings,
)
nodes = base_entity_nodes.merge(
layout, left_on="title", right_on="label", how="left"
)

joined = nodes.merge(base_communities, on="title", how="left")
joined["level"] = joined["level"].fillna(0).astype(int)
joined["community"] = joined["community"].fillna(-1).astype(int)
degrees = compute_degree(graph)

return joined.loc[
nodes = (
base_entity_nodes.merge(layout, left_on="title", right_on="label", how="left")
.merge(degrees, on="title", how="left")
.merge(base_communities, on="title", how="left")
)
nodes["level"] = nodes["level"].fillna(0).astype(int)
nodes["community"] = nodes["community"].fillna(-1).astype(int)
# disconnected nodes and those with no community even at level 0 can be missing degree
nodes["degree"] = nodes["degree"].fillna(0).astype(int)
return nodes.loc[
:,
[
"id",
Expand Down
9 changes: 7 additions & 2 deletions graphrag/index/flows/create_final_relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,25 @@

import pandas as pd

from graphrag.index.operations.compute_degree import compute_degree
from graphrag.index.operations.compute_edge_combined_degree import (
compute_edge_combined_degree,
)
from graphrag.index.operations.create_graph import create_graph


def create_final_relationships(
base_relationship_edges: pd.DataFrame,
base_entity_nodes: pd.DataFrame,
) -> pd.DataFrame:
"""All the steps to transform final relationships."""
relationships = base_relationship_edges

graph = create_graph(base_relationship_edges)
degrees = compute_degree(graph)

relationships["combined_degree"] = compute_edge_combined_degree(
relationships,
base_entity_nodes,
degrees,
node_name_column="title",
node_degree_column="degree",
edge_source_column="source",
Expand Down
93 changes: 13 additions & 80 deletions graphrag/index/flows/extract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,33 @@
from typing import Any
from uuid import uuid4

import networkx as nx
import pandas as pd
from datashaper import (
AsyncType,
VerbCallbacks,
)

from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.extract_entities import extract_entities
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
from graphrag.index.operations.summarize_descriptions import (
summarize_descriptions,
)
from graphrag.storage.pipeline_storage import PipelineStorage


async def extract_graph(
text_units: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
extraction_strategy: dict[str, Any] | None = None,
extraction_num_threads: int = 4,
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
summarization_strategy: dict[str, Any] | None = None,
summarization_num_threads: int = 4,
snapshot_graphml_enabled: bool = False,
snapshot_transient_enabled: bool = False,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""All the steps to create the base entity graph."""
# this returns a graph for each text unit, to be merged later
entity_dfs, relationship_dfs = await extract_entities(
entities, relationships = await extract_entities(
text_units,
callbacks,
cache,
Expand All @@ -52,87 +44,38 @@ async def extract_graph(
num_threads=extraction_num_threads,
)

if not _validate_data(entity_dfs):
if not _validate_data(entities):
error_msg = "Entity Extraction failed. No entities detected during extraction."
callbacks.error(error_msg)
raise ValueError(error_msg)

if not _validate_data(relationship_dfs):
if not _validate_data(relationships):
error_msg = (
"Entity Extraction failed. No relationships detected during extraction."
)
callbacks.error(error_msg)
raise ValueError(error_msg)

merged_entities = _merge_entities(entity_dfs)
merged_relationships = _merge_relationships(relationship_dfs)

entity_summaries, relationship_summaries = await summarize_descriptions(
merged_entities,
merged_relationships,
entities,
relationships,
callbacks,
cache,
strategy=summarization_strategy,
num_threads=summarization_num_threads,
)

base_relationship_edges = _prep_edges(merged_relationships, relationship_summaries)

graph = create_graph(base_relationship_edges)

base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph)

if snapshot_graphml_enabled:
# todo: extract graphs at each level, and add in meta like descriptions
await snapshot_graphml(
graph,
name="graph",
storage=storage,
)
base_relationship_edges = _prep_edges(relationships, relationship_summaries)

if snapshot_transient_enabled:
await snapshot(
base_entity_nodes,
name="base_entity_nodes",
storage=storage,
formats=["parquet"],
)
await snapshot(
base_relationship_edges,
name="base_relationship_edges",
storage=storage,
formats=["parquet"],
)
base_entity_nodes = _prep_nodes(entities, entity_summaries)

return (base_entity_nodes, base_relationship_edges)


def _merge_entities(entity_dfs) -> pd.DataFrame:
all_entities = pd.concat(entity_dfs, ignore_index=True)
return (
all_entities.groupby(["name", "type"], sort=False)
.agg({"description": list, "source_id": list})
.reset_index()
)


def _merge_relationships(relationship_dfs) -> pd.DataFrame:
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
return (
all_relationships.groupby(["source", "target"], sort=False)
.agg({"description": list, "source_id": list, "weight": "sum"})
.reset_index()
)


def _prep_nodes(entities, summaries, graph) -> pd.DataFrame:
degrees_df = _compute_degree(graph)
def _prep_nodes(entities, summaries) -> pd.DataFrame:
entities.drop(columns=["description"], inplace=True)
nodes = (
entities.merge(summaries, on="name", how="left")
.merge(degrees_df, on="name")
.drop_duplicates(subset="name")
.rename(columns={"name": "title", "source_id": "text_unit_ids"})
nodes = entities.merge(summaries, on="title", how="left").drop_duplicates(
subset="title"
)
nodes = nodes.loc[nodes["title"].notna()].reset_index()
nodes["human_readable_id"] = nodes.index
Expand All @@ -145,22 +88,12 @@ def _prep_edges(relationships, summaries) -> pd.DataFrame:
relationships.drop(columns=["description"])
.drop_duplicates(subset=["source", "target"])
.merge(summaries, on=["source", "target"], how="left")
.rename(columns={"source_id": "text_unit_ids"})
)
edges["human_readable_id"] = edges.index
edges["id"] = edges["human_readable_id"].apply(lambda _x: str(uuid4()))
return edges


def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
return pd.DataFrame([
{"name": node, "degree": int(degree)}
for node, degree in graph.degree # type: ignore
])


def _validate_data(df_list: list[pd.DataFrame]) -> bool:
"""Validate that the dataframe list is valid. At least one dataframe must contain data."""
return any(
len(df) > 0 for df in df_list
) # Check for len, not .empty, as the dfs have schemas in some cases
def _validate_data(df: pd.DataFrame) -> bool:
"""Validate that the dataframe has data."""
return len(df) > 0
3 changes: 1 addition & 2 deletions graphrag/index/flows/generate_text_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,8 @@ async def _run_and_snapshot_embeddings(
strategy=text_embed_config["strategy"],
)

data = data.loc[:, ["id", "embedding"]]

if snapshot_embeddings_enabled is True:
data = data.loc[:, ["id", "embedding"]]
await snapshot(
data,
name=f"embeddings.{name}",
Expand Down
15 changes: 15 additions & 0 deletions graphrag/index/operations/compute_degree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A module containing create_graph definition."""

import networkx as nx
import pandas as pd


def compute_degree(graph: nx.Graph) -> pd.DataFrame:
"""Create a new DataFrame with the degree of each node in the graph."""
return pd.DataFrame([
{"title": node, "degree": int(degree)}
for node, degree in graph.degree # type: ignore
])
29 changes: 27 additions & 2 deletions graphrag/index/operations/extract_entities/extract_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def extract_entities(
async_mode: AsyncType = AsyncType.AsyncIO,
entity_types=DEFAULT_ENTITY_TYPES,
num_threads: int = 4,
) -> tuple[list[pd.DataFrame], list[pd.DataFrame]]:
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Extract entities from a piece of text.
Expand Down Expand Up @@ -138,7 +138,10 @@ async def run_strategy(row):
entity_dfs.append(pd.DataFrame(result[0]))
relationship_dfs.append(pd.DataFrame(result[1]))

return (entity_dfs, relationship_dfs)
entities = _merge_entities(entity_dfs)
relationships = _merge_relationships(relationship_dfs)

return (entities, relationships)


def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy:
Expand All @@ -162,3 +165,25 @@ def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStr
case _:
msg = f"Unknown strategy: {strategy_type}"
raise ValueError(msg)


def _merge_entities(entity_dfs) -> pd.DataFrame:
all_entities = pd.concat(entity_dfs, ignore_index=True)
return (
all_entities.groupby(["title", "type"], sort=False)
.agg(description=("description", list), text_unit_ids=("source_id", list))
.reset_index()
)


def _merge_relationships(relationship_dfs) -> pd.DataFrame:
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
return (
all_relationships.groupby(["source", "target"], sort=False)
.agg(
description=("description", list),
text_unit_ids=("source_id", list),
weight=("weight", "sum"),
)
.reset_index()
)
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def run_extract_entities(
)

entities = [
({"name": item[0], **(item[1] or {})})
({"title": item[0], **(item[1] or {})})
for item in graph.nodes(data=True)
if item is not None
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def run( # noqa RUF029 async is required for interface

return EntityExtractionResult(
entities=[
{"type": entity_type, "name": name}
{"type": entity_type, "title": name}
for name, entity_type in entity_map.items()
],
relationships=[],
Expand Down
Loading

0 comments on commit c1c09ba

Please sign in to comment.