-
Notifications
You must be signed in to change notification settings - Fork 745
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: improvements in test synthesization (#1621)
PR 2 of improvements in test generation --------- Co-authored-by: Jin Lin Tham <jltham18@gmail.com>
- Loading branch information
1 parent
5f74eb5
commit d840b16
Showing
32 changed files
with
1,336 additions
and
1,151 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import typing as t | ||
|
||
from ragas.testset.graph import KnowledgeGraph, Node | ||
|
||
|
||
def get_child_nodes(node: Node, graph: KnowledgeGraph, level: int = 1) -> t.List[Node]: | ||
""" | ||
Get the child nodes of a given node up to a specified level. | ||
Parameters | ||
---------- | ||
node : Node | ||
The node to get the children of. | ||
graph : KnowledgeGraph | ||
The knowledge graph containing the node. | ||
level : int | ||
The maximum level to which child nodes are searched. | ||
Returns | ||
------- | ||
List[Node] | ||
The list of child nodes up to the specified level. | ||
""" | ||
children = [] | ||
|
||
# Helper function to perform depth-limited search for child nodes | ||
def dfs(current_node: Node, current_level: int): | ||
if current_level > level: | ||
return | ||
for rel in graph.relationships: | ||
if rel.source == current_node and rel.type == "child": | ||
children.append(rel.target) | ||
dfs(rel.target, current_level + 1) | ||
|
||
# Start DFS from the initial node at level 0 | ||
dfs(node, 1) | ||
|
||
return children |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,29 @@ | ||
import typing as t | ||
|
||
from ragas.llms import BaseRagasLLM | ||
|
||
from .abstract_query import ( | ||
AbstractQuerySynthesizer, | ||
ComparativeAbstractQuerySynthesizer, | ||
from ragas.testset.synthesizers.multi_hop import ( | ||
MultiHopAbstractQuerySynthesizer, | ||
MultiHopSpecificQuerySynthesizer, | ||
) | ||
from ragas.testset.synthesizers.single_hop.specific import ( | ||
SingleHopSpecificQuerySynthesizer, | ||
) | ||
|
||
from .base import BaseSynthesizer | ||
from .base_query import QuerySynthesizer | ||
from .specific_query import SpecificQuerySynthesizer | ||
|
||
QueryDistribution = t.List[t.Tuple[BaseSynthesizer, float]] | ||
|
||
|
||
def default_query_distribution(llm: BaseRagasLLM) -> QueryDistribution: | ||
""" | ||
Default query distribution for the test set. | ||
By default, 25% of the queries are generated using `AbstractQuerySynthesizer`, | ||
25% are generated using `ComparativeAbstractQuerySynthesizer`, and 50% are | ||
generated using `SpecificQuerySynthesizer`. | ||
""" | ||
""" """ | ||
return [ | ||
(AbstractQuerySynthesizer(llm=llm), 0.25), | ||
(ComparativeAbstractQuerySynthesizer(llm=llm), 0.25), | ||
(SpecificQuerySynthesizer(llm=llm), 0.5), | ||
(SingleHopSpecificQuerySynthesizer(llm=llm), 0.5), | ||
(MultiHopAbstractQuerySynthesizer(llm=llm), 0.25), | ||
(MultiHopSpecificQuerySynthesizer(llm=llm), 0.25), | ||
] | ||
|
||
|
||
__all__ = [ | ||
"BaseSynthesizer", | ||
"QuerySynthesizer", | ||
"AbstractQuerySynthesizer", | ||
"ComparativeAbstractQuerySynthesizer", | ||
"SpecificQuerySynthesizer", | ||
"default_query_distribution", | ||
] |
Oops, something went wrong.