From d2f167353acfd2bf93fd57127e29dad0f8a936b4 Mon Sep 17 00:00:00 2001 From: Jouni Helske Date: Sat, 25 May 2024 15:17:04 +0300 Subject: [PATCH] fix declaration order, flip correlations, run-extended --- R/as_data_frame.R | 12 +++++--- R/stanblocks.R | 53 +++++++++++++++++++++++++--------- R/stanblocks_families.R | 11 ++++--- tests/testthat/test-extended.R | 46 +++++++++++++++++++++++++++++ 4 files changed, 98 insertions(+), 24 deletions(-) diff --git a/R/as_data_frame.R b/R/as_data_frame.R index 7086f58..f6ca3e7 100644 --- a/R/as_data_frame.R +++ b/R/as_data_frame.R @@ -25,9 +25,10 @@ #' `rstan::extract(fit$stanfit, pars = "corr_matrix_nu")` if necessary. #' * `sigma_lambda`\cr Standard deviations of the latent factor loadings #' `lambda`. -#' * `corr_psi`\cr Pairwise correlations of the latent factors. -#' Samples of the full correlation matrix can be extracted manually as -#' `rstan::extract(fit$stanfit, pars = "corr_matrix_psi")` if necessary. +#' * `corr_psi`\cr Pairwise correlations of the noise terms of the latent +#' factors. Samples of the full correlation matrix can be extracted +#' manually as `rstan::extract(fit$stanfit, pars = "corr_matrix_psi")` if +#' necessary. #' * `sigma`\cr Standard deviations of gaussian responses. #' * `corr`\cr Pairwise correlations of multivariate gaussian responses. #' * `phi`\cr Describes various distributional parameters, such as: @@ -37,7 +38,10 @@ #' - Degrees of freedom of the Student t-distribution. #' * `omega`\cr Spline coefficients of the regression coefficients `delta`. #' * `omega_alpha`\cr Spline coefficients of time-varying `alpha`. -#' * `omega_psi`\cr Spline coefficients of the latent factors `psi`. +#' * `omega_psi`\cr Spline coefficients of the latent factors `psi`. Note that +#' in case of `nonzero_lambda = FALSE`, mean of these are used to flip the +#' sign of `psi` to avoid multimodality due to sign-switching, but +#' `omega_psi` variables are not modified. #' #' @export #' @family output diff --git a/R/stanblocks.R b/R/stanblocks.R index 6344c07..a3b7c99 100644 --- a/R/stanblocks.R +++ b/R/stanblocks.R @@ -524,8 +524,8 @@ create_transformed_parameters <- function(idt, backend, paste_rows( "transformed parameters {", random_text, - lfactor_text, declarations, + lfactor_text, statements, "}", .parse = FALSE @@ -889,19 +889,44 @@ create_generated_quantities <- function(idt, backend, P <- mvars$lfactor_def$P if (P > 0L && mvars$lfactor_def$correlated) { # evaluate number of corrs to avoid Stan warning about integer division - gen_psi <- paste_rows( - paste0( - "matrix[P, P] corr_matrix_psi = ", - "multiply_lower_tri_self_transpose(L_lf);" - ), - "vector[{(P * (P - 1L)) %/% 2L}] corr_psi;", - "for (k in 1:P) {{", - "for (j in 1:(k - 1)) {{", - "corr_psi[choose(k - 1, 2) + j] = corr_matrix_psi[j, k];", - "}}", - "}}", - .indent = idt(c(1, 1, 1, 2, 3, 2, 1)) - ) + if (any(!mvars$lfactor_def$nonzero_lambda)) { + signs <- paste0( + ifelse( + !mvars$lfactor_def$nonzero_lambda, + paste0("sign_omega_", psis), + "1" + ), + collapse = ", " + ) + gen_psi <- paste_rows( + paste0( + "matrix[P, P] corr_matrix_psi = ", + "multiply_lower_tri_self_transpose(L_lf);" + ), + "row_vector[P] signs = [{signs}];", + "vector[{(P * (P - 1L)) %/% 2L}] corr_psi;", + "for (k in 1:P) {{", + "for (j in 1:(k - 1)) {{", + "corr_psi[choose(k - 1, 2) + j] = signs[j] * signs[k] * corr_matrix_psi[j, k];", + "}}", + "}}", + .indent = idt(c(1, 1, 1, 1, 2, 3, 2, 1)) + ) + } else { + gen_psi <- paste_rows( + paste0( + "matrix[P, P] corr_matrix_psi = ", + "multiply_lower_tri_self_transpose(L_lf);" + ), + "vector[{(P * (P - 1L)) %/% 2L}] corr_psi;", + "for (k in 1:P) {{", + "for (j in 1:(k - 1)) {{", + "corr_psi[choose(k - 1, 2) + j] = corr_matrix_psi[j, k];", + "}}", + "}}", + .indent = idt(c(1, 1, 1, 2, 3, 2, 1)) + ) + } } n_cg <- n_unique(cg) generated_quantities_text <- character(n_cg) diff --git a/R/stanblocks_families.R b/R/stanblocks_families.R index 24b42a0..83a4f13 100644 --- a/R/stanblocks_families.R +++ b/R/stanblocks_families.R @@ -1710,12 +1710,11 @@ transformed_parameters_lines_default <- function(y, idt, noncentered, ) if (has_lfactor && !nonzero_lambda) { state_psi <- paste_rows( - "{{", - "int s = mean(omega_psi_{y}) < 0 ? -1 : 1;", - "psi_{y} = s * (omega_psi_{y} * Bs)';", - "lambda_{y} = -s * lambda_{y};", - "}}", - .indent = idt(c(1, 2, 2, 2, 1)), + "// try to avoid sign-switching by adjusting psi and lambda", + "real sign_omega_{y} = mean(omega_psi_{y}) < 0 ? -1 : 1;", + "psi_{y} = sign_omega_{y} * (omega_psi_{y} * Bs)';", + "lambda_{y} = -sign_omega_{y} * lambda_{y};", + .indent = idt(1), .parse = FALSE ) } else { diff --git a/tests/testthat/test-extended.R b/tests/testthat/test-extended.R index 1643161..4dfdded 100644 --- a/tests/testthat/test-extended.R +++ b/tests/testthat/test-extended.R @@ -837,4 +837,50 @@ test_that("latent factor models are identifiable", { expect_true(all(sumr6$ess_tail > 500)) expect_equal(summary(fit6, type="psi")$mean, -sim$psi, tolerance = 0.5) + + # Test bivariate case with nonzero_lambda + set.seed(123) + N <- 20 + T_ <- 100 + D <- 50 + x <- y <- matrix(0, N, T_) + psi <- matrix(NA, 2, T_) + lambda_y <- rnorm(N) + lambda_y <- lambda_y - mean(lambda_y) + lambda_x <- rnorm(N) + lambda_x <- 0.1 + lambda_x - mean(lambda_x) + L <- t(chol(matrix(c(1, 0.6, 0.6, 1), 2, 2))) + B <- t(splines::bs(seq_len(T_), df = D, intercept = TRUE)) + omega <- matrix(NA, 2, D) + omega[, 1] <- L %*% rnorm(2) + for(i in 2:D) { + omega[, i] <- omega[, i - 1] + L %*% rnorm(2) + } + psi[1, ] <- omega[1, ] %*% B + psi[2, ] <- omega[2, ] %*% B + for(t in 1:T_) { + y[, t] <- rnorm(N, lambda_y * psi[1, t]) + x[, t] <- rnorm(N, lambda_x * psi[2, t]) + } + d <- data.frame( + y = c(y), x = c(x), + time = rep(seq_len(T_), each = N), + id = rep(seq_len(N), times = T_) + ) + dformula <- obs(y ~ 1, family = "gaussian") + + obs(x ~ 1, family = "gaussian") + + lfactor(nonzero_lambda = c(FALSE, TRUE)) + splines(50) + fit <- dynamite( + dformula, + data = d, time = "time", group = "id", + backend = "cmdstanr", stanc_options = list("O1"), + iter_sampling = 5000, iter_warmup = 5000, + parallel_chains = 4, refresh = 0, seed = 1 + ) + sumr <- as_draws(fit) |> + posterior::summarise_draws() + + expect_true(all(sumr$rhat < 1.1)) + expect_true(all(sumr$ess_bulk > 500)) + expect_true(all(sumr$ess_tail > 500)) })