Skip to content

Commit

Permalink
feat: add chat feature
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel-AAlpha authored and pacman82 committed Oct 22, 2024
1 parent 75e0b76 commit 959338a
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 5 deletions.
145 changes: 145 additions & 0 deletions src/chat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
use std::borrow::Cow;

use serde::{Deserialize, Serialize};

use crate::Task;

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Role {
System,
User,
Assistant,
}

#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct Message<'a> {
pub role: Role,
pub content: Cow<'a, str>,
}

pub struct TaskChat<'a> {
/// The list of messages comprising the conversation so far.
pub messages: Vec<Message<'a>>,
/// The maximum number of tokens to be generated. Completion will terminate after the maximum
/// number of tokens is reached. Increase this value to allow for longer outputs. A text is split
/// into tokens. Usually there are more tokens than words. The total number of tokens of prompt
/// and maximum_tokens depends on the model.
/// If maximum tokens is set to None, no outside limit is opposed on the number of maximum tokens.
/// The model will generate tokens until it generates one of the specified stop_sequences or it
/// reaches its technical limit, which usually is its context window.
pub maximum_tokens: Option<u32>,
/// A temperature encourages the model to produce less probable outputs ("be more creative").
/// Values are expected to be between 0 and 1. Try high values for a more random ("creative")
/// response.
pub temperature: Option<f64>,
/// Introduces random sampling for generated tokens by randomly selecting the next token from
/// the smallest possible set of tokens whose cumulative probability exceeds the probability
/// top_p. Set to 0 to get the same behaviour as `None`.
pub top_p: Option<f64>,
}

impl<'a> TaskChat<'a> {
/// Creates a new TaskChat containing one message with the given role and content.
/// All optional TaskChat attributes are left unset.
pub fn new(role: Role, content: impl Into<Cow<'a, str>>) -> Self {
TaskChat {
messages: vec![Message {
role,
content: content.into(),
}],
maximum_tokens: None,
temperature: None,
top_p: None,
}
}

/// Pushes a new Message to this TaskChat.
pub fn append_message(mut self, role: Role, content: impl Into<Cow<'a, str>>) -> Self {
self.messages.push(Message {
role,
content: content.into(),
});
self
}

/// Sets the maximum token attribute of this TaskChat.
pub fn with_maximum_tokens(mut self, maximum_tokens: u32) -> Self {
self.maximum_tokens = Some(maximum_tokens);
self
}

/// Sets the temperature attribute of this TaskChat.
pub fn with_temperature(mut self, temperature: f64) -> Self {
self.temperature = Some(temperature);
self
}

/// Sets the top_p attribute of this TaskChat.
pub fn with_top_p(mut self, top_p: f64) -> Self {
self.top_p = Some(top_p);
self
}
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct Choice<'a> {
pub message: Message<'a>,
pub finish_reason: String,
}

#[derive(Deserialize, Debug, PartialEq, Eq)]
pub struct ChatResponse<'a> {
pub choices: Vec<Choice<'a>>,
}
#[derive(Serialize)]
struct ChatBody<'a> {
/// Name of the model tasked with completing the prompt. E.g. `luminous-base"`.
pub model: &'a str,
/// The list of messages comprising the conversation so far.
messages: &'a [Message<'a>],
/// Limits the number of tokens, which are generated for the completion.
#[serde(skip_serializing_if = "Option::is_none")]
pub maximum_tokens: Option<u32>,
/// Controls the randomness of the model. Lower values will make the model more deterministic and higher values will make it more random.
/// Mathematically, the temperature is used to divide the logits before sampling. A temperature of 0 will always return the most likely token.
/// When no value is provided, the default value of 1 will be used.
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
/// "nucleus" parameter to dynamically adjust the number of choices for each predicted token based on the cumulative probabilities. It specifies a probability threshold, below which all less likely tokens are filtered out.
/// When no value is provided, the default value of 1 will be used.
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
}

impl<'a> ChatBody<'a> {
pub fn new(model: &'a str, task: &'a TaskChat) -> Self {
Self {
model,
messages: &task.messages,
maximum_tokens: task.maximum_tokens,
temperature: task.temperature,
top_p: task.top_p,
}
}
}

impl<'a> Task for TaskChat<'a> {
type Output = Choice<'a>;

type ResponseBody = ChatResponse<'a>;

fn build_request(
&self,
client: &reqwest::Client,
base: &str,
model: &str,
) -> reqwest::RequestBuilder {
let body = ChatBody::new(model, self);
client.post(format!("{base}/chat/completions")).json(&body)
}

fn body_to_output(&self, mut response: Self::ResponseBody) -> Self::Output {
response.choices.pop().unwrap()
}
}
9 changes: 7 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
//! }
//! ```

mod chat;
mod completion;
mod detokenization;
mod explanation;
Expand All @@ -31,14 +32,14 @@ mod image_preprocessing;
mod prompt;
mod semantic_embedding;
mod tokenization;

use std::time::Duration;

use http::HttpClient;
use semantic_embedding::{BatchSemanticEmbeddingOutput, SemanticEmbeddingOutput};
use tokenizers::Tokenizer;

pub use self::{
chat::{Role, TaskChat},
completion::{CompletionOutput, Sampling, Stopping, TaskCompletion},
detokenization::{DetokenizationOutput, TaskDetokenization},
explanation::{
Expand Down Expand Up @@ -305,7 +306,11 @@ impl Client {
.await
}

pub async fn tokenizer_by_model(&self, model: &str, api_token: Option<String>) -> Result<Tokenizer, Error> {
pub async fn tokenizer_by_model(
&self,
model: &str,
api_token: Option<String>,
) -> Result<Tokenizer, Error> {
self.http_client.tokenizer_by_model(model, api_token).await
}
}
Expand Down
25 changes: 22 additions & 3 deletions tests/integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use std::{fs::File, io::BufReader, sync::OnceLock};

use aleph_alpha_client::{
cosine_similarity, Client, Granularity, How, ImageScore, ItemExplanation, Modality, Prompt,
PromptGranularity, Sampling, SemanticRepresentation, Stopping, Task,
TaskBatchSemanticEmbedding, TaskCompletion, TaskDetokenization, TaskExplanation,
PromptGranularity, Role, Sampling, SemanticRepresentation, Stopping, Task,
TaskBatchSemanticEmbedding, TaskChat, TaskCompletion, TaskDetokenization, TaskExplanation,
TaskSemanticEmbedding, TaskTokenization, TextScore,
};
use dotenv::dotenv;
Expand All @@ -18,6 +18,24 @@ fn api_token() -> &'static str {
})
}

#[tokio::test]
async fn chat_with_pharia_1_7b_base() {
// When
let task = TaskChat::new(Role::System, "Instructions").append_message(Role::User, "Question");

let model = "pharia-1-llm-7b-control";
let client = Client::with_authentication(api_token()).unwrap();
let response = client
.output_of(&task.with_model(model), &How::default())
.await
.unwrap();

eprintln!("{:?}", response.message);

// Then
assert!(!response.message.content.is_empty())
}

#[tokio::test]
async fn completion_with_luminous_base() {
// When
Expand Down Expand Up @@ -538,4 +556,5 @@ async fn fetch_tokenizer_for_pharia_1_llm_7b() {

// Then
assert_eq!(128_000, tokenizer.get_vocab_size(true));
}
}

0 comments on commit 959338a

Please sign in to comment.