diff --git a/R/fit.R b/R/fit.R index c2a1f2bf..4e15e167 100644 --- a/R/fit.R +++ b/R/fit.R @@ -561,6 +561,9 @@ CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_var #' unconstrain_draws <- function(files = NULL, draws = NULL, format = getOption("cmdstanr_draws_format", "draws_array")) { + if (!(format %in% valid_draws_formats())) { + stop("Invalid draws format requested!", call. = FALSE) + } if (!is.null(files) || !is.null(draws)) { if (!is.null(files) && !is.null(draws)) { stop("Either a list of CSV files or a draws object can be passed, not both", @@ -582,6 +585,8 @@ unconstrain_draws <- function(files = NULL, draws = NULL, } draws <- maybe_convert_draws_format(private$draws_, "draws_matrix") } + + chains <- posterior::nchains(draws) model_par_names <- self$metadata()$stan_variables[self$metadata()$stan_variables != "lp__"] model_variables <- self$runset$args$model_variables @@ -598,7 +603,9 @@ unconstrain_draws <- function(files = NULL, draws = NULL, unconstrained <- private$model_methods_env_$unconstrain_draws(private$model_methods_env_$model_ptr_, draws) uncon_names <- private$model_methods_env_$unconstrained_param_names(private$model_methods_env_$model_ptr_, FALSE, FALSE) names(unconstrained) <- repair_variable_names(uncon_names) - maybe_convert_draws_format(unconstrained, format, .nchains = posterior::nchains(draws)) + unconstrained$.nchains <- chains + + do.call(function(...) { create_draws_format(format, ...) }, unconstrained) } CmdStanFit$set("public", name = "unconstrain_draws", value = unconstrain_draws) diff --git a/R/utils.R b/R/utils.R index d03184ab..0fb77245 100644 --- a/R/utils.R +++ b/R/utils.R @@ -426,6 +426,19 @@ maybe_convert_draws_format <- function(draws, format, ...) { ) } +create_draws_format <- function(format, ...) { + format <- sub("^draws_", "", format) + switch( + format, + "array" = posterior::draws_array(...), + "df" = posterior::draws_df(...), + "data.frame" = posterior::draws_df(...), + "list" = posterior::draws_list(...), + "matrix" = posterior::draws_matrix(...), + "rvars" = posterior::draws_rvars(...), + stop("Invalid draws format.", call. = FALSE) + ) +} # convert draws for external packages ------------------------------------------ diff --git a/inst/include/model_methods.cpp b/inst/include/model_methods.cpp index e4931462..262ddfc7 100644 --- a/inst/include/model_methods.cpp +++ b/inst/include/model_methods.cpp @@ -127,15 +127,24 @@ Eigen::VectorXd unconstrain_variables(SEXP ext_model_ptr, Eigen::VectorXd variab } // [[Rcpp::export]] -Eigen::MatrixXd unconstrain_draws(SEXP ext_model_ptr, Eigen::MatrixXd variables) { +Rcpp::List unconstrain_draws(SEXP ext_model_ptr, Eigen::MatrixXd variables) { Rcpp::XPtr ptr(ext_model_ptr); - Eigen::MatrixXd unconstrained_draws(variables.cols(), variables.rows()); + // Need to do this for the first row to get the correct size of the unconstrained draws + Eigen::VectorXd unconstrained_draw1; + ptr->unconstrain_array(variables.row(0).transpose(), unconstrained_draw1, &Rcpp::Rcout); + std::vector unconstrained_draws(unconstrained_draw1.size()); + for (auto&& unconstrained_par : unconstrained_draws) { + unconstrained_par.resize(variables.rows()); + } + for (int i = 0; i < variables.rows(); i++) { Eigen::VectorXd unconstrained_variables; ptr->unconstrain_array(variables.transpose().col(i), unconstrained_variables, &Rcpp::Rcout); - unconstrained_draws.col(i) = unconstrained_variables; + for (int j = 0; j < unconstrained_variables.size(); j++) { + unconstrained_draws[j](i) = unconstrained_variables(j); + } } - return unconstrained_draws.transpose(); + return Rcpp::wrap(unconstrained_draws); } // [[Rcpp::export]]