Skip to content

Commit

Permalink
Account for weights when getting default spline knots from quantiles
Browse files Browse the repository at this point in the history
  • Loading branch information
chjackson committed Dec 1, 2023
1 parent 0dbd823 commit 8dd7d86
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 8 deletions.
48 changes: 43 additions & 5 deletions R/spline.R
Original file line number Diff line number Diff line change
Expand Up @@ -806,19 +806,22 @@ flexsurvspline <- function(formula, data, weights, bhazard, rtrunc, subset,
m <- eval(temp, parent.frame())
Y <- check.flexsurv.response(model.extract(m, "response"))

dtimes <- Y[,"stop"][Y[,"status"]==1]

deaths <- Y[,"status"]==1
dtimes <- Y[deaths,"stop"]
intcens <- Y[,"status"]==3
midpoints <- (Y[intcens,"time2"] + Y[intcens,"time1"])/2
if (any(intcens)) dtimes <- sort(c(dtimes, midpoints))
midpoints <- (Y[,"time2"] + Y[,"time1"])/2
ktimes <- ifelse(deaths, Y[,"stop"], ifelse(intcens, midpoints, NA))
kinds <- deaths | intcens

if (is.null(knots)) {
is.wholenumber <-
function(x, tol = .Machine$double.eps^0.5) abs(x - round(x)) < tol
if (is.null(k)) stop("either \"knots\" or \"k\" must be specified")
if (!is.numeric(k)) stop("k must be numeric")
if (!is.wholenumber(k) || (k<0)) stop("number of knots \"k\" must be a non-negative integer")
knots <- quantile(tsfn(dtimes,timescale), seq(0, 1, length.out=k+2)[-c(1,k+2)])
knots <- quantile_weighted(tsfn(ktimes[kinds],timescale),
probs = seq(0, 1, length.out=k+2)[-c(1,k+2)],
weights = model.extract(m, "weights")[kinds])
}
else {
if (!is.numeric(knots)) stop("\"knots\" must be a numeric vector")
Expand Down Expand Up @@ -905,3 +908,38 @@ betax_warn <- function(X, beta, offset){
if (!isTRUE(all.equal(X,0)) || !isTRUE(all.equal(beta,0)) || !isTRUE(all.equal(offset,0)))
warning("`X`, `beta` and `offset` arguments not supported since v2.3. Instead the first element of `gamma` should be modified to include any covariate effects or offsets.")
}

##' Weighted quantile function
##'
##' Works by multiplying the [0,1] weights by a large number
##' `max_rep*length(x)`, then rounding to an integer. Each element of x is
##' duplicated to the length defined by these integers. Then
##' quantiles of x are obtained using the standard `quantile`
##' function.
##'
##' @inheritParams quantile
##'
##' @param weights Vector of non-negative numbers of same length as
##' `x`, in proportion to the weights of `x`. Elements can be greater
##' than 1, since this is normalised to sum to 1 internally.
##'
##' @param `max_rep` is the average number of times that an element of
##' `x` will be replicated. Increasing this will give more accuracy
##' at the cost of bigger memory requirements for longer `x`. The
##' default of 100 is chosen to be rough, given the intended purpose
##' of this function for choosing default knots for a spline to span
##' `x`.
##'
##' @noRd
quantile_weighted <- function(x, probs, weights=NULL, max_rep=100, ...){
x_expand <- x
if (!is.null(weights)){
weights_int <- round(length(x) * max_rep * (weights/sum(weights)))
x_expand <- rep(x, weights_int)
}
quantile(x_expand, probs)
}

### FIXME expanding should result in longer x
### so choose in proportion to the length of x somehow
### sum(weights_int) should be > length(x)
6 changes: 3 additions & 3 deletions man/Survspline.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions tests/testthat/test_spline.R
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,13 @@ test_that("interval censored data",{
spl0 <- flexsurvspline(Surv(recyrs, censrec) ~ 1, data=bci, k=1)
expect_equal(spl0$res["gamma0","est"], spl$res["gamma0","est"], tol=1e-02)
})

test_that("spline with weights",{
set.seed(1)
bc$w <- bc$recyrs # weight later obs
splw <- flexsurvspline(Surv(recyrs, censrec) ~ group, data=bc, weights=w, k=1)
spl <- flexsurvspline(Surv(recyrs, censrec) ~ group, data=bc, k=1)
expect_true(!isTRUE(identical(splw$knots, spl$knots)))
## knots chosen to account for weights, so should be higher in weighted model
expect_lt(mean(spl$knots), mean(splw$knots))
})

0 comments on commit 8dd7d86

Please sign in to comment.