Skip to content

Commit

Permalink
squash! llama : return nullptr from llama_grammar_init
Browse files Browse the repository at this point in the history
Add checks for nullptr when calling llama_grammar_init.

Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>
  • Loading branch information
danbev committed Jun 25, 2024
1 parent 6189bce commit 35d9680
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
12 changes: 10 additions & 2 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_

std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());

result->grammar = llama_grammar_init(
struct llama_grammar * grammar = llama_grammar_init(
grammar_rules.data(),
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
if (grammar == nullptr) {
throw std::runtime_error("Failed to initialize llama_grammar");
}
result->grammar = grammar;
}

result->prev.resize(params.n_prev);
Expand Down Expand Up @@ -59,9 +63,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
if (!ctx->parsed_grammar.rules.empty()) {
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());

ctx->grammar = llama_grammar_init(
struct llama_grammar * grammar = llama_grammar_init(
grammar_rules.data(),
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
if (grammar == nullptr) {
throw std::runtime_error("Failed to initialize llama_grammar");
}
ctx->grammar = grammar;
}

std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
Expand Down
4 changes: 3 additions & 1 deletion examples/gbnf-validator/gbnf-validator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ int main(int argc, char** argv) {
auto grammar = llama_grammar_init(
grammar_rules.data(),
grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));

if (grammar == nullptr) {
throw std::runtime_error("Failed to initialize llama_grammar");
}
// Read the input file
std::string input_str;
{
Expand Down
4 changes: 4 additions & 0 deletions tests/test-llama-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ int main()
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
if (grammar == nullptr)
{
throw std::runtime_error("Failed to initialize llama_grammar");
}

std::vector<std::vector<llama_grammar_element>> expected_stacks = {
{
Expand Down

0 comments on commit 35d9680

Please sign in to comment.