Skip to content

Commit

Permalink
Fixes for splines and interval censoring (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
chjackson committed Nov 30, 2023
1 parent 9291e60 commit 0dbd823
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
14 changes: 7 additions & 7 deletions R/spline.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ dbase.survspline <- function(q, gamma, knots, scale, deriv=FALSE, spline="rp"){
if(!is.matrix(gamma)) gamma <- matrix(gamma, nrow=1)
if(!is.matrix(knots)) knots <- matrix(knots, nrow=1)
else if (spline=="splines2ns") stop("matrix knots not supported with spline=\"splines2ns\"")
lg <- nrow(gamma)
nret <- max(length(q), lg)
nret <- max(length(q), nrow(gamma), nrow(knots))
q <- rep(q, length.out=nret)

gamma <- matrix(rep(as.numeric(t(gamma)), length.out = ncol(gamma) * nret),
Expand Down Expand Up @@ -165,7 +164,6 @@ dsurvspline <- function(x, gamma, beta=0, X=0, knots=c(-10,10), scale="hazard",
d <- dbase.survspline(q=x, gamma=gamma, knots=knots, scale=scale, spline=spline)
for (i in seq_along(d)) assign(names(d)[i], d[[i]])
if (any(ind)){
if (length(knots)==0) browser()
eta <- rowSums(basis(knots, tsfn(q, timescale), spline=spline) * gamma) # log cumulative hazard/odds
eeta <- exp(ldlink(scale)(eta))
ret[ind][eeta==0] <- 0
Expand Down Expand Up @@ -219,13 +217,12 @@ psurvspline <- function(q, gamma, beta=0, X=0, knots=c(-10,10), scale="hazard",
ret[ind][q==0] <- 0
ret[ind][q==Inf] <- 1
finite <- q>0 & q<Inf
ind <- ind[finite]
ind[ind][!finite] <- FALSE
q <- q[finite]
gamma <- gamma[finite,,drop=FALSE]
knots <- knots[finite,,drop=FALSE]
}
if (any(ind)){
if (length(knots)==0) browser()
eta <- rowSums(basis(knots, tsfn(q,timescale), spline=spline) * gamma) +
as.numeric(X %*% beta)
surv <- Slink(scale)(eta)
Expand Down Expand Up @@ -273,7 +270,6 @@ Hsurvspline <- function(x, gamma, beta=0, X=0, knots=c(-10,10), scale="hazard",
d <- dbase.survspline(q=x, gamma=gamma, knots=knots, scale=scale, spline=spline)
for (i in seq_along(d)) assign(names(d)[i], d[[i]])
if (any(ind)){
if (length(knots)==0) browser()
eta <- rowSums(basis(knots, tsfn(q,timescale), spline=spline) * gamma) + as.numeric(X %*% beta)
ret[ind] <- as.numeric(Hlink(scale)(eta))
}
Expand All @@ -299,7 +295,6 @@ hsurvspline <- function(x, gamma, beta=0, X=0, knots=c(-10,10), scale="hazard",
d <- dbase.survspline(q=x, gamma=gamma, knots=knots, scale=scale, spline=spline)
for (i in seq_along(d)) assign(names(d)[i], d[[i]])
if (any(ind)){
if (length(knots)==0) browser()
eta <- rowSums(basis(knots, tsfn(q,timescale), spline=spline) * gamma) + as.numeric(X %*% beta)
eeta <- hlink(scale)(eta)
ret[ind] <- dtsfn(q, timescale) * rowSums(dbasis(knots, tsfn(q, timescale), spline=spline) * gamma) * eeta
Expand Down Expand Up @@ -812,6 +807,11 @@ flexsurvspline <- function(formula, data, weights, bhazard, rtrunc, subset,
Y <- check.flexsurv.response(model.extract(m, "response"))

dtimes <- Y[,"stop"][Y[,"status"]==1]

intcens <- Y[,"status"]==3
midpoints <- (Y[intcens,"time2"] + Y[intcens,"time1"])/2
if (any(intcens)) dtimes <- sort(c(dtimes, midpoints))

if (is.null(knots)) {
is.wholenumber <-
function(x, tol = .Machine$double.eps^0.5) abs(x - round(x)) < tol
Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/test_spline.R
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,15 @@ test_that("splines2 orthogonal basis",{
expect_equal(spl_rp$res["groupMedium","est"], spl_ns$res["groupMedium","est"], tol=1e-03)
expect_equal(spl_rp$res["groupPoor","est"], spl_ns$res["groupPoor","est"], tol=1e-03)
})

test_that("interval censored data",{
bc$recyrs1 <- bc$recyrs + 0.001
bci <- bc[bc$censrec==1,]
## knots chosen from interval midpoints
spl1 <- flexsurvspline(Surv(recyrs, recyrs1, type="interval2") ~ 1, data=bci, k=1)
spl <- flexsurvspline(Surv(recyrs, censrec) ~ 1, data=bci, knots=spl1$knots[2],
bknots=spl1$knots[c(1,3)])
expect_equal(spl1$res["gamma0","est"], spl$res["gamma0","est"], tol=1e-02)
spl0 <- flexsurvspline(Surv(recyrs, censrec) ~ 1, data=bci, k=1)
expect_equal(spl0$res["gamma0","est"], spl$res["gamma0","est"], tol=1e-02)
})

0 comments on commit 0dbd823

Please sign in to comment.