Skip to content

Commit

Permalink
dag plotting wip
Browse files Browse the repository at this point in the history
  • Loading branch information
santikka committed Mar 19, 2024
1 parent 4ebc53e commit 9218840
Show file tree
Hide file tree
Showing 14 changed files with 414 additions and 38 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ S3method(mcmc_diagnostics,dynamitefit)
S3method(ndraws,dynamitefit)
S3method(nobs,dynamitefit)
S3method(plot,dynamitefit)
S3method(plot,dynamiteformula)
S3method(plot,lfo)
S3method(predict,dynamitefit)
S3method(print,dynamitefit)
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# dynamite 1.4.11

* The package now depends on `data.table` version 1.15.0 or higher.
* Added a `plot` method for `dynamiteformula` objects. This method draws a directed acyclic graph (DAG) of the model structure as a snapshot in time with timepoints from the past and the future equal to the highest-order lag dependency in the model as a `ggplot` object. Alternatively, setting the argument `tikz = TRUE` returns the DAG as a `character` string in TikZ format. See the documentation for more details.

# dynamite 1.4.10

Expand All @@ -23,7 +24,7 @@

# dynamite 1.4.7

* Added a note on priors vignette regarding default priors for $\tau$ parameters.
* Added a note on priors vignette regarding default priors for `tau` parameters.
* Fixed `mcmc_diagnostics()` function so that HMC diagnostics are checked also for models run with the `cmdstanr` backend.

# dynamite 1.4.6
Expand Down
5 changes: 2 additions & 3 deletions R/dynamite.R
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,7 @@ dynamite_check <- function(dformula, data, time, group, priors, verbose,
dynamite_stan <- function(dformulas, data, data_name, group, time,
priors, backend, verbose, verbose_stan,
stanc_options, threads_per_chain, grainsize,
custom_stan_model, debug,
...) {
custom_stan_model, debug, ...) {
stan_input <- prepare_stan_input(
dformulas$stoch,
data,
Expand Down Expand Up @@ -523,7 +522,7 @@ sampling_info <- function(dformulas, verbose, debug, backend) {
}
}

#' Check Arguments Names Of `...` for Stan Sampling
#' Check Argument Names Of `...` for Stan Sampling
#'
#' @inheritParams dynamite_stan
#' @param dots The `...` arguments of `dynamite` as a `list`
Expand Down
181 changes: 161 additions & 20 deletions R/dynamiteformula.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#' * Gamma: `gamma` (log-link, using mean and shape parameterization).
#' * Beta: `beta` (logit-link, using mean and precision parameterization).
#' * Student t: `student` (identity link, parametrized using degrees of
#' freedon, location and scale)
#' freedom, location and scale)
#'
#' The models in the \pkg{dynamite} package are defined by combining the
#' channel-specific formulas defined via \R formula syntax.
Expand Down Expand Up @@ -224,12 +224,12 @@ dynamiteformula <- function(formula, family) {
#' @param fixed \[`integer()`]\cr Time-invariant covariate indices.
#' @param varying \[`integer()`]\cr Time-varying covariate indices.
#' @param random \[`integer()`]\cr Random effect covariate indices.
#' @param has_fixed_intercept \[`logical(1)`]\cr Does the channel contain fixed
#' intercept?
#' @param has_fixed_intercept \[`logical(1)`]\cr Does the channel contain
#' a fixed intercept?
#' @param has_varying_intercept \[`logical(1)`]\cr Does the channel contain
#' varying intercept?
#' @param has_random_intercept \[`logical(1)`]\cr Does the channel contain random
#' group-level intercept term?
#' a varying intercept?
#' @param has_random_intercept \[`logical(1)`]\cr Does the channel contain
#' a random group-level intercept term?
#' @noRd
dynamitechannel <- function(formula, original = NULL, family, response, name = NULL,
fixed = integer(0L), varying = integer(0L),
Expand Down Expand Up @@ -478,9 +478,14 @@ get_nonlag_terms <- function(x) {
#' @param x A `dynamiteformula` object.
#' @noRd
get_lag_orders <- function(x) {
lapply(x, function(y) {
unique(find_lag_orders(formula_rhs(y$formula)))
tmp <- lapply(x, function(y) {
tmp_ <- find_lag_orders(formula_rhs(y$original))
if (nrow(tmp_) > 0L) {
tmp_$resp <- y$response
}
tmp_
})
unique(rbindlist_(tmp[vapply(tmp, nrow, integer(1L)) > 0]))
}

#' Get Special Type Formula of a Dimension in a `dynamiteformula`
Expand Down Expand Up @@ -535,20 +540,156 @@ get_quoted <- function(x) {
out[[i]] <- list(name = resp[i], expr = formula_rhs(x[[i]]$formula))
}
out
# if (length(resp) > 0L) {
# expr <- lapply(x, function(x) deparse1(formula_rhs(x$formula)))
# quote_str <- paste0(
# "`:=`(",
# paste0(resp, " = ", expr, collapse = ","),
# ")"
# )
# str2lang(quote_str)
# } else {
# NULL
# }
}

#' Get the Markov Blanket of Response Variable
#' Get a Directed Acyclic Graph (DAG) of a `dynamiteformula` Object.
#'
#' @param x A `dynamiteformula` object.
#' @param project A `logical` value. If `TRUE`, deterministic responses are
#' projected out of the DAG.
#' @noRd
get_dag <- function(x, project = FALSE, covariates = FALSE, expand = 1L) {
resp <- get_responses(x)
contemp_dep <- ifelse_(
covariates,
get_nonlag_terms(x),
lapply(get_nonlag_terms(x), function(y) y[y %in% resp])
)
lag_dep <- get_lag_orders(x)
cg <- attr(x, "channel_groups")
lag_dep <- ifelse_(
covariates,
lag_dep,
lag_dep[lag_dep$var %in% resp, ]
)
if (project) {
resp_det <- which_deterministic(x)
for (i in resp_det) {
contemp_pa <- contemp_dep[[i]]
k <- length(contemp_pa)
if (k > 0) {
contemp_dep[-i] <- lapply(contemp_dep[-i], function(y) {
ifelse_(
contemp_pa %in% y,
union(y, contemp_pa),
y
)
})
}
lag_dep_pa <- lag_dep[lag_dep$resp == resp[i], ]
lag_dep_ch <- lag_dep[lag_dep$var == resp[i], ]
lag_dep_new <- vector(mode = "list", length = nrow(lag_dep_ch))
for (j in seq_len(nrow(lag_dep_ch))) {
lag_dep_new[[j]] <- data.frame(
var = c(contemp_pa, lag_dep_pa$var),
order = c(rep(0, k), lag_dep_pa$order) + lag_dep_ch$order[j],
resp = lag_dep_ch$resp[j]
)
}
lag_dep <- rbind(
lag_dep[lag_dep$resp != resp[i] & lag_dep$var != resp[i], ],
rbindlist_(lag_dep_new)
)
}
resp_stoch <- which_stochastic(x)
contemp_dep <- contemp_dep[resp_stoch]
resp <- resp[resp_stoch]
cg <- cg[resp_stoch]
}
max_lag <- expand * max(lag_dep$order)
all_vars <- c(
resp,
unique(setdiff(union(unlist(contemp_dep), lag_dep$var), resp))
)
resp_lag <- expand.grid(var = all_vars, order = seq_len(max_lag))
v <- c(
paste0(resp_lag$var, "_{t - ", resp_lag$order, "}"),
paste0(resp_lag$var, "_{t + ", resp_lag$order, "}"),
paste0(all_vars, "_{t}")
)
n <- length(v)
m <- nrow(lag_dep)
e <- sum(lengths(contemp_dep)) * (1L + 2L * max_lag) + m * (2L * max_lag)
A <- matrix(
0L,
nrow = n,
ncol = n,
dimnames = replicate(2L, v, simplify = FALSE)
)
edgelist <- data.frame(from = character(e), to = character(e))
if (covariates) {
p <- length(all_vars) - length(resp)
resp <- all_vars
layout_y <- c(seq_len(p), p + rev(seq_along(resp[order(cg)])))
} else {
layout_y <- rev(seq_along(resp[order(cg)]))
}
resp_past <- paste0(resp, "_{t - 1}")
resp_future <- paste0(resp, "_{t + 1}")
resp_t <- paste0(resp, "_{t}")
layout <- data.frame(var = v, x = NA, y = NA)
layout[layout$var %in% resp_t, "x"] <- 0.0
layout[layout$var %in% resp_t, "y"] <- layout_y
layout[layout$var %in% resp_past, "x"] <- -1.0
layout[layout$var %in% resp_past, "y"] <- layout_y
layout[layout$var %in% resp_future, "x"] <- 1.0
layout[layout$var %in% resp_future, "y"] <- layout_y
var_past <- paste0(lag_dep$var, "_{t - ", lag_dep$order, "}")
resp_future <- paste0(lag_dep$resp, "_{t + ", lag_dep$order, "}")
resp_t <- paste0(lag_dep$resp, "_{t}")
var_t <- paste0(lag_dep$var, "_{t}")
A[cbind(var_past, resp_t)] <- 1L
A[cbind(var_t, resp_future)] <- 1L
m_seq <- seq_len(m)
edgelist[m_seq, ] <- data.frame(from = var_past, to = resp_t)
edgelist[m + m_seq, ] <- data.frame(from = var_t, to = resp_future)
for (i in seq_len(max_lag - 1L)) {
var_past <- paste0(lag_dep$var, "_{t - ", lag_dep$order + i, "}")
resp_past <- paste0(lag_dep$resp, "_{t - ", i, "}")
var_future <- paste0(lag_dep$var, "_{t + ", lag_dep$order + i, "}")
resp_future <- paste0(lag_dep$resp, "_{t + ", i, "}")
A[cbind(var_past, resp_past)] <- 1L
A[cbind(var_future, resp_future)] <- 1L
resp_past <- paste0(resp, "_{t - ", i, "}")
resp_future <- paste0(resp, "_{t + ", i, "}")
layout[layout$var %in% resp_past, "x"] <- (-1.0) * i
layout[layout$var %in% resp_past, "y"] <- layout_y
layout[layout$var %in% resp_future, "x"] <- (1.0) * i
layout[layout$var %in% resp_future, "y"] <- layout_y
edgelist[(2 * i) * m + m_seq, ] <-
data.frame(from = var_past, to = resp_past)
edgelist[(2 * i + 1) * m + m_seq, ] <-
data.frame(from = var_future, to = resp_future)
}
idx <- m * (2L * max_lag)
for (i in seq_along(contemp_dep)) {
k <- length(contemp_dep[[i]])
if (k > 0L) {
resp_ti <- paste0(resp[i], "_{t}")
contemp_t <- paste0(contemp_dep[[i]], "_{t}")
A[contemp_t, resp_ti] <- 1L
edgelist[idx + seq_len(k), ] <-
data.frame(from = contemp_t, to = resp_ti)
idx <- idx + k
for (j in seq_len(max_lag)) {
contemp_past <- paste0(contemp_dep[[i]], "_{t - ", j, "}")
resp_past <- paste0(resp[i], "_{t - ", j, "}")
contemp_future <- paste0(contemp_dep[[i]], "_{t + ", j, "}")
resp_future <- paste0(resp[i], "_{t + ", j, "}")
A[contemp_past, resp_past] <- 1L
A[contemp_future, resp_future] <- 1L
edgelist[idx + seq_len(k), ] <-
data.frame(from = contemp_past, resp_past)
edgelist[idx + k + seq_len(k), ] <-
data.frame(from = contemp_future, resp_future)
idx <- idx + 2 * k
}
}
}
list(A = A, edgelist = edgelist, layout = layout)
}

#' Get the Markov Blanket of a Response Variable
#'
#' @param x A `dynamiteformula` object.
#' @param y A `character` string naming the response variable.
Expand Down
12 changes: 6 additions & 6 deletions R/lags.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,23 +124,23 @@ find_lags <- function(x) {
#' @noRd
find_lag_orders <- function(x) {
if (!is.recursive(x)) {
return(list())
return(
data.frame(var = character(0L), order = integer(0L))
)
}
if (is.call(x)) {
if (identical(as.character(x[[1L]]), "lag")) {
if (length(x) == 2L) {
return(list(list(lag = deparse1(x), order = 1L)))
return(data.frame(var = deparse1(x[[2L]]), order = 1L))
} else {
return(list(list(lag = deparse1(x), order = x[[3L]])))
return(data.frame(var = deparse1(x[[2L]]), order = x[[3L]]))
}
} else {
unlist(lapply(x[-1L], find_lag_orders), recursive = FALSE)
rbindlist_(lapply(x[-1L], find_lag_orders))
}
}
}



#' Extract Non-lag Variables from a Language Object
#'
#' @param x A `language` object
Expand Down
Loading

0 comments on commit 9218840

Please sign in to comment.