diff --git a/common/common.cpp b/common/common.cpp index 57d03a5789edd..6a8973d9b2759 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -472,6 +472,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa else { invalid_param = true; } return true; } + if (arg == "--attention") { + CHECK_ARG + std::string value(argv[i]); + /**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; } + else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NON_CAUSAL; } + else { invalid_param = true; } + return true; + } if (arg == "--defrag-thold" || arg == "-dt") { CHECK_ARG params.defrag_thold = std::stof(argv[i]); @@ -1454,8 +1462,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead" }); options.push_back({ "embedding" }); - options.push_back({ "embedding", " --pooling {none,mean,cls}", + options.push_back({ "embedding", " --pooling {none,mean,cls,last}", "pooling type for embeddings, use model default if unspecified" }); + options.push_back({ "embedding", " --attention {causal,non-causal}", + "attention type for embeddings, use model default if unspecified" }); options.push_back({ "context hacking" }); options.push_back({ "*", " --rope-scaling {none,linear,yarn}", @@ -2144,6 +2154,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.pooling_type = params.pooling_type; + cparams.attention_type = params.attention_type; cparams.defrag_thold = params.defrag_thold; cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; diff --git a/common/common.h b/common/common.h index 0486ba3800ed7..a22f3fc893f42 100644 --- a/common/common.h +++ b/common/common.h @@ -99,6 +99,7 @@ struct gpt_params { enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings + enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings // // sampling parameters struct llama_sampling_params sparams; diff --git a/include/llama.h b/include/llama.h index cafeafb85dbc7..a8eb0a88a8f5f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -179,6 +179,12 @@ extern "C" { LLAMA_POOLING_TYPE_LAST = 3, }; + enum llama_attention_type { + LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1, + LLAMA_ATTENTION_TYPE_CAUSAL = 0, + LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1, + }; + enum llama_split_mode { LLAMA_SPLIT_MODE_NONE = 0, // single GPU LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs @@ -296,6 +302,7 @@ extern "C" { enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id + enum llama_attention_type attention_type; // attention type to use for embeddings // ref: https://github.com/ggerganov/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model diff --git a/src/llama.cpp b/src/llama.cpp index 988ed4fdfc55d..3ca40bcef0924 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -12715,7 +12715,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(lctx.inp_mean); @@ -12747,7 +12747,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(lctx.inp_cls); @@ -12768,7 +12768,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(lctx.inp_cls); @@ -12990,14 +12990,15 @@ static int llama_decode_internal( std::vector seq_id_arr; std::vector> seq_id; + // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens + const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; + // count outputs - if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) { - n_outputs = n_tokens_all; - } else if (batch_all.logits) { + if (batch_all.logits && !embd_pooled) { for (uint32_t i = 0; i < n_tokens_all; ++i) { n_outputs += batch_all.logits[i] != 0; } - } else if (lctx.logits_all) { + } else if (lctx.logits_all || embd_pooled) { n_outputs = n_tokens_all; } else { // keep last output only @@ -13043,7 +13044,7 @@ static int llama_decode_internal( { int32_t n_outputs_new = 0; - if (u_batch.logits) { + if (u_batch.logits && !embd_pooled) { for (uint32_t i = 0; i < n_tokens; i++) { n_outputs_new += u_batch.logits[i] != 0; } @@ -17202,6 +17203,7 @@ struct llama_context_params llama_context_default_params() { /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, + /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, /*.yarn_ext_factor =*/ -1.0f, @@ -17448,7 +17450,6 @@ struct llama_context * llama_new_context_with_model( } cparams.yarn_attn_factor *= hparams.rope_attn_factor; - cparams.causal_attn = hparams.causal_attn; if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { @@ -17458,6 +17459,12 @@ struct llama_context * llama_new_context_with_model( } } + if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) { + cparams.causal_attn = hparams.causal_attn; + } else { + cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; + } + if (params.seed == LLAMA_DEFAULT_SEED) { params.seed = time(NULL); }