Skip to content

Commit

Permalink
Merge pull request #843 from stan-dev/fix-old-array-syntax
Browse files Browse the repository at this point in the history
Compatibility fixes for cmdstan 2.33+
  • Loading branch information
andrjohns authored Sep 13, 2023
2 parents 5e2551c + 1caa732 commit b29374e
Show file tree
Hide file tree
Showing 19 changed files with 61 additions and 188 deletions.
2 changes: 1 addition & 1 deletion R/example.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ print_example_program <-
#' stan_program <- "
#' data {
#' int<lower=0> N;
#' int<lower=0,upper=1> y[N];
#' array[N] int<lower=0,upper=1> y;
#' }
#' parameters {
#' real<lower=0,upper=1> theta;
Expand Down
4 changes: 2 additions & 2 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -1158,7 +1158,7 @@ CmdStanFit$set("public", name = "return_codes", value = return_codes)
#' mcmc_program <- write_stan_file(
#' 'data {
#' int<lower=0> N;
#' int<lower=0,upper=1> y[N];
#' array[N] int<lower=0,upper=1> y;
#' }
#' parameters {
#' real<lower=0,upper=1> theta;
Expand All @@ -1169,7 +1169,7 @@ CmdStanFit$set("public", name = "return_codes", value = return_codes)
#' }
#' }
#' generated quantities {
#' int y_rep[N];
#' array[N] int y_rep;
#' profile("gq") {
#' y_rep = bernoulli_rng(rep_vector(theta, N));
#' }
Expand Down
10 changes: 5 additions & 5 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ CmdStanModel$set("public", name = "variables", value = variables)
#' file <- write_stan_file("
#' data {
#' int N;
#' int y[N];
#' array[N] int y;
#' }
#' parameters {
#' // should have <lower=0> but omitting to demonstrate pedantic mode
Expand Down Expand Up @@ -932,7 +932,7 @@ CmdStanModel$set("public", name = "check_syntax", value = check_syntax)
#' file <- write_stan_file("
#' data {
#' int N;
#' int y[N];
#' array[N] int y;
#' }
#' parameters {
#' real lambda;
Expand Down Expand Up @@ -1659,7 +1659,7 @@ CmdStanModel$set("public", name = "variational", value = variational)
#' mcmc_program <- write_stan_file(
#' "data {
#' int<lower=0> N;
#' int<lower=0,upper=1> y[N];
#' array[N] int<lower=0,upper=1> y;
#' }
#' parameters {
#' real<lower=0,upper=1> theta;
Expand All @@ -1678,13 +1678,13 @@ CmdStanModel$set("public", name = "variational", value = variational)
#' gq_program <- write_stan_file(
#' "data {
#' int<lower=0> N;
#' int<lower=0,upper=1> y[N];
#' array[N] int<lower=0,upper=1> y;
#' }
#' parameters {
#' real<lower=0,upper=1> theta;
#' }
#' generated quantities {
#' int y_rep[N] = bernoulli_rng(rep_vector(theta, N));
#' array[N] int y_rep = bernoulli_rng(rep_vector(theta, N));
#' }"
#' )
#'
Expand Down
19 changes: 17 additions & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,20 @@ get_standalone_hpp <- function(stan_file, stancflags) {

get_function_name <- function(fun_start, fun_end, model_lines) {
fun_string <- paste(model_lines[(fun_start+1):fun_end], collapse = " ")
fun_name <- gsub("auto ", "", fun_string, fixed = TRUE)
types <- c(
"auto",
"int",
"double",
"Eigen::Matrix<(.*)>",
"std::vector<(.*)>"
)
pattern <- paste0(
# Only match if the type occurs at start of string
"^(\\s*)?(",
paste0(types, collapse="|"),
# Only match if type followed by a function name and opening bracket
")\\s*(?=\\w*\\()")
fun_name <- gsub(pattern, "", fun_string, perl = TRUE)
sub("\\(.*", "", fun_name, perl = TRUE)
}

Expand Down Expand Up @@ -864,7 +877,9 @@ get_plain_rtn <- function(fun_start, fun_end, model_lines) {
# that instantiates an RNG
prep_fun_cpp <- function(fun_start, fun_end, model_lines) {
fun_body <- paste(model_lines[fun_start:fun_end], collapse = " ")
fun_body <- gsub("auto", get_plain_rtn(fun_start, fun_end, model_lines), fun_body)
if (cmdstan_version() < "2.33") {
fun_body <- gsub("auto", get_plain_rtn(fun_start, fun_end, model_lines), fun_body)
}
fun_body <- gsub("// [[stan::function]]", "// [[Rcpp::export]]\n", fun_body, fixed = TRUE)
fun_body <- gsub("std::ostream\\*\\s*pstream__\\s*=\\s*nullptr", "", fun_body)
fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr", fun_body)
Expand Down
1 change: 0 additions & 1 deletion _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ articles:
- cmdstanr-internals
- posterior
- r-markdown
- deprecations
- profiling
- articles-online-only/opencl

Expand Down
6 changes: 3 additions & 3 deletions man/CmdStanGQ.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/fit-method-profiles.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/model-method-check_syntax.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/model-method-format.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/model-method-generate-quantities.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/write_stan_file.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ test_that("process_data() corrrectly casts integers and floating point numbers",

stan_file <- write_stan_file("
data {
int<lower=0> k[3,3];
array[3,3] int<lower=0> k;
}
")
mod <- cmdstan_model(stan_file, compile = FALSE)
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-example.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ test_that("cmdstanr_example works", {
stan_program <- "
data {
int<lower=0> N;
int<lower=0,upper=1> y[N];
array[N] int<lower=0,upper=1> y;
}
parameters {
real<lower=0,upper=1> theta;
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-fit-mle.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ test_that("time is reported after optimization", {

test_that("no error when checking estimates after failure", {
fit <- cmdstanr_example("schools", method = "optimize", seed = 123) # optim ålways fails for this
expect_silent(fit$summary()) # no error
expect_error(fit$summary(), "Fitting failed. Unable to retrieve the draws.")
})

test_that("draws() works for different formats", {
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-fit-shared.R
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ test_that("sig_figs works with all methods", {
m <- "data {
int<lower=0> N;
int<lower=0> K;
int<lower=0,upper=1> y[N];
array[N] int<lower=0,upper=1> y;
matrix[N, K] X;
}
parameters {
Expand Down
41 changes: 15 additions & 26 deletions tests/testthat/test-model-compile.R
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ test_that("compile() works with pedantic=TRUE", {
}
")
expect_message(
mod_pedantic_warn <- cmdstan_model(stan_file, pedantic = TRUE),
mod_pedantic_warn <- cmdstan_model(stan_file, pedantic = TRUE, force_recompile = TRUE),
"The parameter x was declared but was not used",
fixed = TRUE
)
Expand Down Expand Up @@ -387,13 +387,10 @@ test_that("check_syntax() works with pedantic=TRUE", {
fixed = TRUE
)

expect_output(
expect_message(
mod_pedantic_warn$check_syntax(pedantic = TRUE),
"The parameter x was declared but was not used",
fixed = TRUE
),
regexp = NA
expect_message(
mod_pedantic_warn$check_syntax(pedantic = TRUE),
"The parameter x was declared but was not used",
fixed = TRUE
)
})

Expand Down Expand Up @@ -424,14 +421,14 @@ test_that("check_syntax() works with pedantic=TRUE", {
"
stan_file <- write_stan_file(model_code)
mod_dep_warning <- cmdstan_model(stan_file, compile = FALSE)
expect_message(
expect_error(
mod_dep_warning$compile(),
"deprecated in the Stan language",
"An error occured during compilation! See the message above for more information.",
fixed = TRUE
)
expect_message(
expect_error(
mod_dep_warning$check_syntax(),
"deprecated in the Stan language",
"Syntax error found! See the message above for more information.",
fixed = TRUE
)
})
Expand Down Expand Up @@ -690,13 +687,9 @@ test_that("format() works", {
stan_file_tmp <- write_stan_file(code)
mod_1 <- cmdstan_model(stan_file_tmp, compile = FALSE)

expect_output(
expect_message(
mod_1$format(),
"is deprecated",
fixed = TRUE
),
"target += normal_log(y, 0, 1);",
expect_error(
mod_1$format(),
"Syntax error found! See the message above for more information.",
fixed = TRUE
)

Expand All @@ -710,13 +703,9 @@ test_that("format() works", {
"target += normal_lpdf(y | 0, 1);",
fixed = TRUE
)
expect_output(
expect_message(
mod_1$format(canonicalize = list("includes")),
"is deprecated",
fixed = TRUE
),
"target += normal_log(y, 0, 1);",
expect_error(
mod_1$format(),
"Syntax error found! See the message above for more information.",
fixed = TRUE
)

Expand Down
3 changes: 2 additions & 1 deletion tests/testthat/test-model-expose-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ test_that("Exposing functions with precompiled model gives meaningful error", {
parameters { real x; }
model { x ~ std_normal(); }
")
mod1 <- cmdstan_model(stan_file, compile_standalone = TRUE)
mod1 <- cmdstan_model(stan_file, compile_standalone = TRUE,
force_recompile = TRUE)
expect_equal(7.5, mod1$functions$a_plus_b(5, 2.5))

mod2 <- cmdstan_model(stan_file)
Expand Down
8 changes: 4 additions & 4 deletions tests/testthat/test-model-sample_mpi.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ test_that("sample_mpi() works", {
skip_if(!mpi_toolchain_present())
mpi_file <- write_stan_file("
functions {
vector test(vector beta, vector theta, real[] x, int[] y) {
vector test(vector beta, vector theta, array[] real x, array[] int y) {
return theta;
}
}
transformed data {
vector[4] a;
vector[5] b[4] = {[1,1,1,1,1]', [2,2,2,2,2]', [3,3,3,3,3]', [4,4,4,4,4]'};
real x[4,4];
int y[4,4];
array[4] vector[5] b = {[1,1,1,1,1]', [2,2,2,2,2]', [3,3,3,3,3]', [4,4,4,4,4]'};
array[4,4] real x;
array[4,4] int y;
}
parameters {
real beta;
Expand Down
Loading

0 comments on commit b29374e

Please sign in to comment.