Skip to content

Commit

Permalink
export stanfit getters
Browse files Browse the repository at this point in the history
  • Loading branch information
santikka committed Jul 13, 2024
1 parent 98886fa commit ae46573
Show file tree
Hide file tree
Showing 36 changed files with 518 additions and 126 deletions.
36 changes: 36 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,41 @@ S3method(coef,dynamitefit)
S3method(confint,dynamitefit)
S3method(fitted,dynamitefit)
S3method(formula,dynamitefit)
S3method(get_algorithm,CmdStanMCMC)
S3method(get_algorithm,CmdStanMCMC_CSV)
S3method(get_algorithm,stanfit)
S3method(get_code,dynamitefit)
S3method(get_code,dynamiteformula)
S3method(get_data,dynamitefit)
S3method(get_data,dynamiteformula)
S3method(get_diagnostics,CmdStanMCMC)
S3method(get_diagnostics,CmdStanMCMC_CSV)
S3method(get_diagnostics,stanfit)
S3method(get_draws,CMdStanMCMC_CSV)
S3method(get_draws,CmdStanMCMC)
S3method(get_draws,stanfit)
S3method(get_elapsed_time,CmdStanMCMC)
S3method(get_elapsed_time,CmdStanMCMC_CSV)
S3method(get_elapsed_time,stanfit)
S3method(get_max_treedepth,CmdStanMCMC)
S3method(get_max_treedepth,CmdStanMCMC_CSV)
S3method(get_max_treedepth,stanfit)
S3method(get_model_code,CmdStanMCMC)
S3method(get_model_code,CmdStanMCMC_CSV)
S3method(get_model_code,stanfit)
S3method(get_nchains,CmdStanMCMC)
S3method(get_nchains,CmdStanMCMC_CSV)
S3method(get_nchains,stanfit)
S3method(get_ndraws,CmdStanMCMC)
S3method(get_ndraws,CmdStanMCMC_CSV)
S3method(get_ndraws,stanfit)
S3method(get_parameter_dims,dynamitefit)
S3method(get_parameter_dims,dynamiteformula)
S3method(get_parameter_names,dynamitefit)
S3method(get_parameter_types,dynamitefit)
S3method(get_pars_oi,CmdStanMCMC)
S3method(get_pars_oi,CmdStanMCMC_CSV)
S3method(get_pars_oi,stanfit)
S3method(get_priors,dynamitefit)
S3method(get_priors,dynamiteformula)
S3method(hmc_diagnostics,dynamitefit)
Expand All @@ -41,11 +68,20 @@ export(aux)
export(dynamice)
export(dynamite)
export(dynamiteformula)
export(get_algorithm)
export(get_code)
export(get_data)
export(get_diagnostics)
export(get_draws)
export(get_elapsed_time)
export(get_max_treedepth)
export(get_model_code)
export(get_nchains)
export(get_ndraws)
export(get_parameter_dims)
export(get_parameter_names)
export(get_parameter_types)
export(get_pars_oi)
export(get_priors)
export(hmc_diagnostics)
export(lags)
Expand Down
2 changes: 1 addition & 1 deletion R/as_data_table.R
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ as_data_table_omega <- function(x, draws, n_draws, response, category, ...) {
#' @describeIn as_data_table_default Data Table for a "omega_alpha" Parameter
#' @noRd
as_data_table_omega_alpha <- function(x, draws, n_draws, response,
category, ...) {
category, ...) {
D <- x$stan$model_vars$D
data.table::data.table(
parameter = rep(
Expand Down
38 changes: 19 additions & 19 deletions R/dynamite-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#' * The package vignettes
#' * [dynamite::dynamiteformula()] for information on defining models.
#' * [dynamite::dynamite()] for information on fitting models.
#' * \url{https://github.com/ropensci/dynamite/issues/} to submit a bug report
#' * <https://github.com/ropensci/dynamite/issues/> to submit a bug report
#' or a feature request.
#'
#' # Authors
Expand Down Expand Up @@ -56,8 +56,8 @@
#' coefficients vary according to a spline with 20 degrees of freedom.
#'
#' @family examples
#' @source The data was generated according to a script in
#' \url{https://github.com/ropensci/dynamite/blob/main/data-raw/gaussian_example.R}
#' @source The data was generated via `gaussian_example.R` in
#' <https://github.com/ropensci/dynamite/tree/main/data-raw/>
#' @format A data frame with 3000 rows and 5 variables:
#' \describe{
#' \item{y}{The response variable.}
Expand Down Expand Up @@ -95,8 +95,8 @@
#' }
#' Note the very small number of samples due to size restrictions on CRAN.
#' @family examples
#' @source The data was generated according to a script in
#' \url{https://github.com/ropensci/dynamite/blob/main/data-raw/gaussian_example_fit.R}
#' @source The data was generated via `gaussian_example_fit.R` in
#' <https://github.com/ropensci/dynamite/tree/main/data-raw/>
#' @format A `dynamitefit` object.
"gaussian_example_fit"

Expand All @@ -120,8 +120,8 @@
# #' )
# #' }
# #' @family examples
# #' @source The data was generated according to a script in
# #' \url{https://github.com/ropensci/dynamite/blob/main/data-raw/gaussian_simulation_fit.R}
# #' @source The data was generated via to `gaussian_simulation_fit.R` in
# #' <https://github.com/ropensci/dynamite/tree/main/data-raw/>
# #' @format A `dynamitefit` object.
# "gaussian_simulation_fit"

Expand All @@ -131,8 +131,8 @@
#' response variables of different distributions.
#'
#' @family examples
#' @source The data was generated according to a script in
#' \url{https://github.com/ropensci/dynamite/blob/main/data-raw/multichannel_example.R}
#' @source The data was generated via `multichannel_example.R` in
#' <https://github.com/ropensci/dynamite/tree/main/data-raw/>
#' @format A data frame with 3000 rows and 5 variables:
#' \describe{
#' \item{id}{Variable defining individuals (1 to 50).}
Expand Down Expand Up @@ -171,8 +171,8 @@
#' }
#' Note the small number of samples due to size restrictions on CRAN.
#' @family examples
#' @source Script in
#' \url{https://github.com/ropensci/dynamite/blob/main/data-raw/multichannel_example_fit.R}
#' @source THe data was generated via `multichannel_example_fit.R` in
#' <https://github.com/ropensci/dynamite/tree/main/data-raw/>
#' @format A `dynamitefit` object.
"multichannel_example_fit"

Expand All @@ -182,8 +182,8 @@
#' response variables.
#'
#' @family examples
#' @source The data was generated according to a script in
#' \url{https://github.com/ropensci/dynamite/blob/main/data-raw/categorical_example.R}
#' @source The data was generated via `categorical_example.R` in
#' <https://github.com/ropensci/dynamite/tree/main/data-raw/>
#' @format A data frame with 2000 rows and 5 variables:
#' \describe{
#' \item{id}{Variable defining individuals (1 to 100).}
Expand Down Expand Up @@ -216,8 +216,8 @@
#' }
#' Note the small number of samples due to size restrictions on CRAN.
#' @family examples
#' @source Script in
#' \url{https://github.com/ropensci/dynamite/blob/main/data-raw/categorical_example_fit.R}
#' @source The data was generated via `categorical_example_fit.R` in
#' <https://github.com/ropensci/dynamite/tree/main/data-raw/>
#' @format A `dynamitefit` object.
"categorical_example_fit"

Expand All @@ -227,8 +227,8 @@
# #' trajectories are defined by a latent factor and random intercept terms.
# #'
# #' @family examples
# #' @source The data was generated according to a script in
# #' \url{https://github.com/ropensci/dynamite/blob/main/data-raw/latent_factor_example.R}
# #' @source The data was generated via `latent_factor_example.R` in
# #' <https://github.com/ropensci/dynamite/blob/main/data-raw/>
# #' @format A data frame with 2000 rows and 3 variables:
# #' \describe{
# #' \item{y}{A continuos variable.}
Expand Down Expand Up @@ -264,7 +264,7 @@
# #' }
# #' Note the very small number of samples due to size restrictions on CRAN.
# #' @family examples
# #' @source Script in
# #' \url{https://github.com/ropensci/dynamite/blob/main/data-raw/latent_factor_example_fit.R}
# #' @source The data was generated via `latent_factor_example_fit.R` in
# #' <https://github.com/ropensci/dynamite/tree/main/data-raw/>
# #' @format A `dynamitefit` object.
# "latent_factor_example_fit"
32 changes: 18 additions & 14 deletions R/dynamite.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#' Bayesian inference. The \pkg{dynamite} package supports a wide range of
#' distributions and allows the user to flexibly customize the priors for the
#' model parameters. The dynamite model is specified using standard \R formula
#' syntax via [dynamite::dynamiteformula()]. For more information and examples,
#' syntax via [dynamiteformula()]. For more information and examples,
#' see 'Details' and the package vignettes.
#'
#' The best-case scalability of `dynamite` in terms of data size should be
Expand All @@ -17,7 +17,7 @@
#' @family fitting
#' @rdname dynamite
#' @param dformula \[`dynamiteformula`]\cr The model formula.
#' See [dynamite::dynamiteformula()] and 'Details'.
#' See [dynamiteformula()] and 'Details'.
#' @param data
#' \[`data.frame`, `tibble::tibble`, or `data.table::data.table`]\cr
#' The data that contains the variables in the model in long format.
Expand All @@ -38,22 +38,23 @@
#' group. In case of name conflicts with `data`, see the `group_var` element
#' of the return object to get the column name of the new variable.
#' @param priors \[`data.frame`]\cr An optional data frame with prior
#' definitions. See [dynamite::get_priors()] and 'Details'.
#' definitions. See [get_priors()] and 'Details'.
#' @param backend \[`character(1)`]\cr Defines the backend interface to Stan,
#' should be either `"rstan"` (the default) or `"cmdstanr"`. Note that
#' `cmdstanr` needs to be installed separately as it is not on CRAN. It also
#' needs the actual `CmdStan` software. See https://mc-stan.org/cmdstanr/ for
#' details.
#' needs the actual `CmdStan` software. See <https://mc-stan.org/cmdstanr/>
#' for details.
#' @param verbose \[`logical(1)`]\cr All warnings and messages are suppressed
#' if set to `FALSE`. Defaults to `TRUE`. Setting this to `FALSE` will also
#' disable checks for perfect collinearity in the model matrix.
#' @param verbose_stan \[`logical(1)`]\cr This is the `verbose` argument for
#' [rstan::sampling()]. Defaults to `FALSE`.
#' @param stanc_options \[`list()`]\cr This is the `stanc_options` argument
#' passed to the compile method of a `CmdStanModel` object via
#' [cmdstanr::cmdstan_model()] when `backend = "cmdstanr"`.
#' Defaults to `list("O0")`. To enable level one compiler optimizations,
#' use `list("O1")`.
#' `cmdstan_model()` when `backend = "cmdstanr"`. Defaults to `list("O0")`.
#' To enable level one compiler optimizations, use `list("O1")`.
#' See <https://mc-stan.org/cmdstanr/reference/cmdstan_model.html>
#' for details.
#' @param threads_per_chain \[`integer(1)`]\cr A Positive integer defining the
#' number of parallel threads to use within each chain. Default is `1`. See
#' [rstan::rstan_options()] and [cmdstanr::sample()] for details.
Expand All @@ -76,11 +77,13 @@
#' combined with `model_code = TRUE`, which adds the Stan model code to the
#' return object.
#' @param ... For `dynamite()`, additional arguments to [rstan::sampling()] or
#' [cmdstanr::sample()], such as `chains` and `cores` (`chains` and
#' `parallel_chains` in `cmdstanr`). For `summary()`, additional arguments to
#' [dynamite::as.data.frame.dynamitefit()]. For `print()`, further arguments
#' to the print method for tibbles (see [tibble::formatting]). Not used for
#' `formula()`.
#' the `$sample()` method of the `CmdStanModel` object
#' (see <https://mc-stan.org/cmdstanr/reference/model-method-sample.html>),
#' such as `chains` and `cores`
#' (`chains` and `parallel_chains` in `cmdstanr`). For `summary()`,
#' additional arguments to [as.data.frame.dynamitefit()]. For `print()`,
#' further arguments to the print method for tibbles
#' (see [tibble::formatting]). Not used for `formula()`.
#' @return `dynamite` returns a `dynamitefit` object which is a list containing
#' the following components:
#'
Expand Down Expand Up @@ -321,7 +324,8 @@ dynamite_check <- function(dformula, data, time, group, priors, verbose,
)
stopifnot_(
checkmate::test_string(x = custom_stan_model, null.ok = TRUE),
"Argument {.arg custom_stan_model} must be a single {.cls character} string."
"Argument {.arg custom_stan_model}
must be a single {.cls character} string."
)
stopifnot_(
!isTRUE(grepl("\\.stan$", custom_stan_model, perl = TRUE)) ||
Expand Down
3 changes: 2 additions & 1 deletion R/families.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ dynamitefamily <- function(name, link) {
)
stopifnot_(
is_supported_link(name, link),
"{.val {link}} is not a supported link function for a {.val {name}} channel."
"{.val {link}} is not a supported link function
for a {.val {name}} channel."
)
structure(
list(name = name, link = link),
Expand Down
5 changes: 2 additions & 3 deletions R/lags.R
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ parse_lags <- function(dformula, data, group_var, time_var, verbose) {
#' are stochastic.
#' @param increment \[`logical()`]\cr A vector indicating whether to add
#' the new lag term or not (e.g.,, whether it was already present or not).
#' @param type \[`character(1)`]\cr Either `"fixed"`, `"varying"`, or `"random"`.
#' @param type \[`character(1)`]\cr
#' Either `"fixed"`, `"varying"`, or `"random"`.
#' @param lhs \[`character()`]\cr A vector of the new lagged variable names.
#' @noRd
parse_new_lags <- function(dformula, channels_stoch, increment, type, lhs) {
Expand Down Expand Up @@ -594,5 +595,3 @@ prepare_lagged_response <- function(dformula, lag_map,
}
y
}


6 changes: 3 additions & 3 deletions R/lfo.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
#' the LFO computations to the console.
#' @param k_threshold \[`numeric(1)`]\cr Threshold for the Pareto k estimate
#' triggering refit. Default is 0.7.
#' @param ... Additional arguments passed to [rstan::sampling()] or
#' [cmdstanr::sample()], such as `chains` and `cores` (`parallel_chains` in
#' `cmdstanr`).
#' @param ... Additional arguments passed to [rstan::sampling()] or the
#' `$sample()` method of the `CmdStanModel` object, such as `chains` and
#' `cores` (`parallel_chains` in `cmdstanr`).
#' @return An `lfo` object which is a `list` with the following components:
#'
#' * `ELPD`\cr Expected log predictive density estimate.
Expand Down
2 changes: 1 addition & 1 deletion R/prepare_stan_input.R
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ prepare_channel_default <- function(y, Y, channel, sampling,
nzchar(category),
priors[priors$response == y & priors$category == category, ],
priors[priors$response == y, ]
)
)
channel$prior_distr <- list()
types <- priors$type
loop_types <- intersect(
Expand Down
12 changes: 6 additions & 6 deletions R/print.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,18 @@ print.dynamitefit <- function(x, full_diagnostics = FALSE, ...) {
if (mcmc_algorithm) {
min_ess <- which.min(sumr$ess_bulk)
cat("\nSmallest bulk-ESS: ", round(sumr$ess_bulk[min_ess]), " (",
sumr$variable[min_ess], ")",
sep = ""
sumr$variable[min_ess], ")",
sep = ""
)
min_ess <- which.min(sumr$ess_tail)
cat("\nSmallest tail-ESS: ", round(sumr$ess_tail[min_ess]), " (",
sumr$variable[min_ess], ")",
sep = ""
sumr$variable[min_ess], ")",
sep = ""
)
max_rhat <- which.max(sumr$rhat)
cat("\nLargest Rhat: ", round(sumr$rhat[max_rhat], 3), " (",
sumr$variable[max_rhat], ")",
sep = ""
sumr$variable[max_rhat], ")",
sep = ""
)
runtimes <- get_elapsed_time(x$stanfit)
if (nrow(runtimes) > 2L) {
Expand Down
Loading

0 comments on commit ae46573

Please sign in to comment.