Skip to content

Commit

Permalink
Merge pull request #811 from stan-dev/qol-improvements
Browse files Browse the repository at this point in the history
Bugfixes in .stanfunctions, hessian model method, and exposing RNG functions
  • Loading branch information
andrjohns authored Aug 23, 2023
2 parents e4ff5d4 + 7fb88b6 commit 2b04e4f
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 11 deletions.
4 changes: 2 additions & 2 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ CmdStanModel <- R6::R6Class(
self$functions <- new.env()
self$functions$compiled <- FALSE
if (!is.null(stan_file)) {
assert_file_exists(stan_file, access = "r", extension = "stan")
assert_file_exists(stan_file, access = "r", extension = c("stan", "stanfunctions"))
checkmate::assert_flag(compile)
private$stan_file_ <- absolute_path(stan_file)
private$stan_code_ <- readLines(stan_file)
Expand Down Expand Up @@ -537,7 +537,7 @@ compile <- function(quiet = TRUE,
compile_hessian_method <- FALSE
}

temp_stan_file <- tempfile(pattern = "model-", fileext = ".stan")
temp_stan_file <- tempfile(pattern = "model-", fileext = paste0(".", tools::file_ext(self$stan_file())))
file.copy(self$stan_file(), temp_stan_file, overwrite = TRUE)
temp_file_no_ext <- strip_ext(temp_stan_file)
tmp_exe <- cmdstan_ext(temp_file_no_ext) # adds .exe on Windows
Expand Down
36 changes: 30 additions & 6 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,8 @@ expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
package = "cmdstanr", mustWork = TRUE)))

if (hessian) {
code <- c(code,
code <- c("#include <stan/math/mix.hpp>",
code,
readLines(system.file("include", "hessian.cpp",
package = "cmdstanr", mustWork = TRUE)))
}
Expand All @@ -758,9 +759,8 @@ expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
invisible(NULL)
}

initialize_model_pointer <- function(env, data, seed = 0) {
datafile_path <- ifelse(is.null(data), "", data)
ptr_and_rng <- env$model_ptr(datafile_path, seed)
initialize_model_pointer <- function(env, datafile_path, seed = 0) {
ptr_and_rng <- env$model_ptr(ifelse(is.null(datafile_path), "", datafile_path), seed)
env$model_ptr_ <- ptr_and_rng$model_ptr
env$model_rng_ <- ptr_and_rng$base_rng
env$num_upars_ <- env$get_num_upars(env$model_ptr_)
Expand Down Expand Up @@ -863,8 +863,8 @@ prep_fun_cpp <- function(fun_start, fun_end, model_lines) {
fun_body <- gsub("auto", get_plain_rtn(fun_start, fun_end, model_lines), fun_body)
fun_body <- gsub("// [[stan::function]]", "// [[Rcpp::export]]\n", fun_body, fixed = TRUE)
fun_body <- gsub("std::ostream\\*\\s*pstream__\\s*=\\s*nullptr", "", fun_body)
fun_body <- gsub("boost::ecuyer1988& base_rng__", "size_t seed = 0", fun_body, fixed = TRUE)
fun_body <- gsub("base_rng__,", "*(new boost::ecuyer1988(seed)),", fun_body, fixed = TRUE)
fun_body <- gsub("boost::ecuyer1988&\\s*base_rng__", "SEXP base_rng_ptr", fun_body)
fun_body <- gsub("base_rng__,", "*(Rcpp::XPtr<boost::ecuyer1988>(base_rng_ptr).get()),", fun_body, fixed = TRUE)
fun_body <- gsub("pstream__", "&Rcpp::Rcout", fun_body, fixed = TRUE)
fun_body <- paste(fun_body, collapse = "\n")
gsub(pattern = ",\\s*)", replacement = ")", fun_body)
Expand Down Expand Up @@ -904,6 +904,30 @@ compile_functions <- function(env, verbose = FALSE, global = FALSE) {
} else {
rcpp_source_stan(mod_stan_funs, env, verbose)
}

# If an RNG function is exposed, initialise a Boost RNG object stored in the
# environment
rng_funs <- grep("rng\\b", env$fun_names, value = TRUE)
if (length(rng_funs) > 0) {
rng_cpp <- system.file("include", "base_rng.cpp", package = "cmdstanr", mustWork = TRUE)
rcpp_source_stan(paste0(readLines(rng_cpp), collapse="\n"), env, verbose)
env$rng_ptr <- env$base_rng(seed=0)
}

# For all RNG functions, pass the initialised Boost RNG by default
for (fun in rng_funs) {
if (global) {
fun_env <- globalenv()
} else {
fun_env <- env
}
fundef <- get(fun, envir = fun_env)
funargs <- formals(fundef)
funargs$base_rng_ptr <- env$rng_ptr
formals(fundef) <- funargs
assign(fun, fundef, envir = fun_env)
}

env$compiled <- TRUE
invisible(NULL)
}
Expand Down
8 changes: 8 additions & 0 deletions inst/include/base_rng.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include <Rcpp.h>
#include <boost/random/additive_combine.hpp>

// [[Rcpp::export]]
SEXP base_rng(boost::uint32_t seed = 0) {
Rcpp::XPtr<boost::ecuyer1988> rng_ptr(new boost::ecuyer1988(seed));
return rng_ptr;
}
12 changes: 9 additions & 3 deletions tests/testthat/test-model-expose-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ test_that("Functions can be compiled with model", {

test_that("rng functions can be exposed", {
skip_if(os_is_wsl())
function_decl <- "functions { real normal_rng(real mu) { return normal_rng(mu, 1); } }"
function_decl <- "functions { real wrap_normal_rng(real mu, real sigma) { return normal_rng(mu, sigma); } }"
stan_prog <- paste(function_decl,
paste(readLines(testing_stan_file("bernoulli")),
collapse = "\n"),
Expand All @@ -122,11 +122,17 @@ test_that("rng functions can be exposed", {
mod <- cmdstan_model(model, force_recompile = TRUE)
fit <- mod$sample(data = data_list)

set.seed(10)
fit$expose_functions(verbose = TRUE)

expect_equal(
fit$functions$normal_rng(5, seed = 10),
3.8269637967017344771
fit$functions$wrap_normal_rng(5,10),
-4.5298764235381225873
)

expect_equal(
fit$functions$wrap_normal_rng(5,10),
8.1295902610102039887
)
})

Expand Down

0 comments on commit 2b04e4f

Please sign in to comment.