diff --git a/src/ragas/testset/testset_generator.py b/src/ragas/testset/testset_generator.py index 51ba8ab50..66f7633fa 100644 --- a/src/ragas/testset/testset_generator.py +++ b/src/ragas/testset/testset_generator.py @@ -57,7 +57,7 @@ "conditional": "_condition_question", } -DataRow = namedtuple("DataRow", ["question", "context", "answer", "question_type"]) +DataRow = namedtuple("DataRow", ["question", "ground_truth_context", "ground_truth", "question_type", "episode_done"]) @dataclass @@ -71,21 +71,14 @@ class TestDataset: def to_pandas(self) -> pd.DataFrame: data_samples = [] for data in self.test_data: - is_conv = len(data.context) > 1 - question_type = data.question_type - data = [ - { - "question": qstn, - "context": ctx, - "answer": ans, - "question_type": question_type, - "episode_done": True, - } - for qstn, ctx, ans in zip(data.question, data.context, data.answer) - ] - if is_conv: - data[0].update({"episode_done": False}) - data_samples.extend(data) + data = { + "question": data.question, + "ground_truth_context": data.ground_truth_context, + "ground_truth": data.ground_truth, + "question_type": data.question_type, + "episode_done": data.episode_done, + } + data_samples.append(data) return pd.DataFrame.from_records(data_samples) @@ -260,10 +253,10 @@ def _remove_nodes( return available_indices def _generate_doc_nodes_map( - self, documenet_nodes: t.List[BaseNode] + self, document_nodes: t.List[BaseNode] ) -> t.Dict[str, t.List[BaseNode]]: doc_nodes_map: t.Dict[str, t.List[BaseNode]] = defaultdict(list) - for node in documenet_nodes: + for node in document_nodes: if node.ref_doc_id: doc_nodes_map[node.ref_doc_id].append(node) @@ -398,10 +391,13 @@ def generate( is_valid_question = self._filter_question(question) if is_valid_question: context = self._generate_context(question, text_chunk) + is_conv = len(context) > 1 answer = self._generate_answer(question, context) - samples.append( - DataRow(question.split("\n"), context, answer, evolve_type) - ) + for i, (qstn, ctx, ans) in enumerate(zip(question.split("\n"), context, answer)): + episode_done = False if is_conv and i==0 else True + samples.append( + DataRow(qstn, [ctx], [ans], evolve_type, episode_done) + ) count += 1 pbar.update(count)