Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/ropensci/dynamite
Browse files Browse the repository at this point in the history
  • Loading branch information
helske committed May 8, 2024
2 parents ca8bec3 + 5a0fd2c commit 06a8d48
Show file tree
Hide file tree
Showing 13 changed files with 197 additions and 105 deletions.
14 changes: 8 additions & 6 deletions R/dynamiteformula.R
Original file line number Diff line number Diff line change
Expand Up @@ -619,12 +619,14 @@ get_dag <- function(x, project = FALSE, covariates = FALSE,
lag_dep_pa <- lag_dep[lag_dep$resp == resp[i], ]
lag_dep_ch <- lag_dep[lag_dep$var == resp[i], ]
lag_dep_new <- vector(mode = "list", length = nrow(lag_dep_ch))
for (j in seq_len(nrow(lag_dep_ch))) {
lag_dep_new[[j]] <- data.frame(
var = c(contemp_pa, lag_dep_pa$var),
order = c(rep(0L, k), lag_dep_pa$order) + lag_dep_ch$order[j],
resp = lag_dep_ch$resp[j]
)
if (nrow(lag_dep_pa) > 0L) {
for (j in seq_len(nrow(lag_dep_ch))) {
lag_dep_new[[j]] <- data.frame(
var = c(contemp_pa, lag_dep_pa$var),
order = c(rep(0L, k), lag_dep_pa$order) + lag_dep_ch$order[j],
resp = lag_dep_ch$resp[j]
)
}
}
lag_dep <- rbind(
lag_dep[lag_dep$resp != resp[i] & lag_dep$var != resp[i], ],
Expand Down
4 changes: 2 additions & 2 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,8 @@ plot_varying <- function(coefs, level, alpha, scales, n_params) {
)
}
title <- glue::glue(
"Posterior means and {100 * (1 - 2 * level)} %",
"intervals of the {title_spec}"
"Posterior means and {100 * (1 - 2 * level)} ",
"% intervals of the {title_spec}"
)
# avoid NSE notes from R CMD check
time <- mean <- category <- parameter <- NULL
Expand Down
3 changes: 0 additions & 3 deletions R/prepare_stan_input.R
Original file line number Diff line number Diff line change
Expand Up @@ -479,8 +479,6 @@ initialize_multivariate_channel <- function(y, y_cg, y_name, cg_idx,
list(channel = channel, sampling = sampling)
}



#' Default channel preparation
#'
#' Computes default channel-specific variables for Stan sampling,
Expand Down Expand Up @@ -1254,7 +1252,6 @@ prepare_channel_student <- function(y, Y, channel, sampling,
out
}


#' Raise an error if factor type is not supported by a family
#'
#' @param y \[`character(1)`]\cr Response variable the error is related to.
Expand Down
27 changes: 14 additions & 13 deletions R/priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,20 @@ extract_vectorizable_priors <- function(priors, y) {
prepare_common_priors <- function(priors, M, shrinkage, P,
correlated_nu, correlated_lf) {
common_priors <- NULL
if (shrinkage) {
common_priors <- ifelse_(
is.null(priors),
data.frame(
parameter = "xi",
response = "",
prior = "normal(0, 1)",
type = "xi",
category = ""
),
priors[priors$type == "xi", ]
)
}
# Shrinkage feature removed for now
#if (shrinkage) {
# common_priors <- ifelse_(
# is.null(priors),
# data.frame(
# parameter = "xi",
# response = "",
# prior = "normal(0, 1)",
# type = "xi",
# category = ""
# ),
# priors[priors$type == "xi", ]
# )
#}
if (M > 1L && correlated_nu) {
common_priors <- ifelse_(
is.null(priors),
Expand Down
27 changes: 16 additions & 11 deletions R/stanblocks_families.R
Original file line number Diff line number Diff line change
Expand Up @@ -704,13 +704,6 @@ loglik_lines_gaussian <- function(y, obs, idt, default, ...) {

loglik_lines_multinomial <- function(idt, cvars, cgvars, backend,
threading, ...) {
stopifnot_(
stan_version(backend) >= "2.24",
c(
"Multinomial family is not supported for this version of {.pkg {backend}}.",
`i` = "Please install a newer version of {.pkg {backend}}."
)
)
cgvars$categories <- cgvars$y
cgvars$y <- cgvars$y_cg
cgvars$multinomial <- TRUE
Expand Down Expand Up @@ -2142,9 +2135,12 @@ model_lines_categorical <- function(y, idt, obs, family, priors,
onlyif(has_fixed || has_varying, c("J_{y}", "K_{y}")),
onlyif(has_X, "X")
)
likelihood <- glue::glue(
"target += reduce_sum({distr}_loglik_{y}_lpmf, {seq1T}, grainsize, ",
"{fun_args});"
likelihood <- paste_rows(
paste0(
"target += reduce_sum({distr}_loglik_{y}_lpmf, {seq1T}, grainsize, ",
"{fun_args});"
),
.indent = idt(1)
)
} else {
likelihood <- loglik_lines_categorical(
Expand Down Expand Up @@ -2211,7 +2207,16 @@ model_lines_gaussian <- function(y, obs, idt, priors,
paste_rows(priors, model_text, .parse = FALSE)
}

model_lines_multinomial <- function(cvars, cgvars, idt, threading, ...) {
model_lines_multinomial <- function(cvars, cgvars, idt, backend,
threading, ...) {
stopifnot_(
stan_version(backend) >= "2.24",
c(
"Multinomial family is not supported for
this version of {.pkg {backend}}.",
`i` = "Please install a newer version of {.pkg {backend}}."
)
)
cgvars$priors <- lapply(
cgvars$y[-1L],
function(s) {
Expand Down
7 changes: 4 additions & 3 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ The `dynamite` package is developed with the support of the Research Council of

## Installation

You can install the most recent stable version of `dynmite` from [CRAN](https://cran.r-project.org/package=dynamite) or the development version from [R-universe](https://r-universe.dev/search/) by running one the following lines:
You can install the most recent stable version of `dynamite` from [CRAN](https://cran.r-project.org/package=dynamite) or the development version from [R-universe](https://r-universe.dev/search/) by running one the following lines:

```{r, eval = FALSE}
install.packages("dynamite")
Expand Down Expand Up @@ -73,7 +73,8 @@ gaussian_example_fit <- dynamite(
```{r, echo = FALSE}
set.seed(1)
library("dynamite")
gaussian_example_fit <- update(gaussian_example_fit,
gaussian_example_fit <- update(
gaussian_example_fit,
iter = 2000, warmup = 1000, thin = 1,
chains = 2, cores = 2, refresh = 0
)
Expand All @@ -99,7 +100,7 @@ Traceplots and density plots for time-invariant parameters:
plot(gaussian_example_fit, plot_type = "trace", types = "beta")
```

Posterior predictive samples for the first 4 groups (samples based on the posterior distribution of model parameters and observed data on first time point):
Posterior predictive samples for the first 4 groups (using the samples based on the posterior distribution of the model parameters and observed data on the first time point):

```{r, warning=FALSE, fig.width = 9, fig.height = 4}
library("ggplot2")
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ on DMPMs and the `dynamite` package, see the related

## Installation

You can install the most recent stable version of `dynmite` from
You can install the most recent stable version of `dynamite` from
[CRAN](https://cran.r-project.org/package=dynamite) or the development
version from [R-universe](https://r-universe.dev/search/) by running one
the following lines:
Expand Down Expand Up @@ -105,8 +105,8 @@ print(gaussian_example_fit)
#>
#> Elapsed time (seconds):
#> warmup sample
#> chain:1 5.801 3.542
#> chain:2 5.658 3.544
#> chain:1 5.546 3.396
#> chain:2 5.533 3.524
#>
#> Summary statistics of the time- and group-invariant parameters:
#> # A tibble: 6 × 10
Expand Down Expand Up @@ -144,9 +144,9 @@ plot(gaussian_example_fit, plot_type = "trace", types = "beta")

<img src="man/figures/README-unnamed-chunk-9-1.png" style="display: block; margin: auto;" />

Posterior predictive samples for the first 4 groups (samples based on
the posterior distribution of model parameters and observed data on
first time point):
Posterior predictive samples for the first 4 groups (using the samples
based on the posterior distribution of the model parameters and observed
data on the first time point):

``` r
library("ggplot2")
Expand Down
Loading

0 comments on commit 06a8d48

Please sign in to comment.