Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix llh return object and add jupyter visualization #281

Merged
merged 8 commits into from
Nov 13, 2023
22 changes: 21 additions & 1 deletion cohere/responses/loglikelihood.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Dict, List, NamedTuple, Optional

from cohere.responses.base import CohereObject, _df_html, _escape_html

TokenLogLikelihood = NamedTuple("TokenLogLikelihood", [("encoded", int), ("decoded", str), ("log_likelihood", float)])


class LogLikelihoods:
class LogLikelihoods(CohereObject):
@staticmethod
def token_list_from_dict(token_list: Optional[List[Dict]]):
if token_list is None:
Expand All @@ -13,3 +15,21 @@ def token_list_from_dict(token_list: Optional[List[Dict]]):
def __init__(self, prompt_tokens: List[TokenLogLikelihood], completion_tokens: List[TokenLogLikelihood]):
self.prompt_tokens = self.token_list_from_dict(prompt_tokens)
self.completion_tokens = self.token_list_from_dict(completion_tokens)

def visualize(self, **kwargs):
import pandas as pd
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should handle gracefully if user doesn't have pandas installed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


dfs = []
for lbl, tokens in [("prompt_tokens", self.prompt_tokens), ("completion_tokens", self.completion_tokens)]:
if tokens is not None:
dfs.append(
pd.DataFrame.from_dict(
{
lbl + ".decoded": [_escape_html(t.decoded) for t in tokens],
lbl + ".encoded": [t.encoded for t in tokens],
lbl + ".log_likelihood": [t.log_likelihood for t in tokens],
},
orient="index",
)
)
return _df_html(pd.concat(dfs, axis=0).fillna(""), style={"font-size": "90%"})
Loading