Skip to content

Commit

Permalink
fix llh return object and add jupyter visualization (#281)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanderland authored Nov 13, 2023
1 parent 07a3863 commit 55666c6
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion cohere/responses/loglikelihood.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Dict, List, NamedTuple, Optional

from cohere.responses.base import CohereObject, _df_html, _escape_html

TokenLogLikelihood = NamedTuple("TokenLogLikelihood", [("encoded", int), ("decoded", str), ("log_likelihood", float)])


class LogLikelihoods:
class LogLikelihoods(CohereObject):
@staticmethod
def token_list_from_dict(token_list: Optional[List[Dict]]):
if token_list is None:
Expand All @@ -13,3 +15,21 @@ def token_list_from_dict(token_list: Optional[List[Dict]]):
def __init__(self, prompt_tokens: List[TokenLogLikelihood], completion_tokens: List[TokenLogLikelihood]):
self.prompt_tokens = self.token_list_from_dict(prompt_tokens)
self.completion_tokens = self.token_list_from_dict(completion_tokens)

def visualize(self, **kwargs):
import pandas as pd

dfs = []
for lbl, tokens in [("prompt_tokens", self.prompt_tokens), ("completion_tokens", self.completion_tokens)]:
if tokens is not None:
dfs.append(
pd.DataFrame.from_dict(
{
lbl + ".decoded": [_escape_html(t.decoded) for t in tokens],
lbl + ".encoded": [t.encoded for t in tokens],
lbl + ".log_likelihood": [t.log_likelihood for t in tokens],
},
orient="index",
)
)
return _df_html(pd.concat(dfs, axis=0).fillna(""), style={"font-size": "90%"})

0 comments on commit 55666c6

Please sign in to comment.