Skip to content

Commit

Permalink
Fix incorrect sizing for unconstrain_draws (#983)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrjohns authored May 24, 2024
1 parent 499aa23 commit f141d06
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 5 deletions.
9 changes: 8 additions & 1 deletion R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,9 @@ CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_var
#'
unconstrain_draws <- function(files = NULL, draws = NULL,
format = getOption("cmdstanr_draws_format", "draws_array")) {
if (!(format %in% valid_draws_formats())) {
stop("Invalid draws format requested!", call. = FALSE)
}
if (!is.null(files) || !is.null(draws)) {
if (!is.null(files) && !is.null(draws)) {
stop("Either a list of CSV files or a draws object can be passed, not both",
Expand All @@ -582,6 +585,8 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
}
draws <- maybe_convert_draws_format(private$draws_, "draws_matrix")
}

chains <- posterior::nchains(draws)

model_par_names <- self$metadata()$stan_variables[self$metadata()$stan_variables != "lp__"]
model_variables <- self$runset$args$model_variables
Expand All @@ -598,7 +603,9 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
unconstrained <- private$model_methods_env_$unconstrain_draws(private$model_methods_env_$model_ptr_, draws)
uncon_names <- private$model_methods_env_$unconstrained_param_names(private$model_methods_env_$model_ptr_, FALSE, FALSE)
names(unconstrained) <- repair_variable_names(uncon_names)
maybe_convert_draws_format(unconstrained, format, .nchains = posterior::nchains(draws))
unconstrained$.nchains <- chains

do.call(function(...) { create_draws_format(format, ...) }, unconstrained)
}
CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws)

Expand Down
13 changes: 13 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,19 @@ maybe_convert_draws_format <- function(draws, format, ...) {
)
}

create_draws_format <- function(format, ...) {
format <- sub("^draws_", "", format)
switch(
format,
"array" = posterior::draws_array(...),
"df" = posterior::draws_df(...),
"data.frame" = posterior::draws_df(...),
"list" = posterior::draws_list(...),
"matrix" = posterior::draws_matrix(...),
"rvars" = posterior::draws_rvars(...),
stop("Invalid draws format.", call. = FALSE)
)
}

# convert draws for external packages ------------------------------------------

Expand Down
17 changes: 13 additions & 4 deletions inst/include/model_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,24 @@ Eigen::VectorXd unconstrain_variables(SEXP ext_model_ptr, Eigen::VectorXd variab
}

// [[Rcpp::export]]
Eigen::MatrixXd unconstrain_draws(SEXP ext_model_ptr, Eigen::MatrixXd variables) {
Rcpp::List unconstrain_draws(SEXP ext_model_ptr, Eigen::MatrixXd variables) {
Rcpp::XPtr<stan::model::model_base> ptr(ext_model_ptr);
Eigen::MatrixXd unconstrained_draws(variables.cols(), variables.rows());
// Need to do this for the first row to get the correct size of the unconstrained draws
Eigen::VectorXd unconstrained_draw1;
ptr->unconstrain_array(variables.row(0).transpose(), unconstrained_draw1, &Rcpp::Rcout);
std::vector<Eigen::VectorXd> unconstrained_draws(unconstrained_draw1.size());
for (auto&& unconstrained_par : unconstrained_draws) {
unconstrained_par.resize(variables.rows());
}

for (int i = 0; i < variables.rows(); i++) {
Eigen::VectorXd unconstrained_variables;
ptr->unconstrain_array(variables.transpose().col(i), unconstrained_variables, &Rcpp::Rcout);
unconstrained_draws.col(i) = unconstrained_variables;
for (int j = 0; j < unconstrained_variables.size(); j++) {
unconstrained_draws[j](i) = unconstrained_variables(j);
}
}
return unconstrained_draws.transpose();
return Rcpp::wrap(unconstrained_draws);
}

// [[Rcpp::export]]
Expand Down

0 comments on commit f141d06

Please sign in to comment.