Skip to content

Commit

Permalink
fix: remove is_async argument since everything is async now (#1116)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjmachan authored Jul 22, 2024
1 parent 4427eca commit 89c2664
Show file tree
Hide file tree
Showing 27 changed files with 106 additions and 252 deletions.
5 changes: 2 additions & 3 deletions src/experimental/ragas_experimental/testset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from ragas_experimental.testset.generators import SimpleTestGenerator
from ragas_experimental.testset.generators import QADistribution
from ragas_experimental.testset.generators import QADistribution, SimpleTestGenerator

__all__ = [
"SimpleTestGenerator",
"QADistribution",
]
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dataclasses import dataclass

from langchain_core.documents import Document as LCDocument

from ragas_experimental.testset.graph import Node


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
import numpy as np
import tiktoken
from langchain_core.documents import Document as LCDocument

from ragas.llms.base import BaseRagasLLM, llm_factory
from ragas.llms.json_load import json_loader
from ragas.llms.prompt import Prompt
from ragas_experimental.testset.extractors.base import Extractor
from ragas_experimental.testset.extractors.prompts import (
headline_extractor_prompt,
Expand All @@ -18,6 +14,10 @@
from ragas_experimental.testset.graph import Node
from ragas_experimental.testset.utils import MODEL_MAX_LENGTHS, merge_dicts

from ragas.llms.base import BaseRagasLLM, llm_factory
from ragas.llms.json_load import json_loader
from ragas.llms.prompt import Prompt


@dataclass
class LLMbasedExtractor(Extractor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from dataclasses import dataclass

from langchain_core.documents import Document as LCDocument

from ragas_experimental.testset.extractors.base import Extractor, Regex
from ragas_experimental.testset.graph import Node

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def generate(
docs: t.Sequence[Document],
test_size: int,
distribution: QADistribution,
) -> TestDataset: ...
) -> TestDataset:
...

def generate_with_langchain_docs(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
summary_extractor,
title_extractor,
)
from ragas_experimental.testset.generators import (
QADistribution,
TestGenerator,
)
from ragas_experimental.testset.generators import QADistribution, TestGenerator
from ragas_experimental.testset.generators.base import TestDataset
from ragas_experimental.testset.graph import Node, NodeLevel
from ragas_experimental.testset.questions import (
DEFAULT_DISTRIBUTION,
Expand All @@ -29,11 +27,10 @@
)
from ragas_experimental.testset.splitters import HeadlineSplitter
from ragas_experimental.testset.utils import rng
from ragas_experimental.testset.generators.base import TestDataset

from ragas._analytics import TestsetGenerationEvent, track
from ragas.embeddings import embedding_factory
from ragas.executor import Executor
from ragas._analytics import TestsetGenerationEvent, track
from ragas.llms.base import llm_factory
from ragas.utils import check_if_sum_is_close

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def form_relations(
properties
}}
}}
""".format(node_level=node_level.name)
""".format(
node_level=node_level.name
)
results = schema.execute(
query, context={"nodes": nodes, "relationships": relationships}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from dataclasses import dataclass

import numpy as np

from ragas_experimental.testset.graph import Node
from ragas_experimental.testset.relationships.base import Similarity

Expand Down
1 change: 0 additions & 1 deletion src/experimental/ragas_experimental/testset/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json

import numpy as np

from ragas_experimental.testset.graph import Node, NodeLevel, NodeType, Relationship

MODEL_MAX_LENGTHS = {
Expand Down
2 changes: 1 addition & 1 deletion src/ragas/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def new_group(
name: str, inputs: t.Dict, callbacks: Callbacks, is_async=False
name: str, inputs: t.Dict, callbacks: Callbacks
) -> t.Tuple[CallbackManagerForChainRun, CallbackManagerForChainGroup]:
# start evaluation chain
if isinstance(callbacks, list):
Expand Down
10 changes: 1 addition & 9 deletions src/ragas/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def evaluate(
embeddings: t.Optional[BaseRagasEmbeddings | LangchainEmbeddings] = None,
callbacks: Callbacks = None,
in_ci: bool = False,
is_async: bool = True,
run_config: t.Optional[RunConfig] = None,
raise_exceptions: bool = True,
column_map: t.Optional[t.Dict[str, str]] = None,
Expand Down Expand Up @@ -81,11 +80,6 @@ def evaluate(
Whether the evaluation is running in CI or not. If set to True then some
metrics will be run to increase the reproducability of the evaluations. This
will increase the runtime and cost of evaluations. Default is False.
is_async: bool
Whether to run the evaluation in async mode or not. If set to True then the
evaluation is run by calling the `metric.ascore` method. In case the llm or
embeddings does not support async then the evaluation can be run in sync mode
with `is_async=False`. Default is False.
run_config: RunConfig, optional
Configuration for runtime settings like timeout and retries. If not provided,
default values are used.
Expand Down Expand Up @@ -206,23 +200,21 @@ def evaluate(
# new evaluation chain
row_run_managers = []
evaluation_rm, evaluation_group_cm = new_group(
name="ragas evaluation", inputs={}, callbacks=callbacks, is_async=is_async
name="ragas evaluation", inputs={}, callbacks=callbacks
)
for i, row in enumerate(dataset):
row = t.cast(t.Dict[str, t.Any], row)
row_rm, row_group_cm = new_group(
name=f"row {i}",
inputs=row,
callbacks=evaluation_group_cm,
is_async=is_async,
)
row_run_managers.append((row_rm, row_group_cm))
[
executor.submit(
metric.ascore,
row,
row_group_cm,
is_async,
name=f"{metric.name}-{i}",
thread_timeout=run_config.thread_timeout,
)
Expand Down
6 changes: 3 additions & 3 deletions src/ragas/llms/prompt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import ast
import json
import logging
import os
import ast
import typing as t

from langchain_core.messages import BaseMessage, HumanMessage
Expand Down Expand Up @@ -233,8 +233,8 @@ def get_all_keys(nested_json):
example_dict[self.output_key] = json_loader._safe_load(example[-1], llm)
if example_dict[self.output_key] == {}:
# Extracting the dictionary part using string slicing
dict_str = example[-1].split('(')[0].strip()
example_dict[self.output_key ] = ast.literal_eval(dict_str)
dict_str = example[-1].split("(")[0].strip()
example_dict[self.output_key] = ast.literal_eval(dict_str)
else:
example_dict[self.output_key] = example[-1]
if self.output_type.lower() == "json":
Expand Down
13 changes: 4 additions & 9 deletions src/ragas/metrics/_answer_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ class AnswerCorrectnessClassification(BaseModel):

@dataclass
class AnswerCorrectness(MetricWithLLM, MetricWithEmbeddings):

"""
Measures answer correctness compared to ground truth as a combination of
factuality and semantic similarity.
Expand Down Expand Up @@ -211,16 +210,14 @@ def _create_statements_prompt(self, question: str, text: str) -> PromptValue:
)
return prompt_value

async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> float:
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
assert self.llm is not None, "LLM must be set"

question = row["question"]
statements = {}
for item in ["answer", "ground_truth"]:
p_value = self._create_statements_prompt(question, row[item])
item_statement = await self.llm.generate(
p_value, callbacks=callbacks, is_async=is_async
)
item_statement = await self.llm.generate(p_value, callbacks=callbacks)
statements[item] = await _statements_output_parser.aparse(
item_statement.generations[0][0].text,
p_value,
Expand All @@ -247,9 +244,7 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> fl
ground_truth=ground_truth,
answer=answer,
)
is_statement_present = await self.llm.generate(
p_value, callbacks=callbacks, is_async=is_async
)
is_statement_present = await self.llm.generate(p_value, callbacks=callbacks)
result_text = is_statement_present.generations[0][0].text

answers = await _output_parser.aparse(
Expand All @@ -268,7 +263,7 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> fl
assert self.answer_similarity is not None, "AnswerSimilarity must be set"

similarity_score = await self.answer_similarity.ascore(
row, callbacks=callbacks, is_async=is_async
row, callbacks=callbacks
)

score = np.average(
Expand Down
3 changes: 1 addition & 2 deletions src/ragas/metrics/_answer_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,14 @@ def _create_question_gen_prompt(self, row: t.Dict) -> PromptValue:
ans, ctx = row["answer"], row["contexts"]
return self.question_generation.format(answer=ans, context="\n".join(ctx))

async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> float:
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
assert self.llm is not None, "LLM is not set"

prompt = self._create_question_gen_prompt(row)
result = await self.llm.generate(
prompt,
n=self.strictness,
callbacks=callbacks,
is_async=is_async,
)

answers = [
Expand Down
4 changes: 1 addition & 3 deletions src/ragas/metrics/_answer_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ def __post_init__(self: t.Self):
**self.embeddings.encode_kwargs,
}

async def _ascore(
self: t.Self, row: t.Dict, callbacks: Callbacks, is_async: bool
) -> float:
async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
assert self.embeddings is not None, "embeddings must be set"

ground_truth = t.cast(str, row["ground_truth"])
Expand Down
11 changes: 2 additions & 9 deletions src/ragas/metrics/_context_entities_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ async def get_entities(
self,
text: str,
callbacks: Callbacks,
is_async: bool,
) -> t.Optional[ContextEntitiesResponse]:
assert self.llm is not None, "LLM is not initialized"
p_value = self.context_entity_recall_prompt.format(
Expand All @@ -158,7 +157,6 @@ async def get_entities(
result = await self.llm.generate(
prompt=p_value,
callbacks=callbacks,
is_async=is_async,
)

result_text = result.generations[0][0].text
Expand All @@ -174,15 +172,10 @@ async def _ascore(
self,
row: Dict,
callbacks: Callbacks,
is_async: bool,
) -> float:
ground_truth, contexts = row["ground_truth"], row["contexts"]
ground_truth = await self.get_entities(
ground_truth, callbacks=callbacks, is_async=is_async
)
contexts = await self.get_entities(
"\n".join(contexts), callbacks=callbacks, is_async=is_async
)
ground_truth = await self.get_entities(ground_truth, callbacks=callbacks)
contexts = await self.get_entities("\n".join(contexts), callbacks=callbacks)
if ground_truth is None or contexts is None:
return np.nan
return self._compute_score(ground_truth.entities, contexts.entities)
Expand Down
2 changes: 0 additions & 2 deletions src/ragas/metrics/_context_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ async def _ascore(
self: t.Self,
row: t.Dict,
callbacks: Callbacks,
is_async: bool,
) -> float:
assert self.llm is not None, "LLM is not set"

Expand All @@ -161,7 +160,6 @@ async def _ascore(
results = await self.llm.generate(
hp,
callbacks=callbacks,
is_async=is_async,
n=self.reproducibility,
)
results = [
Expand Down
4 changes: 1 addition & 3 deletions src/ragas/metrics/_context_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def dicts(self) -> t.List[t.Dict]:

@dataclass
class ContextRecall(MetricWithLLM):

"""
Estimates context recall by estimating TP and FN using annotated answer and
retrieved context.
Expand Down Expand Up @@ -163,13 +162,12 @@ def _compute_score(self, response: t.Any) -> float:

return score

async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> float:
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
assert self.llm is not None, "set LLM before use"
p_value = self._create_context_recall_prompt(row)
results = await self.llm.generate(
p_value,
callbacks=callbacks,
is_async=is_async,
n=self.reproducibility,
)
results = [results.generations[0][i].text for i in range(self.reproducibility)]
Expand Down
6 changes: 1 addition & 5 deletions src/ragas/metrics/_faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,7 @@ def _compute_score(self, answers: StatementFaithfulnessAnswers):

return score

async def _ascore(
self: t.Self, row: t.Dict, callbacks: Callbacks, is_async: bool
) -> float:
async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
"""
returns the NLI score for each (q, c, a) pair
"""
Expand All @@ -248,7 +246,6 @@ async def _ascore(
statements = await self.llm.generate(
p_value,
callbacks=callbacks,
is_async=is_async,
)
statements = await _statements_output_parser.aparse(
statements.generations[0][0].text, p_value, self.llm, self.max_retries
Expand All @@ -266,7 +263,6 @@ async def _ascore(
nli_result = await self.llm.generate(
p_value,
callbacks=callbacks,
is_async=is_async,
n=self._reproducibility,
)

Expand Down
Loading

0 comments on commit 89c2664

Please sign in to comment.