From 959338a374bcd291d57f3e768240e12b63ca7e7a Mon Sep 17 00:00:00 2001 From: Daniel Tscherwinka Date: Tue, 22 Oct 2024 10:26:37 +0200 Subject: [PATCH] feat: add chat feature --- src/chat.rs | 145 +++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 9 ++- tests/integration.rs | 25 +++++++- 3 files changed, 174 insertions(+), 5 deletions(-) create mode 100644 src/chat.rs diff --git a/src/chat.rs b/src/chat.rs new file mode 100644 index 0000000..360c74f --- /dev/null +++ b/src/chat.rs @@ -0,0 +1,145 @@ +use std::borrow::Cow; + +use serde::{Deserialize, Serialize}; + +use crate::Task; + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum Role { + System, + User, + Assistant, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +pub struct Message<'a> { + pub role: Role, + pub content: Cow<'a, str>, +} + +pub struct TaskChat<'a> { + /// The list of messages comprising the conversation so far. + pub messages: Vec>, + /// The maximum number of tokens to be generated. Completion will terminate after the maximum + /// number of tokens is reached. Increase this value to allow for longer outputs. A text is split + /// into tokens. Usually there are more tokens than words. The total number of tokens of prompt + /// and maximum_tokens depends on the model. + /// If maximum tokens is set to None, no outside limit is opposed on the number of maximum tokens. + /// The model will generate tokens until it generates one of the specified stop_sequences or it + /// reaches its technical limit, which usually is its context window. + pub maximum_tokens: Option, + /// A temperature encourages the model to produce less probable outputs ("be more creative"). + /// Values are expected to be between 0 and 1. Try high values for a more random ("creative") + /// response. + pub temperature: Option, + /// Introduces random sampling for generated tokens by randomly selecting the next token from + /// the smallest possible set of tokens whose cumulative probability exceeds the probability + /// top_p. Set to 0 to get the same behaviour as `None`. + pub top_p: Option, +} + +impl<'a> TaskChat<'a> { + /// Creates a new TaskChat containing one message with the given role and content. + /// All optional TaskChat attributes are left unset. + pub fn new(role: Role, content: impl Into>) -> Self { + TaskChat { + messages: vec![Message { + role, + content: content.into(), + }], + maximum_tokens: None, + temperature: None, + top_p: None, + } + } + + /// Pushes a new Message to this TaskChat. + pub fn append_message(mut self, role: Role, content: impl Into>) -> Self { + self.messages.push(Message { + role, + content: content.into(), + }); + self + } + + /// Sets the maximum token attribute of this TaskChat. + pub fn with_maximum_tokens(mut self, maximum_tokens: u32) -> Self { + self.maximum_tokens = Some(maximum_tokens); + self + } + + /// Sets the temperature attribute of this TaskChat. + pub fn with_temperature(mut self, temperature: f64) -> Self { + self.temperature = Some(temperature); + self + } + + /// Sets the top_p attribute of this TaskChat. + pub fn with_top_p(mut self, top_p: f64) -> Self { + self.top_p = Some(top_p); + self + } +} + +#[derive(Deserialize, Debug, PartialEq, Eq)] +pub struct Choice<'a> { + pub message: Message<'a>, + pub finish_reason: String, +} + +#[derive(Deserialize, Debug, PartialEq, Eq)] +pub struct ChatResponse<'a> { + pub choices: Vec>, +} +#[derive(Serialize)] +struct ChatBody<'a> { + /// Name of the model tasked with completing the prompt. E.g. `luminous-base"`. + pub model: &'a str, + /// The list of messages comprising the conversation so far. + messages: &'a [Message<'a>], + /// Limits the number of tokens, which are generated for the completion. + #[serde(skip_serializing_if = "Option::is_none")] + pub maximum_tokens: Option, + /// Controls the randomness of the model. Lower values will make the model more deterministic and higher values will make it more random. + /// Mathematically, the temperature is used to divide the logits before sampling. A temperature of 0 will always return the most likely token. + /// When no value is provided, the default value of 1 will be used. + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + /// "nucleus" parameter to dynamically adjust the number of choices for each predicted token based on the cumulative probabilities. It specifies a probability threshold, below which all less likely tokens are filtered out. + /// When no value is provided, the default value of 1 will be used. + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, +} + +impl<'a> ChatBody<'a> { + pub fn new(model: &'a str, task: &'a TaskChat) -> Self { + Self { + model, + messages: &task.messages, + maximum_tokens: task.maximum_tokens, + temperature: task.temperature, + top_p: task.top_p, + } + } +} + +impl<'a> Task for TaskChat<'a> { + type Output = Choice<'a>; + + type ResponseBody = ChatResponse<'a>; + + fn build_request( + &self, + client: &reqwest::Client, + base: &str, + model: &str, + ) -> reqwest::RequestBuilder { + let body = ChatBody::new(model, self); + client.post(format!("{base}/chat/completions")).json(&body) + } + + fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output { + response.choices.pop().unwrap() + } +} diff --git a/src/lib.rs b/src/lib.rs index 2ab4dbc..044251f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,6 +23,7 @@ //! } //! ``` +mod chat; mod completion; mod detokenization; mod explanation; @@ -31,7 +32,6 @@ mod image_preprocessing; mod prompt; mod semantic_embedding; mod tokenization; - use std::time::Duration; use http::HttpClient; @@ -39,6 +39,7 @@ use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput}; use tokenizers::Tokenizer; pub use self::{ + chat::{Role, TaskChat}, completion::{CompletionOutput, Sampling, Stopping, TaskCompletion}, detokenization::{DetokenizationOutput, TaskDetokenization}, explanation::{ @@ -305,7 +306,11 @@ impl Client { .await } - pub async fn tokenizer_by_model(&self, model: &str, api_token: Option) -> Result { + pub async fn tokenizer_by_model( + &self, + model: &str, + api_token: Option, + ) -> Result { self.http_client.tokenizer_by_model(model, api_token).await } } diff --git a/tests/integration.rs b/tests/integration.rs index 11ab979..33a76b1 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -2,8 +2,8 @@ use std::{fs::File, io::BufReader, sync::OnceLock}; use aleph_alpha_client::{ cosine_similarity, Client, Granularity, How, ImageScore, ItemExplanation, Modality, Prompt, - PromptGranularity, Sampling, SemanticRepresentation, Stopping, Task, - TaskBatchSemanticEmbedding, TaskCompletion, TaskDetokenization, TaskExplanation, + PromptGranularity, Role, Sampling, SemanticRepresentation, Stopping, Task, + TaskBatchSemanticEmbedding, TaskChat, TaskCompletion, TaskDetokenization, TaskExplanation, TaskSemanticEmbedding, TaskTokenization, TextScore, }; use dotenv::dotenv; @@ -18,6 +18,24 @@ fn api_token() -> &'static str { }) } +#[tokio::test] +async fn chat_with_pharia_1_7b_base() { + // When + let task = TaskChat::new(Role::System, "Instructions").append_message(Role::User, "Question"); + + let model = "pharia-1-llm-7b-control"; + let client = Client::with_authentication(api_token()).unwrap(); + let response = client + .output_of(&task.with_model(model), &How::default()) + .await + .unwrap(); + + eprintln!("{:?}", response.message); + + // Then + assert!(!response.message.content.is_empty()) +} + #[tokio::test] async fn completion_with_luminous_base() { // When @@ -538,4 +556,5 @@ async fn fetch_tokenizer_for_pharia_1_llm_7b() { // Then assert_eq!(128_000, tokenizer.get_vocab_size(true)); -} \ No newline at end of file +} +