diff --git a/cohere/client.py b/cohere/client.py index fe6c5ed81..8f4ce36a7 100644 --- a/cohere/client.py +++ b/cohere/client.py @@ -650,6 +650,7 @@ def rerank( model: str, top_n: Optional[int] = None, max_chunks_per_doc: Optional[int] = None, + snippet_extraction: Literal["disabled", "short", "medium", "long"] = "disabled", ) -> Reranking: """Returns an ordered list of documents ordered by their relevance to the provided query @@ -659,7 +660,16 @@ def rerank( model (str): The model to use for re-ranking top_n (int): (optional) The number of results to return, defaults to returning all results max_chunks_per_doc (int): (optional) The maximum number of chunks derived from a document + snippet_extraction (str): (optional) Snippet extraction mode: + - "disabled": do not run snippet extraction + - "short": return short snippets + - "medium": return medium-length snippets + - "long": return long snippets """ + if snippet_extraction not in {"disabled", "short", "medium", "long"}: + raise CohereError( + message='invalid `snippet_extraction` mode, must be one of "disabled", "short", "medium", "long"' + ) parsed_docs = [] for doc in documents: if isinstance(doc, str): @@ -678,6 +688,7 @@ def rerank( "top_n": top_n, "return_documents": False, "max_chunks_per_doc": max_chunks_per_doc, + "snippet_extraction": snippet_extraction, } reranking = Reranking(self._request(cohere.RERANK_URL, json=json_body)) diff --git a/cohere/client_async.py b/cohere/client_async.py index d64657391..c9800200f 100644 --- a/cohere/client_async.py +++ b/cohere/client_async.py @@ -441,6 +441,7 @@ async def rerank( model: str, top_n: Optional[int] = None, max_chunks_per_doc: Optional[int] = None, + snippet_extraction: Literal["disabled", "short", "medium", "long"] = "disabled", ) -> Reranking: """Returns an ordered list of documents ordered by their relevance to the provided query @@ -450,7 +451,16 @@ async def rerank( model (str): The model to use for re-ranking top_n (int): (optional) The number of results to return, defaults to returning all results max_chunks_per_doc (int): (optional) The maximum number of chunks derived from a document + snippet_extraction (str): (optional) Snippet extraction mode: + - "disabled": do not run snippet extraction + - "short": return short snippets + - "medium": return medium-length snippets + - "long": return long snippets """ + if snippet_extraction not in {"disabled", "short", "medium", "long"}: + raise CohereError( + message='invalid `snippet_extraction` mode, must be one of "disabled", "short", "medium", "long"' + ) parsed_docs = [] for doc in documents: if isinstance(doc, str): @@ -469,6 +479,7 @@ async def rerank( "top_n": top_n, "return_documents": False, "max_chunks_per_doc": max_chunks_per_doc, + "snippet_extraction": snippet_extraction, } reranking = Reranking(await self._request(cohere.RERANK_URL, json=json_body)) for rank in reranking.results: diff --git a/cohere/responses/rerank.py b/cohere/responses/rerank.py index 99319e41b..1e22b3b29 100644 --- a/cohere/responses/rerank.py +++ b/cohere/responses/rerank.py @@ -9,23 +9,44 @@ """ +class RerankSnippet(NamedTuple("Snippet", [("text", str), ("start_index", int), ("end_index", int)])): + """ + Returned by co.rerank, + object which contains `text`, `start_index` and `end_index` + """ + + def __repr__(self) -> str: + return f"RerankSnippet" + + class RerankResult(CohereObject): def __init__( - self, document: Dict[str, Any] = None, index: int = None, relevance_score: float = None, *args, **kwargs + self, + document: Dict[str, Any] = None, + index: int = None, + relevance_score: float = None, + snippets: List[RerankSnippet] = None, + *args, + **kwargs, ) -> None: super().__init__(*args, **kwargs) self.document = document + self.snippets = snippets self.index = index self.relevance_score = relevance_score def __repr__(self) -> str: score = self.relevance_score index = self.index - if self.document is None: - return f"RerankResult" - else: - text = self.document["text"] - return f"RerankResult" + document_repr = "" + if self.document is not None: + document_repr = f", document['text']: {self.document['text']}" + + snippet_repr = "" + if self.snippets is not None: + snippet_repr = f", snippets: {self.snippets}" + + return f"RerankResult" class Reranking(CohereObject): @@ -40,10 +61,23 @@ def __init__( def _results(self, response: Dict[str, Any]) -> List[RerankResult]: results = [] for res in response["results"]: - if "document" in res.keys(): - results.append(RerankResult(res["document"], res["index"], res["relevance_score"])) + document = res.get("document") + + if res.get("snippets") is not None: + snippets = [ + RerankSnippet( + text=snippet["text"], start_index=snippet["start_index"], end_index=snippet["end_index"] + ) + for snippet in res["snippets"] + ] else: - results.append(RerankResult(index=res["index"], relevance_score=res["relevance_score"])) + snippets = None + + results.append( + RerankResult( + document=document, index=res["index"], relevance_score=res["relevance_score"], snippets=snippets + ) + ) return results def __str__(self) -> str: