Skip to content

Commit

Permalink
feat: improvements in test synthesization (#1621)
Browse files Browse the repository at this point in the history
PR 2 of improvements in test generation

---------

Co-authored-by: Jin Lin Tham <jltham18@gmail.com>
  • Loading branch information
shahules786 and jltham authored Nov 7, 2024
1 parent 5f74eb5 commit d840b16
Show file tree
Hide file tree
Showing 32 changed files with 1,336 additions and 1,151 deletions.
6 changes: 3 additions & 3 deletions docs/getstarted/rag_testset_generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ query_distribution = default_query_distribution(generator_llm)
```
```
[
(AbstractQuerySynthesizer(llm=generator_llm), 0.25),
(ComparativeAbstractQuerySynthesizer(llm=generator_llm), 0.25),
(SpecificQuerySynthesizer(llm=generator_llm), 0.5),
(SingleHopSpecificQuerySynthesizer(llm=llm), 0.5),
(MultiHopAbstractQuerySynthesizer(llm=llm), 0.25),
(MultiHopSpecificQuerySynthesizer(llm=llm), 0.25),
]
```

Expand Down
4 changes: 2 additions & 2 deletions docs/references/testset_schema.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
members:
- BaseScenario

::: ragas.testset.synthesizers.specific_query.SpecificQueryScenario
::: ragas.testset.synthesizers.single_hop.specific.SingleHopSpecificQuerySynthesizer
options:
show_root_heading: True
show_root_full_path: False

::: ragas.testset.synthesizers.abstract_query.AbstractQueryScenario
::: ragas.testset.synthesizers.multi_hop.specific.MultiHopSpecificQuerySynthesizer
options:
show_root_heading: True
show_root_full_path: False
2 changes: 2 additions & 0 deletions src/ragas/metrics/_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class DistanceMeasure(Enum):
LEVENSHTEIN = "levenshtein"
HAMMING = "hamming"
JARO = "jaro"
JARO_WINKLER = "jaro_winkler"


@dataclass
Expand Down Expand Up @@ -77,6 +78,7 @@ def __post_init__(self):
DistanceMeasure.LEVENSHTEIN: distance.Levenshtein,
DistanceMeasure.HAMMING: distance.Hamming,
DistanceMeasure.JARO: distance.Jaro,
DistanceMeasure.JARO_WINKLER: distance.JaroWinkler,
}

def init(self, run_config: RunConfig):
Expand Down
100 changes: 84 additions & 16 deletions src/ragas/testset/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,15 @@ def __repr__(self) -> str:
def __str__(self) -> str:
return self.__repr__()

def find_clusters(
self, relationship_condition: t.Callable[[Relationship], bool] = lambda _: True
def find_indirect_clusters(
self,
relationship_condition: t.Callable[[Relationship], bool] = lambda _: True,
depth_limit: int = 3,
) -> t.List[t.Set[Node]]:
"""
Finds clusters of nodes in the knowledge graph based on a relationship condition.
Finds indirect clusters of nodes in the knowledge graph based on a relationship condition.
Here if A -> B -> C -> D, then A, B, C, and D form a cluster. If there's also a path A -> B -> C -> E,
it will form a separate cluster.
Parameters
----------
Expand All @@ -223,31 +227,95 @@ def find_clusters(
A list of sets, where each set contains nodes that form a cluster.
"""
clusters = []
visited = set()
visited_paths = set()

relationships = [
rel for rel in self.relationships if relationship_condition(rel)
]

def dfs(node: Node, cluster: t.Set[Node]):
visited.add(node)
def dfs(node: Node, cluster: t.Set[Node], depth: int, path: t.Tuple[Node, ...]):
if depth >= depth_limit or path in visited_paths:
return
visited_paths.add(path)
cluster.add(node)

for rel in relationships:
if rel.source == node and rel.target not in visited:
dfs(rel.target, cluster)
# if the relationship is bidirectional, we need to check the reverse
neighbor = None
if rel.source == node and rel.target not in cluster:
neighbor = rel.target
elif (
rel.bidirectional
and rel.target == node
and rel.source not in visited
and rel.source not in cluster
):
dfs(rel.source, cluster)
neighbor = rel.source

if neighbor is not None:
dfs(neighbor, cluster.copy(), depth + 1, path + (neighbor,))

# Add completed path-based cluster
if len(cluster) > 1:
clusters.append(cluster)

for node in self.nodes:
if node not in visited:
cluster = set()
dfs(node, cluster)
if len(cluster) > 1:
initial_cluster = set()
dfs(node, initial_cluster, 0, (node,))

# Remove duplicates by converting clusters to frozensets
unique_clusters = [
set(cluster) for cluster in set(frozenset(c) for c in clusters)
]

return unique_clusters

def find_direct_clusters(
self, relationship_condition: t.Callable[[Relationship], bool] = lambda _: True
) -> t.Dict[Node, t.List[t.Set[Node]]]:
"""
Finds direct clusters of nodes in the knowledge graph based on a relationship condition.
Here if A->B, and A->C, then A, B, and C form a cluster.
Parameters
----------
relationship_condition : Callable[[Relationship], bool], optional
A function that takes a Relationship and returns a boolean, by default lambda _: True
Returns
-------
List[Set[Node]]
A list of sets, where each set contains nodes that form a cluster.
"""

clusters = []
relationships = [
rel for rel in self.relationships if relationship_condition(rel)
]
for node in self.nodes:
cluster = set()
cluster.add(node)
for rel in relationships:
if rel.bidirectional:
if rel.source == node:
cluster.add(rel.target)
elif rel.target == node:
cluster.add(rel.source)
else:
if rel.source == node:
cluster.add(rel.target)

if len(cluster) > 1:
if cluster not in clusters:
clusters.append(cluster)

return clusters
# Remove subsets from clusters
unique_clusters = []
for cluster in clusters:
if not any(cluster < other for other in clusters):
unique_clusters.append(cluster)
clusters = unique_clusters

cluster_dict = {}
for cluster in clusters:
cluster_dict.update({cluster.pop(): cluster})

return cluster_dict
38 changes: 38 additions & 0 deletions src/ragas/testset/graph_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import typing as t

from ragas.testset.graph import KnowledgeGraph, Node


def get_child_nodes(node: Node, graph: KnowledgeGraph, level: int = 1) -> t.List[Node]:
"""
Get the child nodes of a given node up to a specified level.
Parameters
----------
node : Node
The node to get the children of.
graph : KnowledgeGraph
The knowledge graph containing the node.
level : int
The maximum level to which child nodes are searched.
Returns
-------
List[Node]
The list of child nodes up to the specified level.
"""
children = []

# Helper function to perform depth-limited search for child nodes
def dfs(current_node: Node, current_level: int):
if current_level > level:
return
for rel in graph.relationships:
if rel.source == current_node and rel.type == "child":
children.append(rel.target)
dfs(rel.target, current_level + 1)

# Start DFS from the initial node at level 0
dfs(node, 1)

return children
31 changes: 11 additions & 20 deletions src/ragas/testset/synthesizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,29 @@
import typing as t

from ragas.llms import BaseRagasLLM

from .abstract_query import (
AbstractQuerySynthesizer,
ComparativeAbstractQuerySynthesizer,
from ragas.testset.synthesizers.multi_hop import (
MultiHopAbstractQuerySynthesizer,
MultiHopSpecificQuerySynthesizer,
)
from ragas.testset.synthesizers.single_hop.specific import (
SingleHopSpecificQuerySynthesizer,
)

from .base import BaseSynthesizer
from .base_query import QuerySynthesizer
from .specific_query import SpecificQuerySynthesizer

QueryDistribution = t.List[t.Tuple[BaseSynthesizer, float]]


def default_query_distribution(llm: BaseRagasLLM) -> QueryDistribution:
"""
Default query distribution for the test set.
By default, 25% of the queries are generated using `AbstractQuerySynthesizer`,
25% are generated using `ComparativeAbstractQuerySynthesizer`, and 50% are
generated using `SpecificQuerySynthesizer`.
"""
""" """
return [
(AbstractQuerySynthesizer(llm=llm), 0.25),
(ComparativeAbstractQuerySynthesizer(llm=llm), 0.25),
(SpecificQuerySynthesizer(llm=llm), 0.5),
(SingleHopSpecificQuerySynthesizer(llm=llm), 0.5),
(MultiHopAbstractQuerySynthesizer(llm=llm), 0.25),
(MultiHopSpecificQuerySynthesizer(llm=llm), 0.25),
]


__all__ = [
"BaseSynthesizer",
"QuerySynthesizer",
"AbstractQuerySynthesizer",
"ComparativeAbstractQuerySynthesizer",
"SpecificQuerySynthesizer",
"default_query_distribution",
]
Loading

0 comments on commit d840b16

Please sign in to comment.