Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switches to chat conventions #7

Merged
merged 15 commits into from
Sep 11, 2024
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: mall
Title: Run multiple 'Large Language Model' predictions against a table, or
vectors
Version: 0.0.0.9003
Version: 0.0.0.9004
Authors@R:
person("Edgar", "Ruiz", , "first.last@example.com", role = c("aut", "cre"))
Description: Run multiple 'Large Language Model' predictions against a table. The
Expand All @@ -16,6 +16,7 @@ Imports:
cli,
dplyr,
glue,
jsonlite,
ollamar,
rlang
Suggests:
Expand Down
9 changes: 5 additions & 4 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ S3method(llm_sentiment,data.frame)
S3method(llm_summarize,"tbl_Spark SQL")
S3method(llm_summarize,data.frame)
S3method(llm_translate,data.frame)
S3method(m_backend_generate,mall_ollama)
S3method(m_backend_generate,mall_simulate_llm)
S3method(m_backend_prompt,mall_defaults)
S3method(m_backend_submit,mall_ollama)
S3method(m_backend_submit,mall_simulate_llm)
S3method(print,mall_defaults)
export(llm_classify)
export(llm_custom)
Expand All @@ -25,15 +25,16 @@ export(llm_vec_extract)
export(llm_vec_sentiment)
export(llm_vec_summarize)
export(llm_vec_translate)
export(m_backend_generate)
export(m_backend_prompt)
export(m_backend_submit)
import(cli)
import(glue)
import(rlang)
importFrom(dplyr,bind_cols)
importFrom(dplyr,mutate)
importFrom(dplyr,tibble)
importFrom(ollamar,generate)
importFrom(jsonlite,fromJSON)
importFrom(ollamar,chat)
importFrom(ollamar,list_models)
importFrom(ollamar,test_connection)
importFrom(utils,menu)
3 changes: 2 additions & 1 deletion R/llm-classify.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ llm_vec_classify <- function(x,
labels,
additional_prompt = "") {
llm_vec_prompt(
x = x, prompt_label = "classify",
x = x,
prompt_label = "classify",
additional_prompt = additional_prompt,
labels = labels,
valid_resps = labels
Expand Down
17 changes: 13 additions & 4 deletions R/llm-custom.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
llm_custom <- function(
.data,
col,
prompt,
prompt = "",
pred_name = ".pred",
valid_resps = "") {
UseMethod("llm_custom")
Expand All @@ -24,7 +24,7 @@ llm_custom <- function(
#' @export
llm_custom.data.frame <- function(.data,
col,
prompt,
prompt = "",
pred_name = ".pred",
valid_resps = NULL) {
mutate(
Expand All @@ -39,9 +39,18 @@ llm_custom.data.frame <- function(.data,

#' @rdname llm_custom
#' @export
llm_vec_custom <- function(x, prompt, valid_resps = NULL) {
llm_vec_custom <- function(x, prompt = "", valid_resps = NULL) {
llm_use(.silent = TRUE, force = FALSE)
resp <- m_backend_generate(defaults_get(), x, prompt)
if (!inherits(prompt, "list")) {
p_split <- strsplit(prompt, "\\{\\{x\\}\\}")[[1]]
if (length(p_split) == 1 && p_split == prompt) {
content <- glue("{prompt}\n{{x}}")
} else {
content <- prompt
}
prompt <- list(list(role = "user", content = content))
}
resp <- m_backend_submit(defaults_get(), x, prompt)
if (!is.null(valid_resps)) {
errors <- !resp %in% valid_resps
resp[errors] <- NA
Expand Down
8 changes: 6 additions & 2 deletions R/llm-extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ llm_extract.data.frame <- function(.data,
resp <- map(
resp,
\(x) ({
x <- trimws(strsplit(x, "\\|")[[1]])
x <- strsplit(x, "\\|")[[1]]
names(x) <- clean_names(labels)
x
})
Expand Down Expand Up @@ -76,10 +76,14 @@ llm_extract.data.frame <- function(.data,
llm_vec_extract <- function(x,
labels = c(),
additional_prompt = "") {
llm_vec_prompt(
resp <- llm_vec_prompt(
x = x,
prompt_label = "extract",
labels = labels,
additional_prompt = additional_prompt
)
map_chr(
resp,
\(x) paste0(as.character(fromJSON(x, flatten = TRUE)), collapse = "|")
)
}
114 changes: 73 additions & 41 deletions R/m-backend-prompt.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#' @rdname m_backend_generate
#' @rdname m_backend_submit
#' @export
m_backend_prompt <- function(backend, additional) {
UseMethod("m_backend_prompt")
Expand All @@ -9,60 +9,92 @@ m_backend_prompt.mall_defaults <- function(backend, additional = "") {
list(
sentiment = function(options) {
options <- paste0(options, collapse = ", ")
x <- glue(paste(
"You are a helpful sentiment engine.",
"Return only one of the following answers: {options}.",
"No capitalization. No explanations.",
additional,
"The answer is based on the following text:"
))
list(
list(
role = "user",
content = glue(paste(
"You are a helpful sentiment engine.",
"Return only one of the following answers: {options}.",
"No capitalization. No explanations.",
"{additional}",
"The answer is based on the following text:\n{{x}}"
))
)
)
},
summarize = function(max_words) {
glue(paste(
"You are a helpful summarization engine.",
"Your answer will contain no no capitalization and no explanations.",
"Return no more than {max_words} words.",
additional,
"The answer is the summary of the following text:"
))
list(
list(
role = "user",
content = glue(paste(
"You are a helpful summarization engine.",
"Your answer will contain no no capitalization and no explanations.",
"Return no more than {max_words} words.",
"{additional}",
"The answer is the summary of the following text:\n{{x}}"
))
)
)
},
classify = function(labels) {
labels <- paste0(labels, collapse = ", ")
glue(paste(
"You are a helpful classification engine.",
"Determine if the text refers to one of the following: {labels}.",
"No capitalization. No explanations.",
additional,
"The answer is based on the following text:"
))
list(
list(
role = "user",
content = glue(paste(
"You are a helpful classification engine.",
"Determine if the text refers to one of the following: {labels}.",
"No capitalization. No explanations.",
"{additional}",
"The answer is based on the following text:\n{{x}}"
))
)
)
},
extract = function(labels) {
no_labels <- length(labels)
labels <- paste0(labels, collapse = ", ")
glue(paste(
"You are a helpful text extraction engine.",
"Extract the {labels} being referred to on the text.",
"I expect {no_labels} item(s) exactly.",
"No capitalization. No explanations.",
"Return the response in a simple pipe separated list, no headers.",
additional,
"The answer is based on the following text:"
))
col_labels <- paste0(labels, collapse = ", ")
json_labels <- paste0("\"", labels, "\":your answer", collapse = ",")
json_labels <- paste0("{{", json_labels, "}}")
plural <- ifelse(no_labels > 1, "s", "")
list(
list(
role = "system",
content = "You only speak simple JSON. Do not write normal text."
),
list(
role = "user",
content = glue(paste(
"You are a helpful text extraction engine.",
"Extract the {col_labels} being referred to on the text.",
"I expect {no_labels} item{plural} exactly.",
"No capitalization. No explanations.",
"You will use this JSON this format exclusively: {json_labels} .",
"{additional}",
"The answer is based on the following text:\n{{x}}"
))
)
)
},
translate = function(language) {
glue(paste(
"You are a helpful translation engine.",
"You will return only the translation text, no explanations.",
"The target language to translate to is: {language}.",
additional,
"The answer is the summary of the following text:"
))
list(
list(
role = "user",
content = glue(paste(
"You are a helpful translation engine.",
"You will return only the translation text, no explanations.",
"The target language to translate to is: {language}.",
"{additional}",
"The answer is the summary of the following text:\n{{x}}"
))
)
)
}
)
}

get_prompt <- function(label, ..., .additional = "") {
defaults <- m_backend_prompt(defaults_get(), .additional)
defaults <- m_backend_prompt(defaults_get(), additional = .additional)
fn <- defaults[[label]]
fn(...)
}
Expand All @@ -75,5 +107,5 @@ llm_vec_prompt <- function(x,
...) {
llm_use(.silent = TRUE, force = FALSE)
prompt <- get_prompt(prompt_label, ..., .additional = additional_prompt)
llm_vec_custom(x, prompt, valid_resps)
llm_vec_custom(x, prompt, valid_resps = valid_resps)
}
16 changes: 8 additions & 8 deletions R/m-backend-generate.R → R/m-backend-submit.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,37 @@
#'
#' @param backend An `mall_defaults` object
#' @param x The body of the text to be submitted to the LLM
#' @param base_prompt The instructions to the LLM about what to do with `x`
#' @param prompt The additional information to add to the submission
#' @param additional Additional text to insert to the `base_prompt`
#'
#' @returns `m_backend_generate` does not return an object. `m_backend_prompt`
#' @returns `m_backend_submit` does not return an object. `m_backend_prompt`
#' returns a list of functions that contain the base prompts.
#'
#' @keywords internal
#' @export
m_backend_generate <- function(backend, x, base_prompt) {
UseMethod("m_backend_generate")
m_backend_submit <- function(backend, x, prompt) {
UseMethod("m_backend_submit")
}

#' @export
m_backend_generate.mall_ollama <- function(backend, x, base_prompt) {
m_backend_submit.mall_ollama <- function(backend, x, prompt) {
args <- as.list(backend)
args$backend <- NULL
map_chr(
x,
\(x) {
.args <- c(
prompt = glue("{base_prompt}\n{x}"),
messages = list(map(prompt, \(i) map(i, \(j) glue(j, x = x)))),
output = "text",
args
)
exec("generate", !!!.args)
exec("chat", !!!.args)
}
)
}

#' @export
m_backend_generate.mall_simulate_llm <- function(backend, x, base_prompt) {
m_backend_submit.mall_simulate_llm <- function(backend, x, base_prompt) {
args <- backend
class(args) <- "list"
if (args$model == "pipe") {
Expand Down
3 changes: 2 additions & 1 deletion R/mall.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#' @importFrom ollamar generate test_connection list_models
#' @importFrom ollamar chat test_connection list_models
#' @importFrom dplyr mutate tibble bind_cols
#' @importFrom utils menu
#' @importFrom jsonlite fromJSON
#' @import rlang
#' @import glue
#' @import cli
Expand Down
4 changes: 2 additions & 2 deletions man/llm_custom.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 9 additions & 9 deletions man/m_backend_generate.Rd → man/m_backend_submit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions tests/testthat/test-llm-classify.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
test_that("Classify works", {
test_text <- "this is a test"
llm_use("simulate_llm", "echo", .silent = TRUE)
llm_use("simulate_llm", "echo", .silent = TRUE)
expect_equal(
llm_vec_classify(test_text, labels = test_text),
test_text
Expand All @@ -14,14 +14,14 @@ test_that("Classify works", {
llm_classify(data.frame(x = test_text), x, labels = test_text),
data.frame(x = test_text, .classify = test_text)
)

expect_equal(
llm_classify(
data.frame(x = test_text),
x,
labels = test_text,
data.frame(x = test_text),
x,
labels = test_text,
pred_name = "new"
),
),
data.frame(x = test_text, new = test_text)
)
})
Loading