Skip to content

Commit

Permalink
add dynamite object permutation field, add tests for dynamice and cus…
Browse files Browse the repository at this point in the history
…tom stan model, run-extended
  • Loading branch information
santikka committed Apr 4, 2024
1 parent 7687bfc commit 69d4c62
Show file tree
Hide file tree
Showing 36 changed files with 237 additions and 90 deletions.
10 changes: 9 additions & 1 deletion R/dynamice.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ dynamice <- function(dformula, data, time, group = NULL,
stanfit <- rstan::read_stan_csv(filenames)
stanfit@stanmodel <- methods::new("stanmodel", model_code = tmp$model_code)
}
# TODO does this work in this case?
n_draws <- ifelse_(
is.null(stanfit),
0L,
(stanfit@sim$n_save[1L] - stanfit@sim$warmup2[1L]) *
stanfit@sim$chains
)
# TODO return object? How is this going to work with update?
structure(
list(
Expand All @@ -152,8 +159,9 @@ dynamice <- function(dformula, data, time, group = NULL,
time_var = time,
priors = priors,
backend = backend,
permutation = sample(n_draws),
imputed = imputed,
call = tmp$call, # TODO?
imputed = imputed
),
class = "dynamitefit"
)
Expand Down
7 changes: 7 additions & 0 deletions R/dynamite.R
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,12 @@ dynamite <- function(dformula, data, time, group = NULL,
# copy so that get_data can still return the full stan_input via debug
stan_input_out <- stan_input
stan_input_out$sampling_vars <- NULL
n_draws <- ifelse_(
is.null(stanfit),
0L,
(stanfit@sim$n_save[1L] - stanfit@sim$warmup2[1L]) *
stanfit@sim$chains
)
out <- structure(
list(
stanfit = stanfit,
Expand All @@ -248,6 +254,7 @@ dynamite <- function(dformula, data, time, group = NULL,
time_var = time,
priors = rbindlist_(stan_input$priors),
backend = backend,
permutation = sample(n_draws),
call = dynamite_call
),
class = "dynamitefit"
Expand Down
17 changes: 10 additions & 7 deletions R/loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#' different channels are combined, i.e., all channels of are left out.
#' @param thin \[`integer(1)`]\cr Use only every `thin` posterior sample when
#' computing LOO. This can be beneficial with when the model object contains
#' large number of samples. Default is `NULL` which is equal to `thin = 1`.
#' large number of samples. Default is `1` meaning that all samples are used.
#' @param ... Ignored.
#' @return An output from [loo::loo()] or a list of such outputs (if
#' `separate_channels` was `TRUE`).
Expand All @@ -33,7 +33,7 @@
#' }
#' }
#'
loo.dynamitefit <- function(x, separate_channels = FALSE, thin = NULL, ...) {
loo.dynamitefit <- function(x, separate_channels = FALSE, thin = 1L, ...) {
stopifnot_(
is.null(x$imputed),
"Leave-one-out cross-validation is not supported for models estimated using
Expand All @@ -47,8 +47,11 @@ loo.dynamitefit <- function(x, separate_channels = FALSE, thin = NULL, ...) {
checkmate::test_flag(x = separate_channels),
"Argument {.arg separate_channels} must be a single {.cls logical} value."
)
if (is.null(thin)) thin <- 1L
# compute loglik for all posterior samples even with thin != NULL
stopifnot_(
checkmate::test_int(x = thin, lower = 1L, upper = ndraws(x)),
"Argument {.arg thin} must be a single positive {.cls integer}."
)
# compute loglik for all posterior samples even with thin > 1
out <- initialize_predict(
x,
newdata = NULL,
Expand All @@ -68,7 +71,7 @@ loo.dynamitefit <- function(x, separate_channels = FALSE, thin = NULL, ...) {
n_chains <- x$stanfit@sim$chains
n_draws <- ndraws(x) %/% n_chains
idx_draws <- seq.int(1L, n_draws * n_chains, by = thin)
loo_ <- function(ll, n_draws, n_chains, thin) {
loo_ <- function(ll, n_draws, n_chains) {
ll <- t(matrix(ll, ncol = n_draws * n_chains)[, idx_draws])
reff <- loo::relative_eff(
exp(ll),
Expand All @@ -86,13 +89,13 @@ loo.dynamitefit <- function(x, separate_channels = FALSE, thin = NULL, ...) {
),
by = "variable"
)
lapply(ll, function(x) loo_(x$value, n_draws, n_chains, thin))
lapply(ll, function(x) loo_(x$value, n_draws, n_chains))
} else {
temp <- out[, .SD, .SDcols = patterns("_loglik$")]
ll <- temp[is.finite(rowSums(temp))][,
rowSums(.SD),
.SDcols = names(temp)
]
loo_(ll, n_draws, n_chains, thin)
loo_(ll, n_draws, n_chains)
}
}
6 changes: 3 additions & 3 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,8 @@ predict_ <- function(object, simulated, storage, observed,
env = list(n_new = n_new, n_draws = n_draws)
]
simulated[,
(".draw") := rep(seq.int(1L, n_draws), n_new),
env = list(n_new = n_new, n_draws = n_draws)
(".draw") := rep(seq.int(1L, n_draws), n_new),
env = list(n_new = n_new, n_draws = n_draws)
]
idx <- which(draw_time == u_time[1L]) + (fixed - 1L) * n_draws
n_sim <- n_draws
Expand Down Expand Up @@ -439,7 +439,7 @@ predict_ <- function(object, simulated, storage, observed,
idx_draws <- ifelse_(
identical(n_draws, ndraws(object)),
seq_len(n_draws),
sample.int(ndraws(object), n_draws)
object$permutation[seq_len(n_draws)]
)
eval_envs <- prepare_eval_envs(
object,
Expand Down
5 changes: 4 additions & 1 deletion R/predict_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,10 @@ prepare_eval_envs <- function(object, simulated, observed,
type, eval_type, idx_draws,
new_levels, group_var) {
#samples <- rstan::extract(object$stanfit)
samples <- lapply(posterior::as_draws_rvars(object$stanfit), posterior::draws_of)
samples <- lapply(
posterior::as_draws_rvars(object$stanfit),
posterior::draws_of
)
channel_vars <- object$stan$channel_vars
channel_group_vars <- object$stan$channel_group_vars
cg <- attr(object$dformulas$all, "channel_groups")
Expand Down
Binary file modified R/sysdata.rda
Binary file not shown.
8 changes: 4 additions & 4 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ install.packages("dynamite", repos = "https://ropensci.r-universe.dev")
A single-channel model with time-invariant effect of `z`, time-varying effect of `x`, lagged value of the response variable `y` and a group-specific random intercepts:

```{r, echo = FALSE}
library(dynamite)
library("dynamite")
ggplot2::theme_set(ggplot2::theme_bw())
```

```{r, eval = FALSE}
set.seed(1)
library(dynamite)
library("dynamite")
gaussian_example_fit <- dynamite(
obs(y ~ -1 + z + varying(~ x + lag(y)) + random(~1), family = "gaussian") +
splines(df = 20),
Expand All @@ -72,7 +72,7 @@ gaussian_example_fit <- dynamite(

```{r, echo = FALSE}
set.seed(1)
library(dynamite)
library("dynamite")
gaussian_example_fit <- update(gaussian_example_fit,
iter = 2000, warmup = 1000, thin = 1,
chains = 2, cores = 2, refresh = 0
Expand Down Expand Up @@ -102,7 +102,7 @@ plot(gaussian_example_fit, type = "beta")
Posterior predictive samples for the first 4 groups (samples based on the posterior distribution of model parameters and observed data on first time point):

```{r, warning=FALSE, fig.width = 9, fig.height = 4}
library(ggplot2)
library("ggplot2")
pred <- predict(gaussian_example_fit, n_draws = 100)
pred |>
dplyr::filter(id < 5) |>
Expand Down
32 changes: 18 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ group-specific random intercepts:

``` r
set.seed(1)
library(dynamite)
library("dynamite")
gaussian_example_fit <- dynamite(
obs(y ~ -1 + z + varying(~ x + lag(y)) + random(~1), family = "gaussian") +
splines(df = 20),
Expand All @@ -94,25 +94,29 @@ gaussian_example_fit
#> Grouping variable: id (Number of groups: 50)
#> Time index variable: time (Number of time points: 30)
#>
#> Smallest bulk-ESS: 557 (sigma_nu_y_alpha)
#> Smallest tail-ESS: 1032 (sigma_nu_y_alpha)
#> Largest Rhat: 1.006 (alpha_y[28])
#> NUTS sampler diagnostics:
#>
#> No divergences, saturated max treedepths or low E-BFMIs.
#>
#> Smallest bulk-ESS: 661 (sigma_nu_y_alpha)
#> Smallest tail-ESS: 1058 (sigma_nu_y_alpha)
#> Largest Rhat: 1.003 (sigma_y)
#>
#> Elapsed time (seconds):
#> warmup sample
#> chain:1 5.169 2.753
#> chain:2 4.897 1.763
#> chain:1 5.479 3.373
#> chain:2 5.966 3.770
#>
#> Summary statistics of the time- and group-invariant parameters:
#> # A tibble: 6 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
#> <chr> <num> <num> <num> <num> <num> <num> <num> <num> <num>
#> 1 beta_y_z 1.97 1.97 0.0121 0.0124 1.95 1.99 1.00 2122. 1385.
#> 2 sigma_nu_y… 0.0944 0.0938 0.0112 0.0113 0.0774 0.114 0.999 557. 1032.
#> 3 sigma_y 0.198 0.198 0.00368 0.00382 0.192 0.204 1.00 2169. 1398.
#> 4 tau_alpha_y 0.209 0.202 0.0497 0.0453 0.143 0.298 1.00 1237. 1419.
#> 5 tau_y_x 0.362 0.353 0.0674 0.0650 0.268 0.485 1.00 2177. 1670.
#> 6 tau_y_y_la… 0.106 0.103 0.0216 0.0206 0.0770 0.146 1.00 1936. 1144.
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 beta_y_z 1.97 1.97 0.0116 0.0112 1.95 1.99 1.00 2815. 1434.
#> 2 sigma_nu_y… 0.0944 0.0933 0.0114 0.0107 0.0780 0.114 1.00 661. 1058.
#> 3 sigma_y 0.198 0.198 0.00373 0.00362 0.192 0.204 1.00 2580. 1254.
#> 4 tau_alpha_y 0.212 0.205 0.0483 0.0432 0.146 0.301 1.00 1731. 1606.
#> 5 tau_y_x 0.364 0.355 0.0740 0.0648 0.266 0.494 1.00 2812. 1504.
#> 6 tau_y_y_la… 0.107 0.105 0.0219 0.0213 0.0781 0.148 1.00 2387. 1682.
```

Posterior estimates of time-varying effects:
Expand Down Expand Up @@ -144,7 +148,7 @@ the posterior distribution of model parameters and observed data on
first time point):

``` r
library(ggplot2)
library("ggplot2")
pred <- predict(gaussian_example_fit, n_draws = 100)
pred |>
dplyr::filter(id < 5) |>
Expand Down
2 changes: 1 addition & 1 deletion data-raw/categorical_example_fit.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Code to create `categorical_example_fit` object

library(dynamite)
library("dynamite")

set.seed(1)
categorical_example_fit <- dynamite(
Expand Down
2 changes: 1 addition & 1 deletion data-raw/gaussian_example_fit.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Code to create `gaussian_example_fit` object

library(dynamite)
library("dynamite")

# Note the very small number of post-warmup iterations due to the data size
# restrictions in CRAN.
Expand Down
1 change: 1 addition & 0 deletions data-raw/gaussian_simulation_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ gaussian_simulation_fit <- dynamite(
group = "id",
chains = 1,
iter = 1,
refresh = 0,
algorithm = "Fixed_param",
init = list(init),
)
Expand Down
2 changes: 1 addition & 1 deletion data-raw/multichannel_example_fit.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Code to create `multichannel_example_fit` object

library(dynamite)
library("dynamite")

# Note the very small number of post-warmup iterations due to the data size
# restrictions in CRAN.
Expand Down
Binary file modified data/categorical_example_fit.rda
Binary file not shown.
Binary file modified data/gaussian_example_fit.rda
Binary file not shown.
Binary file modified data/gaussian_simulation_fit.rda
Binary file not shown.
Binary file modified data/multichannel_example_fit.rda
Binary file not shown.
6 changes: 3 additions & 3 deletions man/categorical_example.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/categorical_example_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/dynamite-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file modified man/figures/README-unnamed-chunk-10-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/README-unnamed-chunk-7-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/README-unnamed-chunk-8-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/README-unnamed-chunk-9-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions man/gaussian_example.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/gaussian_example_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions man/gaussian_simulation_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/lags.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/lfactor.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/loo.dynamitefit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/multichannel_example.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 69d4c62

Please sign in to comment.