Skip to content

Commit

Permalink
Make testgenerator output compatible with evaluate (#302)
Browse files Browse the repository at this point in the history
* Changed context from str to List[str] so that it is consistent with
eval. Now output of TestDataset can be used for evaluation.
*  Changed typo in _generate_doc_nodes_map
* Changed TestDataset class to reflect the changes in test set
generation. Drawback is episode_done will be True in all cases as data
is changed at the level above.
  • Loading branch information
tinomaxthayil authored Nov 21, 2023
1 parent 1cfbaa1 commit 3d29c44
Showing 1 changed file with 17 additions and 21 deletions.
38 changes: 17 additions & 21 deletions src/ragas/testset/testset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
"conditional": "_condition_question",
}

DataRow = namedtuple("DataRow", ["question", "context", "answer", "question_type"])
DataRow = namedtuple("DataRow", ["question", "ground_truth_context", "ground_truth", "question_type", "episode_done"])


@dataclass
Expand All @@ -71,21 +71,14 @@ class TestDataset:
def to_pandas(self) -> pd.DataFrame:
data_samples = []
for data in self.test_data:
is_conv = len(data.context) > 1
question_type = data.question_type
data = [
{
"question": qstn,
"context": ctx,
"answer": ans,
"question_type": question_type,
"episode_done": True,
}
for qstn, ctx, ans in zip(data.question, data.context, data.answer)
]
if is_conv:
data[0].update({"episode_done": False})
data_samples.extend(data)
data = {
"question": data.question,
"ground_truth_context": data.ground_truth_context,
"ground_truth": data.ground_truth,
"question_type": data.question_type,
"episode_done": data.episode_done,
}
data_samples.append(data)

return pd.DataFrame.from_records(data_samples)

Expand Down Expand Up @@ -260,10 +253,10 @@ def _remove_nodes(
return available_indices

def _generate_doc_nodes_map(
self, documenet_nodes: t.List[BaseNode]
self, document_nodes: t.List[BaseNode]
) -> t.Dict[str, t.List[BaseNode]]:
doc_nodes_map: t.Dict[str, t.List[BaseNode]] = defaultdict(list)
for node in documenet_nodes:
for node in document_nodes:
if node.ref_doc_id:
doc_nodes_map[node.ref_doc_id].append(node)

Expand Down Expand Up @@ -398,10 +391,13 @@ def generate(
is_valid_question = self._filter_question(question)
if is_valid_question:
context = self._generate_context(question, text_chunk)
is_conv = len(context) > 1
answer = self._generate_answer(question, context)
samples.append(
DataRow(question.split("\n"), context, answer, evolve_type)
)
for i, (qstn, ctx, ans) in enumerate(zip(question.split("\n"), context, answer)):
episode_done = False if is_conv and i==0 else True
samples.append(
DataRow(qstn, [ctx], [ans], evolve_type, episode_done)
)
count += 1
pbar.update(count)

Expand Down

0 comments on commit 3d29c44

Please sign in to comment.