Skip to content

Commit

Permalink
implement loglikelihood endpoint (#277)
Browse files Browse the repository at this point in the history
* implement loglikelihood endpoint
* poetry update
  • Loading branch information
sanderland authored Aug 14, 2023
1 parent f60bd70 commit ee55ca9
Show file tree
Hide file tree
Showing 10 changed files with 280 additions and 177 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## 4.20.0

- [#276] (https://github.com/cohere-ai/cohere-python/pull/276)
- Add support for base_model option in create_custom_model
- [#277] (https://github.com/cohere-ai/cohere-python/pull/277)
- Add support for co.loglikelihood endpoint

## 4.19.1

- [#273] (https://github.com/cohere-ai/cohere-python/pull/273)
Expand Down
1 change: 1 addition & 0 deletions cohere/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
CHECK_API_KEY_URL = "check-api-key"
TOKENIZE_URL = "tokenize"
DETOKENIZE_URL = "detokenize"
LOGLIKELIHOOD_URL = "loglikelihood"

CLUSTER_JOBS_URL = "cluster-jobs"
EMBED_JOBS_URL = "embed-jobs"
Expand Down
20 changes: 20 additions & 0 deletions cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Codebook,
Detokenization,
Generations,
LogLikelihoods,
StreamingGenerations,
Tokens,
)
Expand Down Expand Up @@ -104,6 +105,25 @@ def check_api_key(self) -> Dict[str, bool]:
"""
return {"valid": is_api_key_valid(self.api_key)}

def loglikelihood(
self,
prompt: Optional[str] = None,
completion: Optional[str] = None,
model: Optional[str] = None,
) -> LogLikelihoods:
"""Calculates the token log-likelihood for a provided prompt and completion.
Using this endpoint instead of co.generate with max_tokens=0 will guarantee that any required tokens such as <EOP_TOKEN>
are correctly inserted, and makes it easier to retrieve only the completion log-likelihood.
Args:
prompt (str): The prompt
completion (str): (Optional) The completion
model (str): (Optional) The model to use for calculating the log-likelihoods
"""
json_body = {"model": model, "prompt": prompt, "completion": completion}
response = self._request(cohere.LOGLIKELIHOOD_URL, json=json_body)
return LogLikelihoods(response["prompt_tokens"], response["completion_tokens"])

def batch_generate(
self, prompts: List[str], return_exceptions=False, **kwargs
) -> List[Union[Generations, Exception]]:
Expand Down
11 changes: 11 additions & 0 deletions cohere/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
Generations,
LabelPrediction,
Language,
LogLikelihoods,
PreferenceRating,
Reranking,
StreamingGenerations,
Expand Down Expand Up @@ -143,6 +144,16 @@ async def check_api_key(self) -> Dict[str, bool]:
"""
return {"valid": is_api_key_valid(self.api_key)}

async def loglikelihood(
self,
prompt: Optional[str] = None,
completion: Optional[str] = None,
model: Optional[str] = None,
) -> LogLikelihoods:
json_body = {"model": model, "prompt": prompt, "completion": completion}
response = await self._request(cohere.LOGLIKELIHOOD_URL, json=json_body)
return LogLikelihoods(response["prompt_tokens"], response["completion_tokens"])

async def batch_generate(
self, prompts: List[str], return_exceptions=False, **kwargs
) -> List[Union[Exception, Generations]]:
Expand Down
1 change: 1 addition & 0 deletions cohere/responses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
PreferenceRating,
)
from cohere.responses.generation import Generation, Generations, StreamingGenerations
from cohere.responses.loglikelihood import LogLikelihoods
from cohere.responses.rerank import RerankDocument, Reranking, RerankResult
from cohere.responses.summarize import SummarizeResponse
from cohere.responses.tokenize import Detokenization, Tokens
15 changes: 15 additions & 0 deletions cohere/responses/loglikelihood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Dict, List, NamedTuple, Optional

TokenLogLikelihood = NamedTuple("TokenLogLikelihood", [("encoded", int), ("decoded", str), ("log_likelihood", float)])


class LogLikelihoods:
@staticmethod
def token_list_from_dict(token_list: Optional[List[Dict]]):
if token_list is None:
return None
return [TokenLogLikelihood(**token) for token in token_list]

def __init__(self, prompt_tokens: List[TokenLogLikelihood], completion_tokens: List[TokenLogLikelihood]):
self.prompt_tokens = self.token_list_from_dict(prompt_tokens)
self.completion_tokens = self.token_list_from_dict(completion_tokens)
352 changes: 176 additions & 176 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cohere"
version = "4.19.3"
version = "4.20.0"
description = ""
authors = ["Cohere"]
readme = "README.md"
Expand Down
31 changes: 31 additions & 0 deletions tests/async/test_async_loglikelihood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from cohere.responses import LogLikelihoods

TEST_MODEL = "command-light"


@pytest.mark.asyncio
async def test_basic_llh_async(async_client):
resp = await async_client.loglikelihood(model=TEST_MODEL, prompt="co:here", completion="co:where?")
assert isinstance(resp, LogLikelihoods)
assert isinstance(resp.prompt_tokens[0].encoded, int)
assert isinstance(resp.prompt_tokens[0].decoded, str)
assert isinstance(resp.prompt_tokens[1].log_likelihood, float)

assert resp.prompt_tokens[0].decoded == "<BOS_TOKEN>"
assert resp.prompt_tokens[-1].decoded == "<EOP_TOKEN>"

assert isinstance(resp.completion_tokens[0].encoded, int)
assert isinstance(resp.completion_tokens[0].decoded, str)
assert isinstance(resp.completion_tokens[0].log_likelihood, float)

assert resp.completion_tokens[-1].decoded == "<EOS_TOKEN>"


@pytest.mark.asyncio
async def test_only_prompt_async_llh(async_client):
resp = await async_client.loglikelihood(model=TEST_MODEL, prompt="co:here")
assert isinstance(resp, LogLikelihoods)
assert isinstance(resp.prompt_tokens[0].encoded, int)
assert resp.completion_tokens is None
17 changes: 17 additions & 0 deletions tests/sync/test_loglikelihood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import unittest

from utils import get_api_key

import cohere
from cohere.responses import LogLikelihoods

TEST_MODEL = "command-light"
API_KEY = get_api_key()
co = cohere.Client(API_KEY)


class TestEmbed(unittest.TestCase):
def test_basic_llh(self):
resp = co.loglikelihood(model=TEST_MODEL, prompt="co:here", completion="co:where?")
assert isinstance(resp, LogLikelihoods)
assert isinstance(resp.prompt_tokens[0].encoded, int)

0 comments on commit ee55ca9

Please sign in to comment.