Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional snippets to rerank response #288

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ def rerank(
model: str,
top_n: Optional[int] = None,
max_chunks_per_doc: Optional[int] = None,
snippet_extraction: Literal["disabled", "short", "medium", "long"] = "disabled",
) -> Reranking:
"""Returns an ordered list of documents ordered by their relevance to the provided query

Expand All @@ -659,7 +660,16 @@ def rerank(
model (str): The model to use for re-ranking
top_n (int): (optional) The number of results to return, defaults to returning all results
max_chunks_per_doc (int): (optional) The maximum number of chunks derived from a document
snippet_extraction (str): (optional) Snippet extraction mode:
- "disabled": do not run snippet extraction
- "short": return short snippets
- "medium": return medium-length snippets
- "long": return long snippets
"""
if snippet_extraction not in {"disabled", "short", "medium", "long"}:
raise CohereError(
message='invalid `snippet_extraction` mode, must be one of "disabled", "short", "medium", "long"'
)
parsed_docs = []
for doc in documents:
if isinstance(doc, str):
Expand All @@ -678,6 +688,7 @@ def rerank(
"top_n": top_n,
"return_documents": False,
"max_chunks_per_doc": max_chunks_per_doc,
"snippet_extraction": snippet_extraction,
}

reranking = Reranking(self._request(cohere.RERANK_URL, json=json_body))
Expand Down
11 changes: 11 additions & 0 deletions cohere/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ async def rerank(
model: str,
top_n: Optional[int] = None,
max_chunks_per_doc: Optional[int] = None,
snippet_extraction: Literal["disabled", "short", "medium", "long"] = "disabled",
) -> Reranking:
"""Returns an ordered list of documents ordered by their relevance to the provided query

Expand All @@ -450,7 +451,16 @@ async def rerank(
model (str): The model to use for re-ranking
top_n (int): (optional) The number of results to return, defaults to returning all results
max_chunks_per_doc (int): (optional) The maximum number of chunks derived from a document
snippet_extraction (str): (optional) Snippet extraction mode:
- "disabled": do not run snippet extraction
- "short": return short snippets
- "medium": return medium-length snippets
- "long": return long snippets
"""
if snippet_extraction not in {"disabled", "short", "medium", "long"}:
raise CohereError(
message='invalid `snippet_extraction` mode, must be one of "disabled", "short", "medium", "long"'
)
parsed_docs = []
for doc in documents:
if isinstance(doc, str):
Expand All @@ -469,6 +479,7 @@ async def rerank(
"top_n": top_n,
"return_documents": False,
"max_chunks_per_doc": max_chunks_per_doc,
"snippet_extraction": snippet_extraction,
}
reranking = Reranking(await self._request(cohere.RERANK_URL, json=json_body))
for rank in reranking.results:
Expand Down
52 changes: 43 additions & 9 deletions cohere/responses/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,44 @@
"""


class RerankSnippet(NamedTuple("Snippet", [("text", str), ("start_index", int), ("end_index", int)])):
"""
Returned by co.rerank,
object which contains `text`, `start_index` and `end_index`
"""

def __repr__(self) -> str:
return f"RerankSnippet<text: {self.text}, start_index: {self.start_index}, end_index: {self.end_index}>"


class RerankResult(CohereObject):
def __init__(
self, document: Dict[str, Any] = None, index: int = None, relevance_score: float = None, *args, **kwargs
self,
document: Dict[str, Any] = None,
index: int = None,
relevance_score: float = None,
snippets: List[RerankSnippet] = None,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.document = document
self.snippets = snippets
self.index = index
self.relevance_score = relevance_score

def __repr__(self) -> str:
score = self.relevance_score
index = self.index
if self.document is None:
return f"RerankResult<index: {index}, relevance_score: {score}>"
else:
text = self.document["text"]
return f"RerankResult<document['text']: {text}, index: {index}, relevance_score: {score}>"
document_repr = ""
if self.document is not None:
document_repr = f", document['text']: {self.document['text']}"

snippet_repr = ""
if self.snippets is not None:
snippet_repr = f", snippets: {self.snippets}"

return f"RerankResult<index: {index}, relevance_score: {score}{document_repr}{snippet_repr}>"


class Reranking(CohereObject):
Expand All @@ -40,10 +61,23 @@ def __init__(
def _results(self, response: Dict[str, Any]) -> List[RerankResult]:
results = []
for res in response["results"]:
if "document" in res.keys():
results.append(RerankResult(res["document"], res["index"], res["relevance_score"]))
document = res.get("document")

if res.get("snippets") is not None:
snippets = [
RerankSnippet(
text=snippet["text"], start_index=snippet["start_index"], end_index=snippet["end_index"]
)
for snippet in res["snippets"]
]
else:
results.append(RerankResult(index=res["index"], relevance_score=res["relevance_score"]))
snippets = None

results.append(
RerankResult(
document=document, index=res["index"], relevance_score=res["relevance_score"], snippets=snippets
)
)
return results

def __str__(self) -> str:
Expand Down
Loading