Skip to content

Commit

Permalink
feat(llms.json_load): Recursively load json lists (#593)
Browse files Browse the repository at this point in the history
Slightly broken json are protected against by the function
`ragas.llms.json_load.JsonLoader._find_outermost_json`. However, I've
found that for many metrics, gpt4 can often return slightly broken json
lists, for which this function returns only the first valid json. Here
we wrap `_find_outermost_json` with `_load_all_jsons` which calls it
recursively to load the full json list.

I.e. expected output for `'{"1":"2"}, ,, {"3":"4"}]'` is `[{'1': '2'},
{'3': '4'}]`

---------

Co-authored-by: jjmachan <jamesjithin97@gmail.com>
  • Loading branch information
pberger514 and jjmachan authored Feb 15, 2024
1 parent 3834fe5 commit 27e1c24
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
20 changes: 15 additions & 5 deletions src/ragas/llms/json_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def _safe_load(self, text: str, llm: BaseRagasLLM, callbacks: Callbacks = None):
retry = 0
while retry <= self.max_retries:
try:
start, end = self._find_outermost_json(text)
return json.loads(text[start:end])
_json = self._load_all_jsons(text)
return _json[0] if len(_json) == 1 else _json
except ValueError:
from ragas.llms.prompt import PromptValue

Expand All @@ -104,8 +104,8 @@ async def _asafe_load(
retry = 0
while retry <= self.max_retries:
try:
start, end = self._find_outermost_json(text)
return json.loads(text[start:end])
_json = self._load_all_jsons(text)
return _json[0] if len(_json) == 1 else _json
except ValueError:
from ragas.llms.prompt import PromptValue

Expand All @@ -126,7 +126,7 @@ async def safe_load(
callbacks: Callbacks = None,
is_async: bool = True,
run_config: RunConfig = RunConfig(),
):
) -> t.Union[t.Dict, t.List]:
if is_async:
_asafe_load_with_retry = add_async_retry(self._asafe_load, run_config)
return await _asafe_load_with_retry(text=text, llm=llm, callbacks=callbacks)
Expand All @@ -141,6 +141,16 @@ async def safe_load(
safe_load,
)

def _load_all_jsons(self, text):
start, end = self._find_outermost_json(text)
_json = json.loads(text[start:end])
text = text.replace(text[start:end], "", 1)
start, end = self._find_outermost_json(text)
if (start, end) == (-1, -1):
return [_json]
else:
return [_json] + self._load_all_jsons(text)

def _find_outermost_json(self, text):
stack = []
start_index = -1
Expand Down
1 change: 1 addition & 0 deletions src/ragas/metrics/_context_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ async def _ascore(
await json_loader.safe_load(item, self.llm, is_async=is_async)
for item in responses
]
json_responses = t.cast(t.List[t.Dict], json_responses)
score = self._calculate_average_precision(json_responses)
return score

Expand Down
1 change: 1 addition & 0 deletions src/ragas/metrics/_faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ async def _ascore(
is_async=is_async,
)

assert isinstance(statements, dict), "Invalid JSON response"
p = self._create_nli_prompt(row, statements.get("statements", []))
nli_result = await self.llm.generate(p, callbacks=callbacks, is_async=is_async)
json_output = await json_loader.safe_load(
Expand Down

0 comments on commit 27e1c24

Please sign in to comment.