Skip to content

Commit

Permalink
Support for beta, X and offset in spline dpqr functions removed (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
chjackson committed Nov 29, 2023
1 parent 3b76de9 commit 9291e60
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 26 deletions.
50 changes: 29 additions & 21 deletions R/spline.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
##' restricted mean survival will be conditioned on survival up to
##' this time.
##'
##' @param beta Vector of covariate effects (deprecated).
##' @param beta Vector of covariate effects. Not supported and ignored since version 2.3, and this argument will be removed in 2.4.
##'
##' @param X Matrix of covariate values (deprecated).
##' @param X Matrix of covariate values. Not supported and ignored since version 2.3, and this argument will be removed in 2.4.
##'
##' @param knots Locations of knots on the axis of log time, supplied in
##' increasing order. Unlike in \code{\link{flexsurvspline}}, these include
Expand Down Expand Up @@ -56,7 +56,7 @@
##' the basis being orthogonal.
##'
##' @param offset An extra constant to add to the linear predictor
##' \eqn{\eta}{eta}.
##' \eqn{\eta}{eta}. Not supported and ignored since version 2.3, and this argument will be removed in 2.4.
##'
##' @param log,log.p Return log density or probability.
##'
Expand Down Expand Up @@ -109,14 +109,13 @@ NULL
## could be generalized to any function with vector of arguments
## TODO more special value handling

dbase.survspline <- function(q, gamma, knots, scale, offset=0, deriv=FALSE, spline="rp"){
dbase.survspline <- function(q, gamma, knots, scale, deriv=FALSE, spline="rp"){
if(!is.matrix(gamma)) gamma <- matrix(gamma, nrow=1)
if(!is.matrix(knots)) knots <- matrix(knots, nrow=1)
else if (spline=="splines2ns") stop("matrix knots not supported with spline=\"splines2ns\"")
lg <- nrow(gamma)
nret <- max(length(q), lg)
q <- rep(q, length.out=nret)
offset <- rep(offset, length.out=nret)

gamma <- matrix(rep(as.numeric(t(gamma)), length.out = ncol(gamma) * nret),
ncol = ncol(gamma), byrow = TRUE)
Expand All @@ -136,10 +135,9 @@ dbase.survspline <- function(q, gamma, knots, scale, offset=0, deriv=FALSE, spli
}
ind <- !is.na(q) & q > 0
q <- q[ind]
offset <- offset[ind]
gamma <- gamma[ind,,drop=FALSE]
knots <- knots[ind,,drop=FALSE]
list(ret=ret, gamma=gamma, q=q, scale=scale, ind=ind, knots=knots, offset=offset)
list(ret=ret, gamma=gamma, q=q, scale=scale, ind=ind, knots=knots)
}

dlink <- function(scale){
Expand All @@ -163,11 +161,12 @@ ldlink <- function(scale){
##' @rdname Survspline
##' @export
dsurvspline <- function(x, gamma, beta=0, X=0, knots=c(-10,10), scale="hazard", timescale="log", spline="rp", offset=0, log=FALSE){
d <- dbase.survspline(q=x, gamma=gamma, knots=knots, scale=scale, offset=offset, spline=spline)
betax_warn(beta, X, offset)
d <- dbase.survspline(q=x, gamma=gamma, knots=knots, scale=scale, spline=spline)
for (i in seq_along(d)) assign(names(d)[i], d[[i]])
if (any(ind)){
if (length(knots)==0) browser()
eta <- rowSums(basis(knots, tsfn(q, timescale), spline=spline) * gamma) + as.numeric(X %*% beta) + offset # log cumulative hazard/odds
eta <- rowSums(basis(knots, tsfn(q, timescale), spline=spline) * gamma) # log cumulative hazard/odds
eeta <- exp(ldlink(scale)(eta))
ret[ind][eeta==0] <- 0
ret[ind][is.nan(eeta)] <- NaN
Expand Down Expand Up @@ -213,22 +212,22 @@ Slink <- function(scale){
##' @rdname Survspline
##' @export
psurvspline <- function(q, gamma, beta=0, X=0, knots=c(-10,10), scale="hazard", timescale="log", spline="rp", offset=0, lower.tail=TRUE, log.p=FALSE){
d <- dbase.survspline(q=q, gamma=gamma, knots=knots, scale=scale, offset=offset, spline=spline)
betax_warn(beta, X, offset)
d <- dbase.survspline(q=q, gamma=gamma, knots=knots, scale=scale, spline=spline)
for (i in seq_along(d)) assign(names(d)[i], d[[i]])
if (any(ind)){
ret[ind][q==0] <- 0
ret[ind][q==Inf] <- 1
finite <- q>0 & q<Inf
ind <- ind[finite]
q <- q[finite]
offset <- offset[finite]
gamma <- gamma[finite,,drop=FALSE]
knots <- knots[finite,,drop=FALSE]
}
if (any(ind)){
if (length(knots)==0) browser()
eta <- rowSums(basis(knots, tsfn(q,timescale), spline=spline) * gamma) +
as.numeric(X %*% beta) + offset
as.numeric(X %*% beta)
surv <- Slink(scale)(eta)
ret[ind] <- as.numeric(1 - surv)
}
Expand All @@ -240,17 +239,19 @@ psurvspline <- function(q, gamma, beta=0, X=0, knots=c(-10,10), scale="hazard",
##' @rdname Survspline
##' @export
qsurvspline <- function(p, gamma, beta=0, X=0, knots=c(-10,10), scale="hazard", timescale="log", spline="rp", offset=0, lower.tail=TRUE, log.p=FALSE){
betax_warn(beta, X, offset)
if (log.p) p <- exp(p)
if (!lower.tail) p <- 1 - p
qgeneric(psurvspline, p=p, matargs=c("gamma","knots"), scalarargs=c("scale","timescale","spline"),
gamma=gamma, beta=beta, X=X, knots=knots, scale=scale, timescale=timescale, spline=spline, offset=offset)
gamma=gamma, knots=knots, scale=scale, timescale=timescale, spline=spline)
}

##' @rdname Survspline
##' @export
rsurvspline <- function(n, gamma, beta=0, X=0, knots=c(-10,10), scale="hazard", timescale="log", spline="rp", offset=0){
betax_warn(beta, X, offset)
if (length(n) > 1) n <- length(n)
ret <- qsurvspline(p=runif(n), gamma=gamma, beta=beta, X=X, knots=knots, scale=scale, timescale=timescale, spline=spline, offset=offset)
ret <- qsurvspline(p=runif(n), gamma=gamma, knots=knots, scale=scale, timescale=timescale, spline=spline)
ret
}

Expand All @@ -269,11 +270,11 @@ Hlink <- function(scale){
##' @export
Hsurvspline <- function(x, gamma, beta=0, X=0, knots=c(-10,10), scale="hazard", timescale="log", spline="rp", offset=0){
match.arg(scale, c("hazard","odds","normal"))
d <- dbase.survspline(q=x, gamma=gamma, knots=knots, scale=scale, offset=offset, spline=spline)
d <- dbase.survspline(q=x, gamma=gamma, knots=knots, scale=scale, spline=spline)
for (i in seq_along(d)) assign(names(d)[i], d[[i]])
if (any(ind)){
if (length(knots)==0) browser()
eta <- rowSums(basis(knots, tsfn(q,timescale), spline=spline) * gamma) + as.numeric(X %*% beta) + offset
eta <- rowSums(basis(knots, tsfn(q,timescale), spline=spline) * gamma) + as.numeric(X %*% beta)
ret[ind] <- as.numeric(Hlink(scale)(eta))
}
ret
Expand All @@ -293,12 +294,13 @@ hlink <- function(scale){
##' @export
hsurvspline <- function(x, gamma, beta=0, X=0, knots=c(-10,10), scale="hazard", timescale="log", spline="rp", offset=0){
## value for x=0? currently zero, should it be limit as x reduces to 0?
betax_warn(beta, X, offset)
match.arg(scale, c("hazard","odds","normal"))
d <- dbase.survspline(q=x, gamma=gamma, knots=knots, scale=scale, offset=offset, spline=spline)
d <- dbase.survspline(q=x, gamma=gamma, knots=knots, scale=scale, spline=spline)
for (i in seq_along(d)) assign(names(d)[i], d[[i]])
if (any(ind)){
if (length(knots)==0) browser()
eta <- rowSums(basis(knots, tsfn(q,timescale), spline=spline) * gamma) + as.numeric(X %*% beta) + offset
eta <- rowSums(basis(knots, tsfn(q,timescale), spline=spline) * gamma) + as.numeric(X %*% beta)
eeta <- hlink(scale)(eta)
ret[ind] <- dtsfn(q, timescale) * rowSums(dbasis(knots, tsfn(q, timescale), spline=spline) * gamma) * eeta
ret[ind][ret[ind]<=0] <- 0 # these correspond to invalid decreasing cumulative hazard functions
Expand All @@ -309,22 +311,23 @@ hsurvspline <- function(x, gamma, beta=0, X=0, knots=c(-10,10), scale="hazard",
##' @rdname Survspline
##' @export
rmst_survspline = function(t, gamma, beta=0, X=0, knots=c(-10,10), scale="hazard", timescale="log", spline="rp", offset=0, start=0){
betax_warn(beta, X, offset)
rmst_generic(psurvspline, t, start=start,
matargs = c("gamma", "knots"), scalarargs = c("scale", "timescale", "spline"),
gamma=gamma, knots=knots,
beta=beta, X=X,
scale=scale, timescale=timescale, spline=spline, offset=offset)
scale=scale, timescale=timescale, spline=spline)
}

##' @rdname Survspline
##' @export
mean_survspline = function(gamma, beta=0, X=0, knots=c(-10,10), scale="hazard", timescale="log", spline="rp", offset=0){
betax_warn(beta, X, offset)
nt <- if (is.matrix(gamma)) nrow(gamma) else 1
rmst_generic(psurvspline, rep(Inf,nt), start=0,
matargs = c("gamma", "knots"),
scalarargs = c("scale", "timescale", "spline"),
gamma=gamma, knots=knots,
beta=beta, X=X, scale=scale, timescale=timescale, spline=spline, offset=offset)
scale=scale, timescale=timescale, spline=spline)
}

##' Natural cubic spline basis
Expand Down Expand Up @@ -897,3 +900,8 @@ flexsurvspline <- function(formula, data, weights, bhazard, rtrunc, subset,
class(ret) <- "flexsurvreg"
ret
}

betax_warn <- function(X, beta, offset){
if (!isTRUE(all.equal(X,0)) || !isTRUE(all.equal(beta,0)) || !isTRUE(all.equal(offset,0)))
warning("`X`, `beta` and `offset` arguments not supported since v2.3. Instead the first element of `gamma` should be modified to include any covariate effects or offsets.")
}
8 changes: 3 additions & 5 deletions tests/testthat/test_spline.R
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,9 @@ test_that("flexsurvspline results match stpm in Stata",{
test_that("Expected survival",{
spl <- flexsurvspline(Surv(recyrs, censrec) ~ group, data=bc, k=1)
gamma <- coef(spl)[1:3]
beta <- coef(spl)[4:5]
surv <- function(x,...)psurvspline(q=x, gamma=gamma, beta=beta, knots=spl$knots, scale=spl$scale, lower.tail=FALSE, ...)
expect_equal(integrate(surv, 0, 5, X=c(0,0))$value, 4.341222955052117, tol=1e-04)# For group="good"
expect_equal(integrate(surv, 0, 5, X=c(1,0))$value, 3.664826479659649, tol=1e-04) # For group="medium"
expect_equal(integrate(surv, 0, 5, X=c(0,1))$value, 2.713301623208948, tol=1e-04) # For group="poor"
surv <- function(x,...)psurvspline(q=x, gamma=gamma, knots=spl$knots,
scale=spl$scale, lower.tail=FALSE, ...)
expect_equal(integrate(surv, 0, 5)$value, 4.341222955052117, tol=1e-04) # For group="good"
})

test_that("gamma in d/psurvspline can be matrix or vector",{
Expand Down
11 changes: 11 additions & 0 deletions tests/testthat/test_splinedist.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,14 @@ test_that("fixed-knot convenience wrappers",{

})

test_that("removing support for beta, X and offset",{
gamma <- c(0.1, 0.2)
d2 <- dsurvspline(1, gamma=gamma)
expect_warning(d2 <- dsurvspline(1, gamma=gamma, beta = 0.1))
expect_warning(d2 <- dsurvspline(1, gamma=gamma, X = 0.1))
expect_warning(d2 <- dsurvspline(1, gamma=gamma, offset = 0.1))
suppressWarnings({
d3 <- dsurvspline(1, gamma=gamma, X = 0.1)
expect_equal(d2, d3)
})
})

0 comments on commit 9291e60

Please sign in to comment.