Skip to content

Commit

Permalink
Fix offset prediction in cross validation #372
Browse files Browse the repository at this point in the history
  • Loading branch information
seananderson committed Sep 20, 2024
1 parent cb83a62 commit 4b785e0
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 5 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: sdmTMB
Title: Spatial and Spatiotemporal SPDE-Based GLMMs with 'TMB'
Version: 0.6.0.9004
Version: 0.6.0.9005
Authors@R: c(
person(c("Sean", "C."), "Anderson", , "sean@seananderson.ca",
role = c("aut", "cre"),
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# sdmTMB (development version)

* Fix passing of `offset` argument through in `sdmTMB_cv()`. Before it was being
omitted in the prediction (i.e., set to 0). #372

* Fig bug in `exponentiate` argument for `tidy()`. Set `conf.int = TRUE` as
default. #353

Expand Down
10 changes: 6 additions & 4 deletions R/cross-val.R
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,10 @@ sdmTMB_cv <- function(
cli_abort("`weights` cannot be specified within sdmTMB_cv().")
}
if ("offset" %in% names(dot_args)) {
.offset <- eval(dot_args$offset)
if (parallel && !is.character(.offset) && !is.null(.offset)) {
cli_abort("We recommend using a character value for 'offset' (indicating the column name) when applying parallel cross validation.")
if (!is.character(dot_args$offset)) {
cli_abort("Please use a character value for 'offset' (indicating the column name) for cross validation.")
}
.offset <- eval(dot_args$offset)
} else {
.offset <- NULL
}
Expand Down Expand Up @@ -369,7 +369,9 @@ sdmTMB_cv <- function(

# FIXME: only use TMB report() below to be faster!
# predict for withheld data:
predicted <- predict(object, newdata = cv_data, type = "response")
predicted <- predict(object, newdata = cv_data, type = "response",
offset = if (!is.null(.offset)) cv_data[[.offset]] else rep(0, nrow(cv_data)))

cv_data$cv_predicted <- predicted$est
response <- get_response(object$formula[[1]])
withheld_y <- predicted[[response]]
Expand Down
22 changes: 22 additions & 0 deletions tests/testthat/test-offset.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,28 @@ test_that("Offset prediction matches glm()", {
# p_glmmTMB <- predict(fit_glmmTMB, newdata = dat)
# expect_equal(p$est, unname(p_glmmTMB))
})

test_that("offset gets passed through cross validation as expected #372", {
dat <- subset(dogfish, catch_weight > 0)
expect_error(
x <- sdmTMB_cv(catch_weight ~ 1,
data = dat,
family = Gamma("log"), offset = log(dat$area_swept), spatial = "off"
),
"offset"
)
set.seed(1)
x <- sdmTMB_cv(catch_weight ~ 1,
data = dat, family = Gamma("log"),
offset = "area_swept", spatial = "off",
mesh = make_mesh(dat, c("X", "Y"), cutoff = 10), k_folds = 2
)
y <- x$data[, c("catch_weight", "cv_predicted")]
# plot(y$catch_weight, y$cv_predicted)
# if offset is applied, will have unique values because an intercept-only model:
expect_true(length(unique(y$cv_predicted)) == 684L)
})

# #
# # offset/prediction setting checks:
#
Expand Down

0 comments on commit 4b785e0

Please sign in to comment.