Skip to content

Commit

Permalink
Complete analysis of age-district interactions
Browse files Browse the repository at this point in the history
  • Loading branch information
athowes committed Aug 19, 2024
1 parent 7ef5a39 commit fd14c39
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 29 deletions.
3 changes: 3 additions & 0 deletions make/4_run_aaa_fit_multi-sexbehav-sae.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,6 @@ sapply(na.omit(recent_ids), function(id) orderly::orderly_commit(id))
sapply(na.omit(recent_ids), function(id) orderly::orderly_push_archive(name = report, id = id))

run_commit_push("process_multi-sexbehav-sae")

#' PhD revisions
orderly::orderly_run("aaa_fit_multi-sexbehav-sae-age-space-interaction", parameters = list(iso3 == "MWI"))
97 changes: 95 additions & 2 deletions src/aaa_fit_multi-sexbehav-sae-age-space-interaction/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,100 @@ multinomial_model <- function(formula, model_name, S = 1000) {

message(paste0("Completed fitting ", model_name, "."))

message("Skipping post-processing")
message("Begin post-processing")

return(list(fit = fit))
#' Full R-INLA samples
full_samples <- inla.posterior.sample(n = S, result = fit)

#' Just the latent field
eta_samples <- lapply(full_samples, "[", "latent")

#' For some reason "latent" is comprised of more than only the latent field
eta_samples <- lapply(eta_samples, function(eta_sample) {
data.frame(eta_sample) %>%
tibble::rownames_to_column() %>%
rename(eta = 2) %>%
filter(substr(rowname, 1, 10) == "Predictor:") %>%
select(-rowname)
})

#' Into a matrix with a row for each observation and a column for each sample
eta_samples_matrix <- matrix(unlist(eta_samples), ncol = S)
eta_samples_df <- data.frame(eta_samples_matrix)

samples <- eta_samples_df %>%
mutate(
#' To split by
obs_idx = df$obs_idx,
#' To sample predictive
#' When n_eff_kish is missing there is no survey for that observation,
#' so the posterior predictive is meaningless. Setting to zero may save
#' some computation, but probably better to filter out entirely.
n_eff_kish_new = floor(ifelse(is.na(df$n_eff_kish), 0, df$n_eff_kish))
) %>%
split(.$obs_idx) %>%
mclapply(function(x) {
n_eff_kish_new <- x$n_eff_kish_new
#' Remove the obs_idx and n_eff_kish_new columns
x_samples <- x[1:(length(x) - 2)]
#' Normalise each column (to avoid overflow of softmax)
x_samples <- apply(x_samples, MARGIN = 2, FUN = function(x) x - max(x))
#' Exponentiate (can be done outside apply)
#' WARNING: That these are samples from lambda posterior isn't true! Come back to this
lambda_samples <- exp(x_samples)
#' Calculate samples from posterior of probabilites
prob_samples <- apply(lambda_samples, MARGIN = 2, FUN = function(x) x / sum(x))
#' Calculate predictive samples (including sampling variability)
prob_predictive_samples <- apply(prob_samples, MARGIN = 2, FUN = function(x) {
stats::rmultinom(n = 1, size = n_eff_kish_new, prob = x) / n_eff_kish_new
})
#' Return list, allowing extraction of each set of samples
list(
lambda = data.frame(lambda_samples),
prob = data.frame(prob_samples),
prob_predictive = data.frame(prob_predictive_samples))
})

lambda_samples_df <- bind_rows(lapply(samples, "[[", "lambda"))
prob_samples_df <- bind_rows(lapply(samples, "[[", "prob"))
prob_predictive_samples_df <- bind_rows(lapply(samples, "[[", "prob_predictive"))

#' Helper functions
row_summary <- function(df, ...) unname(apply(df, MARGIN = 1, ...))
median <- function(x) quantile(x, 0.5, na.rm = TRUE)
lower <- function(x) quantile(x, 0.025, na.rm = TRUE)
upper <- function(x) quantile(x, 0.975, na.rm = TRUE)

#' Quantile of the observation within posterior predictive
prob_predictive_quantile <- prob_predictive_samples_df %>%
mutate(estimate = df$estimate) %>%
apply(MARGIN = 1, function(x) {
estimate <- x[S + 1]
samples <- x[1:S]
if(all(is.na(samples))) return(NA)
else ecdf(samples)(estimate)
})

#' Calculate mean, median, lower and upper for each set of samples
df <- df %>%
mutate(
lambda_mean = row_summary(lambda_samples_df, mean),
lambda_median = row_summary(lambda_samples_df, median),
lambda_lower = row_summary(lambda_samples_df, lower),
lambda_upper = row_summary(lambda_samples_df, upper),
prob_mean = row_summary(prob_samples_df, mean),
prob_median = row_summary(prob_samples_df, median),
prob_lower = row_summary(prob_samples_df, lower),
prob_upper = row_summary(prob_samples_df, upper),
prob_predictive_mean = row_summary(prob_predictive_samples_df, mean),
prob_predictive_median = row_summary(prob_predictive_samples_df, median),
prob_predictive_lower = row_summary(prob_predictive_samples_df, lower),
prob_predictive_upper = row_summary(prob_predictive_samples_df, upper),
prob_predictive_quantile = prob_predictive_quantile,
model = model_name
)

message("Completed post-processing")

return(list(df = df, fit = fit))
}
15 changes: 4 additions & 11 deletions src/aaa_fit_multi-sexbehav-sae-age-space-interaction/orderly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ artefacts:
- multi-sexbehav-sae.csv
- multi-sexbehav-sae.pdf
- data:
description: Stacked proportion barplots
description: Age district effect comparison plots
filenames:
- stacked-proportions.pdf
- age-district-plot-nosex12m.png
- age-district-plot-sexcohab.png
- age-district-plot-sexnonregplus.png
- data:
description: Model selection information criteria for multinomial models
filenames:
Expand All @@ -26,15 +28,6 @@ artefacts:
description: Random effect variance parameter posterior means
filenames:
- variance-proportions.csv
- data:
description: Sample size recovery diagnostic
filenames:
- sample-size-recovery.pdf
- data:
description: Coverage posterior predictive checks
filenames:
- coverage-histograms.pdf
- coverage-ecdf-diff.pdf

parameters:
iso3:
Expand Down
142 changes: 126 additions & 16 deletions src/aaa_fit_multi-sexbehav-sae-age-space-interaction/script.R
Original file line number Diff line number Diff line change
Expand Up @@ -191,34 +191,27 @@ formula4 <- update(formula_baseline,

#' Model 4 extended (Chris thesis revision suggestion)
#' * age x space x category random effects (IID)
# formula4_extended <- update(formula4,
# . ~ . + f(area_idx, model = "besag", graph = adjM, scale.model = TRUE, group = cat_idx,
# control.group = list(model = "iid"), constr = TRUE, hyper = multi.utils::tau_pc(x = 0.001, u = 2.5, alpha = 0.01)) +
# )
formula4_extended <- update(formula4,
. ~ . + f(area_idx_copy, model = "iid", group = age_idx, replicate = cat_idx,
constr = TRUE, hyper = multi.utils::tau_pc(x = 0.001, u = 2.5, alpha = 0.01))
)

#' Fit the models

#' Number of Monte Carlo samples
S <- 1000

formulas <- list(formula4)
models <- list("Model 4")

#' tryCatch version for safety
try_multinomial_model <- function(...) {
return(tryCatch(multinomial_model(...), error = function(e) {
message("Error!")
return(NULL)
}))
}
formulas <- list(formula4, formula4_extended)
models <- list("Model 4", "Model 4 extended")

res <- purrr::pmap(
list(formula = formulas, model_name = models, S = S),
try_multinomial_model
multinomial_model
)

#' Extract the fitted models
res_fit <- lapply(res, "[[", 1)
res_df <- lapply(res, "[[", 1) %>% bind_rows()
res_fit <- lapply(res, "[[", 2)

#' Add columns for local DIC, WAIC, CPO
ic <- lapply(res_fit, function(fit) {
Expand Down Expand Up @@ -277,3 +270,120 @@ variance_df <- map(res_fit, function(fit)
)

write_csv(variance_df, "variance-proportions.csv", na = "")

#' Artefact: Smoothed district indicator estimates for multinomial models
res_df <- res_df %>%
#' Make it clear which of the estimates are raw and which are from the model (smoothed)
rename(
estimate_raw = estimate,
ci_lower_raw = ci_lower,
ci_upper_raw = ci_upper,
estimate_smoothed = prob_mean,
median_smoothed = prob_median,
ci_lower_smoothed = prob_lower,
ci_upper_smoothed = prob_upper
) %>%
mutate(iso3 = iso3, .before = indicator) %>%
relocate(model, .before = estimate_smoothed)

write_csv(res_df, "multi-sexbehav-sae.csv", na = "")

#' Create plotting data
res_plot <- res_df %>%
filter(area_id != iso3) %>%
pivot_longer(
cols = c(starts_with("estimate")),
names_to = c(".value", "source"),
names_pattern = "(.*)\\_(.*)"
) %>%
left_join( #' Use this to make it an sf again
select(areas, area_id),
by = "area_id"
) %>%
st_as_sf()

#' Artefact: Cloropleths
pdf("multi-sexbehav-sae.pdf", h = 8.25, w = 11.75)

res_plot %>%
split(~indicator + model) %>%
lapply(function(x)
x %>%
mutate(
age_group = fct_recode(age_group,
"15-19" = "Y015_019",
"20-24" = "Y020_024",
"25-29" = "Y025_029"
),
source = fct_relevel(source, "raw", "smoothed") %>%
fct_recode("Survey raw" = "raw", "Smoothed" = "smoothed")
) %>%
ggplot(aes(fill = estimate)) +
geom_sf(size = 0.1, colour = scales::alpha("grey", 0.25)) +
scale_fill_viridis_c(option = "C", label = label_percent()) +
facet_grid(age_group ~ survey_id + source) +
theme_minimal() +
labs(
title = paste0(substr(x$survey_id[1], 1, 3), ": ", x$indicator[1], " (", x$model[1], ")")
) +
theme(
axis.text = element_blank(),
axis.ticks = element_blank(),
panel.grid = element_blank(),
strip.text = element_text(face = "bold"),
plot.title = element_text(face = "bold"),
legend.position = "bottom",
legend.key.width = unit(4, "lines")
)
)

dev.off()

#' Artefact: Special revisions cloropleths
age_district_plot <- res_plot %>%
mutate(year = as.numeric(substr(survey_id, 4, 7))) %>%
filter(year == max(year)) %>%
mutate(
age_group = fct_recode(age_group,
"15-19" = "Y015_019",
"20-24" = "Y020_024",
"25-29" = "Y025_029"
),
indicator_plot = fct_recode(indicator,
"Not sexually active" = "nosex12m",
"One cohabiting partner" = "sexcohab",
"Non-regular or multiple partners(s) +" = "sexnonregplus"
),
source = fct_relevel(source, "raw", "smoothed") %>%
fct_recode("Survey raw" = "raw", "Smoothed" = "smoothed"),
source_extended = case_when(
source == "Survey raw" ~ "Survey raw",
source == "Smoothed" & model == "Model 4" ~ "Base",
source == "Smoothed" & model == "Model 4 extended" ~ "Extended"
),
source_extended = fct_relevel(source_extended, "Survey raw", "Base", "Extended")
) %>%
split(~ indicator) %>%
lapply(function(x)
x %>%
ggplot(aes(fill = estimate)) +
geom_sf(size = 0.1, colour = scales::alpha("grey", 0.25)) +
scale_fill_viridis_c(option = "C", label = label_percent()) +
facet_grid(age_group ~ source_extended) +
labs(
title = ,
caption = paste0(
"Indicator: ", x$indicator_plot[1], "; Survey: ", x$survey_id[1], "\n",
"Extended includes age-district interaction effects"
),
fill = "Estimate"
) +
theme_minimal() +
theme(
axis.text = element_blank(),
axis.ticks = element_blank(),
panel.grid = element_blank(),
)
)

imap(age_district_plot, \(x, name) ggsave(paste0("age-district-plot-", tolower(name), ".png"), x, h = 7.5, w = 6.25))

0 comments on commit fd14c39

Please sign in to comment.