## rescaled exponential splitting ## The survival object 'y' is rescaled so that ## a. overall death rate is 1.0 ## b. death rate within small intervals of time is 1 ## The first makes printouts easier to read, since the rates in subnodes are ## numbers like "1.23" (23% higher event rate than the root node) instead ## of "0.0014" (which requires looking back at the root node rate and ## dividing). The second makes the data appear to be Poisson, and causes ## the early splits at least to be equivalent to the local full likelihood ## method of LeBlanc and Crowley ## rpart.exp <- function(y, offset, parms, wt) { if (!inherits(y, "Surv")) stop("Response must be a 'survival' object - use the 'Surv()' function") ny <- ncol(y) n <- nrow(y) status <- y[, ny] if (any(y[, 1L] <= 0)) stop("Observation time must be > 0") if (all(status == 0)) stop("No deaths in data set") time <- y[ , ny - 1L] ## Make my table of time intervals. The first goes from 0 to the first ## death, the next from death 2 to death 3, ..., and the last from ## "next to last death" to "max time in dataset". ## We also need to avoid a pathological condition in some data sets, where ## two death times differ by a trivial amount, e.g., 10^-16, perhaps due ## to roundoff error in creating the input data. Ammalgamate such ## intervals. This turns out to be hard to do in S, but easy in C dtimes <- sort(unique(time[status == 1])) # unique death times temp <- .Call(C_rpartexp2, as.double(dtimes), as.double(.Machine$double.eps)) dtimes <- dtimes[temp == 1] ## For the sake of speed, restrict the number of intervals to be < 1000. ## (Actually, anything > 100 is probably overkill for the ## actual task at hand, which is to approximately scale to exponential). if (length(dtimes) > 1000) dtimes <- quantile(dtimes, 0:1000/1000) ## The last interval goes to the max time in the data set itable <- c(0, dtimes[-length(dtimes)], max(time)) # set of intervals drate2 <- function(n, ny, y, wt, itable) { ## An alternative to the drate1 function ## Why? The pyears2 routine changed in 6/2001, with the inclusion ## of case weights. We need the newer version. If you have the ## older version of the survival library, the above will crash S. ## How to tell -- list the pyears function, and see whether it's ## call to pyears2 has weights in the argument list. ## time <- y[, ny - 1L] status <- y[, ny] ilength <- diff(itable) #lengths of intervals ngrp <- length(ilength) #number of intervals ## The code below is as opaque as any I've written, all in the ## service of "no for loops". ## First, 'index' gives the time interval (as defined by itable) ## in which the end of each observation's follow-up (time) lies. ## Then 'itime' will be the amount of time spent in that last ## interval, which is of course somewhat less than ilength. index <- unclass(cut(time, itable, include.lowest = TRUE)) itime <- time - itable[index] if (ny == 3L) { ## there is both a start time and a stop time ## compute the amount of time NOT spent in the interval that ## the start time lies in. stime <- y[, 1L] #start time for each interval index2 <- unclass(cut(stime, itable, include.lowest = TRUE)) itime2 <- stime - itable[index2] } ## Compute the amount of person-years in each of the intervals ## This is: (width of interval) * (number of "time" elements that ## end in an interval farther right) ## + (ending times in this interval) ## By construction, I know that there is at least 1 obs in each of the ## intervals, so tab1 is of a determined length tab1 <- table(index) temp <- rev(cumsum(rev(tab1))) # cumsum, counting from the right pyears <- ilength * c(temp[-1L], 0) + tapply(itime, index, sum) if (ny == 3L) { ## subtract off the time before "start" tab2 <- table(index2, levels = 1:ngrp) # force the length of tab2 temp <- rev(cumsum(rev(tab2))) py2 <- ilength * c(0, temp[-ngrp]) + tapply(itime2, index2, sum) pyears <- pyears - py2 } deaths <- tapply(status, index, sum) rate <- deaths/pyears # hazard rate in each interval rate } ## ## Now, compute the "new y" for each observation. ## This is a stretching of the time axis ## The cumulative hazard over each interval is rate*length(interval), ## and is the basis of the rescaling. rate <- drate2(n, ny, y, wt, itable) cumhaz <- cumsum(c(0, rate * diff(itable))) newy <- approx(itable, cumhaz, time)$y if (ny == 3L) newy <- newy - approx(itable, cumhaz, y[, 1L])$y if (length(offset) == n) newy <- newy * exp(offset) if (missing(parms)) parms <- c(shrink = 1L, method = 1L) else { parms <- as.list(parms) if (is.null(names(parms))) stop("You must input a named list for parms") parmsNames <- c("method", "shrink") indx <- pmatch(names(parms), parmsNames, 0L) if (any(indx == 0L)) stop(gettextf("'parms' component not matched: %s", names(parms)[indx == 0L]), domain = NA) else names(parms) <- parmsNames[indx] if (is.null(parms$method)) method <- 1L else method <- pmatch(parms$method, c("deviance", "sqrt")) if (is.na(method)) stop("Invalid error method for Poisson") if (is.null(parms$shrink)) shrink <- 2L - method else shrink <- parms$shrink if (!is.numeric(shrink) || shrink < 0L) stop("Invalid shrinkage value") parms <- c(shrink = shrink, method = method) } list(y = cbind(newy, y[, 2L]), parms = parms, numresp = 2L, numy = 2L, summary = function(yval, dev, wt, ylevel, digits) { paste0(" events=", formatg(yval[, 2L]), ", estimated rate=" , formatg(yval[, 1L], digits), " , mean deviance=", formatg(dev/wt, digits)) }, text = function(yval, dev, wt, ylevel, digits, n, use.n) { if (use.n) paste0(formatg(yval[, 1L], digits), "\n", formatg(yval[, 2L]) , "/", n) else paste(formatg(yval[, 1L], digits)) }) }