From 4b785e06e845e48f6390bf87e64414c061323e31 Mon Sep 17 00:00:00 2001 From: Sean Anderson Date: Fri, 20 Sep 2024 13:34:34 -0700 Subject: [PATCH] Fix offset prediction in cross validation #372 --- DESCRIPTION | 2 +- NEWS.md | 3 +++ R/cross-val.R | 10 ++++++---- tests/testthat/test-offset.R | 22 ++++++++++++++++++++++ 4 files changed, 32 insertions(+), 5 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 7cc976e50..c69e2f1d4 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Type: Package Package: sdmTMB Title: Spatial and Spatiotemporal SPDE-Based GLMMs with 'TMB' -Version: 0.6.0.9004 +Version: 0.6.0.9005 Authors@R: c( person(c("Sean", "C."), "Anderson", , "sean@seananderson.ca", role = c("aut", "cre"), diff --git a/NEWS.md b/NEWS.md index b3db0f703..18375ff40 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # sdmTMB (development version) +* Fix passing of `offset` argument through in `sdmTMB_cv()`. Before it was being + omitted in the prediction (i.e., set to 0). #372 + * Fig bug in `exponentiate` argument for `tidy()`. Set `conf.int = TRUE` as default. #353 diff --git a/R/cross-val.R b/R/cross-val.R index 06abe9af1..6e871f13c 100644 --- a/R/cross-val.R +++ b/R/cross-val.R @@ -287,10 +287,10 @@ sdmTMB_cv <- function( cli_abort("`weights` cannot be specified within sdmTMB_cv().") } if ("offset" %in% names(dot_args)) { - .offset <- eval(dot_args$offset) - if (parallel && !is.character(.offset) && !is.null(.offset)) { - cli_abort("We recommend using a character value for 'offset' (indicating the column name) when applying parallel cross validation.") + if (!is.character(dot_args$offset)) { + cli_abort("Please use a character value for 'offset' (indicating the column name) for cross validation.") } + .offset <- eval(dot_args$offset) } else { .offset <- NULL } @@ -369,7 +369,9 @@ sdmTMB_cv <- function( # FIXME: only use TMB report() below to be faster! # predict for withheld data: - predicted <- predict(object, newdata = cv_data, type = "response") + predicted <- predict(object, newdata = cv_data, type = "response", + offset = if (!is.null(.offset)) cv_data[[.offset]] else rep(0, nrow(cv_data))) + cv_data$cv_predicted <- predicted$est response <- get_response(object$formula[[1]]) withheld_y <- predicted[[response]] diff --git a/tests/testthat/test-offset.R b/tests/testthat/test-offset.R index aa6a636f2..d19420d70 100644 --- a/tests/testthat/test-offset.R +++ b/tests/testthat/test-offset.R @@ -138,6 +138,28 @@ test_that("Offset prediction matches glm()", { # p_glmmTMB <- predict(fit_glmmTMB, newdata = dat) # expect_equal(p$est, unname(p_glmmTMB)) }) + +test_that("offset gets passed through cross validation as expected #372", { + dat <- subset(dogfish, catch_weight > 0) + expect_error( + x <- sdmTMB_cv(catch_weight ~ 1, + data = dat, + family = Gamma("log"), offset = log(dat$area_swept), spatial = "off" + ), + "offset" + ) + set.seed(1) + x <- sdmTMB_cv(catch_weight ~ 1, + data = dat, family = Gamma("log"), + offset = "area_swept", spatial = "off", + mesh = make_mesh(dat, c("X", "Y"), cutoff = 10), k_folds = 2 + ) + y <- x$data[, c("catch_weight", "cv_predicted")] + # plot(y$catch_weight, y$cv_predicted) + # if offset is applied, will have unique values because an intercept-only model: + expect_true(length(unique(y$cv_predicted)) == 684L) +}) + # # # # offset/prediction setting checks: #