Skip to content

Commit

Permalink
Merge pull request #61 from tlverse/devel
Browse files Browse the repository at this point in the history
  • Loading branch information
nhejazi authored Jun 24, 2020
2 parents c670f3d + d05ca80 commit 24fc839
Show file tree
Hide file tree
Showing 62 changed files with 1,713 additions and 1,235 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
^Makefile$
^LICENSE$
^sandbox$
^paper$
^docs$
^_pkgdown\.yml$
^CRAN-RELEASE$
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
branches:
only:
- master
- devel

env:
global:
Expand Down
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: hal9001
Title: The Scalable Highly Adaptive Lasso
Version: 0.2.5
Version: 0.2.6
Authors@R: c(
person("Jeremy", "Coyle", email = "jeremyrcoyle@gmail.com",
role = c("aut", "cre"),
Expand Down Expand Up @@ -43,7 +43,7 @@ Imports:
utils,
methods,
assertthat,
origami (>= 0.8.1),
origami (>= 1.0.3),
glmnet
Suggests:
testthat,
Expand All @@ -62,4 +62,4 @@ LinkingTo:
Rcpp,
RcppEigen
VignetteBuilder: knitr
RoxygenNote: 7.0.2
RoxygenNote: 7.1.0
61 changes: 24 additions & 37 deletions R/hal.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@
#' @param id a vector of ID values, used to generate cross-validation folds for
#' cross-validated selection of the regularization parameter lambda.
#' @param offset a vector of offset values, used in fitting.
#' @param screen_basis If \code{TRUE}, use a screening procedure to reduce the
#' number of basis functions fitted.
#' @param screen_lambda If \code{TRUE}, use a screening procedure to reduce the
#' number of lambda values evaluated.
#' @param ... Other arguments passed to \code{\link[glmnet]{cv.glmnet}}. Please
#' consult its documentation for a full list of options.
#' @param yolo A \code{logical} indicating whether to print one of a curated
Expand Down Expand Up @@ -112,23 +108,37 @@ fit_hal <- function(X,
id = NULL,
offset = NULL,
cv_select = TRUE,
screen_basis = FALSE,
screen_lambda = FALSE,
...,
yolo = TRUE) {

# check arguments and catch function call
call <- match.call(expand.dots = TRUE)
fit_type <- match.arg(fit_type)
family <- match.arg(family)

# catch dot arguments to stop misuse of glmnet's `lambda.min.ratio`
dot_args <- list(...)
assertthat::assert_that(!("lambda.min.ratio" %in% names(dot_args) &
family == "binomial"),
msg = paste(
"`glmnet` silently ignores",
"`lambda.min.ratio` when",
"`family = 'binomial'`."
)
)

# NOTE: NOT supporting binomial outcomes with lassi method currently
if (fit_type == "lassi" && family == "binomial") {
stop("For binary outcomes, please set argument 'fit_type' to 'glmnet'.")
}
if (fit_type == "lassi" && family == "cox") {
stop("For Cox models, please set argument 'fit_type' to 'glmnet'.")
}
assertthat::assert_that(!(fit_type == "lassi" && family == "binomial"),
msg = paste(
"For binary outcomes, please set",
"argument 'fit_type' to 'glmnet'."
)
)
assertthat::assert_that(!(fit_type == "lassi" && family == "cox"),
msg = paste(
"For Cox models, please set argument",
"'fit_type' to 'glmnet'."
)
)

# cast X to matrix -- and don't start the timer until after
if (!is.matrix(X)) {
Expand All @@ -154,21 +164,7 @@ fit_hal <- function(X,

# make design matrix for HAL
if (is.null(basis_list)) {
if (screen_basis) {
# NOTE: foldid is never missing since created above if not supplied
good_basis <- hal_screen_basis(
x = X,
y = Y,
family = family,
offset = offset,
foldid = foldid,
max_degree = max_degree
)
basis_lists <- lapply(good_basis, basis_list_cols, X)
basis_list <- unlist(basis_lists, recursive = FALSE)
} else {
basis_list <- enumerate_basis(X, max_degree)
}
basis_list <- enumerate_basis(X, max_degree)
}

# generate a vector of col lists corresponding to the bases generated
Expand Down Expand Up @@ -240,15 +236,6 @@ fit_hal <- function(X,
coefs <- hal_lasso$betas_mat[, "lambda_1se"]
}
} else if (fit_type == "glmnet") {
if ((screen_lambda) && (length(lambda) != 1)) {
# reduce the set of lambdas to fit
lambda <- hal_screen_lambda(x_basis, Y,
family = family,
lambda = lambda,
foldid = foldid,
offset = offset
)
}
# just use the standard implementation available in glmnet
if (!cv_select) {
hal_lasso <- glmnet::glmnet(
Expand Down
22 changes: 13 additions & 9 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
#' \code{hal9001}.
#'
#' @param object An object of class \code{hal9001}, containing the results of
#' fitting the Highly Adaptive Lasso, as produced by a call to \code{fit_hal}.
#' fitting the Highly Adaptive Lasso, as produced by \code{\link{fit_hal}}.
#' @param offset A vector of offsets. Must be provided if provided at training
#' @param lambda A single lambda value or a vector of lambdas to use for
#' prediction. If \code{NULL}, a value of lambda will be selected based on
#' cross-validation, using \code{\link[glmnet]{cv.glmnet}}.
#' @param ... Additional arguments passed to \code{predict} as necessary.
#' @param new_data A \code{matrix} or \code{data.frame} containing new data
#' (observations NOT used in fitting the \code{hal9001} object passed in via
Expand All @@ -25,11 +22,18 @@
#'
#' @export
#'
#' @return A \code{numeric} vector of predictions from a fitted \code{hal9001}
#' object.
#' @note This prediction method does not function similarly to the equivalent
#' method from \pkg{glmnet}. In particular, this procedure will NOT return a
#' subset of lambdas originally specified in callingo \code{\link{fit_hal}}
#' nor result in re-fitting. Instead, it will return predictions for all of
#' the lambdas specified in the call to \code{\link{fit_hal}} that constructs
#' \code{object}, when \code{cv_select = FALSE}. When \code{cv_select = TRUE},
#' predictions will only be returned for the value of lambda selected by
#' cross-validation.
#'
#' @return A \code{numeric} vector of predictions from a \code{hal9001} object.
predict.hal9001 <- function(object,
offset = NULL,
lambda = NULL,
...,
new_data,
new_X_unpenalized = NULL) {
Expand Down Expand Up @@ -89,8 +93,8 @@ predict.hal9001 <- function(object,
) + object$coefs[1])
}
} else {
# Note: there is no intercept in the Cox mode (its built into the baseline
# hazard, and like it, would cancel in the partial likelihood.)
# Note: there is no intercept in the Cox model (built into the baseline
# hazard and would cancel in the partial likelihood).
# message(paste("The Cox Model is not commonly used for prediction,",
# "proceed with caution."))
if (ncol(object$coefs) > 1) {
Expand Down
Loading

0 comments on commit 24fc839

Please sign in to comment.