From 74e98df152e09e2776ea493faedde06af52cd9f1 Mon Sep 17 00:00:00 2001 From: Nima Hejazi Date: Thu, 27 Jan 2022 17:56:11 -0500 Subject: [PATCH 1/2] update sl3_Task imports --- NAMESPACE | 1 + R/sl3_Task.R | 61 +++++++++++++++++++++++++++------------------------- 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 1e97c684..e30b7646 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -144,6 +144,7 @@ importFrom(caret,findLinearCombos) importFrom(data.table,":=") importFrom(data.table,data.table) importFrom(data.table,set) +importFrom(data.table,setDT) importFrom(data.table,setcolorder) importFrom(data.table,setnames) importFrom(data.table,setorderv) diff --git a/R/sl3_Task.R b/R/sl3_Task.R index ea94ae18..9e9cb0ac 100644 --- a/R/sl3_Task.R +++ b/R/sl3_Task.R @@ -7,11 +7,11 @@ #' @docType class #' #' @importFrom R6 R6Class -#' @importFrom assertthat assert_that is.count is.flag +#' @importFrom assertthat assert_that #' @importFrom origami make_folds #' @importFrom uuid UUIDgenerate #' @importFrom digest digest -#' @import data.table +#' @importFrom data.table data.table setcolorder setDT setnames ":=" #' #' @export #' @@ -22,7 +22,6 @@ #' @format \code{\link{R6Class}} object. #' #' @template sl3_Task_extra -# sl3_Task <- R6Class( classname = "sl3_Task", portable = TRUE, @@ -30,12 +29,11 @@ sl3_Task <- R6Class( public = list( initialize = function(data, covariates, outcome = NULL, outcome_type = NULL, outcome_levels = NULL, - id = NULL, weights = NULL, offset = NULL, time = NULL, - nodes = NULL, column_names = NULL, row_index = NULL, - folds = NULL, flag = TRUE, + id = NULL, weights = NULL, offset = NULL, + time = NULL, nodes = NULL, column_names = NULL, + row_index = NULL, folds = NULL, flag = TRUE, drop_missing_outcome = FALSE) { - # generate node list from other arguments if not explicitly specified if (is.null(nodes)) { nodes <- list( @@ -66,7 +64,7 @@ sl3_Task <- R6Class( # verify nodes are contained in column map missing_cols <- setdiff(all_nodes, names(column_names)) - assert_that( + assertthat::assert_that( length(missing_cols) == 0, msg = sprintf( "Couldn't find %s", @@ -78,7 +76,7 @@ sl3_Task <- R6Class( referenced_columns <- column_names[all_nodes] missing_cols <- setdiff(referenced_columns, data_names) - assert_that( + assertthat::assert_that( length(missing_cols) == 0, msg = sprintf( "Data doesn't contain referenced columns %s", @@ -134,8 +132,7 @@ sl3_Task <- R6Class( private$.folds <- folds # assign uuid using digest - private$.uuid <- digest(self$data) - + private$.uuid <- digest::digest(self$data) invisible(self) }, add_interactions = function(interactions, warn_on_existing = TRUE) { @@ -168,27 +165,32 @@ sl3_Task <- R6Class( # check if interaction terms numeric int_numeric <- sapply(int, function(i) is.numeric(self$X[[i]])) if (all(int_numeric)) { - d_int <- data.table(self$X[, prod.DT(.SD), .SD = int]) - setnames(d_int, paste0(int, collapse = "_")) + d_int <- data.table::data.table(self$X[, prod.DT(.SD), .SD = int]) + data.table::setnames(d_int, paste0(int, collapse = "_")) return(d_int) } else { # match interaction terms to X - Xmatch <- lapply(int, function(i) grep(i, colnames(self$X), value = T)) + Xmatch <- lapply(int, function(i) { + grep(i, colnames(self$X), value = TRUE) + }) + #browser() Xint <- as.list(as.data.frame(t(expand.grid(Xmatch)))) - d_Xint <- lapply(Xint, function(Xint) self$X[, prod.DT(.SD), .SD = Xint]) - setDT(d_Xint) - setnames(d_Xint, sapply(Xint, paste0, collapse = "_")) + d_Xint <- lapply(Xint, function(Xint) { + self$X[, prod.DT(.SD), .SD = Xint] + }) + data.table::setDT(d_Xint) + data.table::setnames(d_Xint, sapply(Xint, paste0, collapse = "_")) no_Xint <- rowSums(d_Xint) == 0 # happens when we omit 1 factor level if (any(int_numeric)) { d_Xint$other <- rep(0, nrow(d_Xint)) d_Xint[no_Xint, "other"] <- 1 if (any(int_numeric)) { - # we actually need to take the product if we have a numeric covariate - d_Xint[no_Xint, "other"] <- prod.DT(data.table( + # need to take the product if we have a numeric covariate + d_Xint[no_Xint, "other"] <- prod.DT(data.table::data.table( rep(1, sum(no_Xint)), - self$X[no_Xint, names(which(int_numeric)), with = F] + self$X[no_Xint, names(which(int_numeric)), with = FALSE] )) } other_name <- paste0("other.", paste0(int, collapse = "_")) @@ -199,10 +201,13 @@ sl3_Task <- R6Class( }) interaction_names <- unlist(lapply(interaction_data, colnames)) - interaction_data <- data.table(do.call(cbind, interaction_data)) - setnames(interaction_data, interaction_names) + interaction_data <- data.table::data.table( + do.call(cbind, interaction_data) + ) + data.table::setnames(interaction_data, interaction_names) - interaction_cols <- self$add_columns(interaction_data, column_uuid = NULL) + interaction_cols <- self$add_columns(interaction_data, + column_uuid = NULL) new_covariates <- c(self$nodes$covariates, interaction_names) return(self$next_in_chain( covariates = new_covariates, @@ -273,7 +278,7 @@ sl3_Task <- R6Class( # verify nodes are contained in dataset missing_cols <- setdiff(all_nodes, names(column_names)) - assert_that( + assertthat::assert_that( length(missing_cols) == 0, msg = sprintf( "Couldn't find %s", @@ -326,7 +331,7 @@ sl3_Task <- R6Class( new_folds <- NULL } else { if (must_reindex) { - stop("subset indicies have copies, this requires dropping folds.") + stop("subset indices have copies, this requires dropping folds.") } new_folds <- subset_folds(private$.folds, row_index) } @@ -399,7 +404,7 @@ sl3_Task <- R6Class( return(offset) }, print = function() { - cat(sprintf("A sl3 Task with %d obs and these nodes:\n", self$nrow)) + cat(sprintf("An sl3 Task with %d obs and these nodes:\n", self$nrow)) print(self$nodes) }, revere_fold_task = function(fold_number) { @@ -441,9 +446,8 @@ sl3_Task <- R6Class( X_dt[, intercept := 1] # make intercept first column - setcolorder(X_dt, c(old_ncol + 1, seq_len(old_ncol))) + data.table::setcolorder(X_dt, c(old_ncol + 1, seq_len(old_ncol))) } - return(X_dt) }, Y = function() { @@ -539,5 +543,4 @@ sl3_Task <- R6Class( #' @rdname sl3_Task #' #' @export -# make_sl3_Task <- sl3_Task$new From 4d6a71e3bb8a824ff355ba40b18198bab7646c2a Mon Sep 17 00:00:00 2001 From: Nima Hejazi Date: Thu, 27 Jan 2022 18:11:28 -0500 Subject: [PATCH 2/2] bump R requirement and switch as.data.frame to as.data.table per #375 --- DESCRIPTION | 2 +- NAMESPACE | 1 + R/sl3_Task.R | 12 ++++++------ man/sl3_Task.Rd | 7 ++++--- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 49cdcca6..7a675fa1 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -29,7 +29,7 @@ Maintainer: Jeremy Coyle Description: A modern implementation of the Super Learner prediction algorithm, coupled with a general-purpose framework for composing arbitrary pipelines for machine learning tasks. -Depends: R (>= 2.14.0) +Depends: R (>= 3.1.0) Imports: data.table, assertthat, diff --git a/NAMESPACE b/NAMESPACE index e30b7646..e7d567d1 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -142,6 +142,7 @@ importFrom(assertthat,is.count) importFrom(assertthat,is.flag) importFrom(caret,findLinearCombos) importFrom(data.table,":=") +importFrom(data.table,as.data.table) importFrom(data.table,data.table) importFrom(data.table,set) importFrom(data.table,setDT) diff --git a/R/sl3_Task.R b/R/sl3_Task.R index 9e9cb0ac..79b078a0 100644 --- a/R/sl3_Task.R +++ b/R/sl3_Task.R @@ -1,8 +1,9 @@ #' Define a Machine Learning Task #' -#' An increasingly less thin wrapper around a \code{data.table} containing the -#' data. Contains metadata about the particular machine learning problem, -#' including which variables are to be used as covariates and outcomes. +#' An increasingly thick wrapper around a \code{\link[data.table]{data.table}} +#' containing the data for a prediction task. This contains metadata about the +#' particular machine learning problem, including which variables are to be +#' used as covariates and outcomes. #' #' @docType class #' @@ -11,7 +12,7 @@ #' @importFrom origami make_folds #' @importFrom uuid UUIDgenerate #' @importFrom digest digest -#' @importFrom data.table data.table setcolorder setDT setnames ":=" +#' @importFrom data.table as.data.table data.table setcolorder setDT setnames ":=" #' #' @export #' @@ -173,8 +174,7 @@ sl3_Task <- R6Class( Xmatch <- lapply(int, function(i) { grep(i, colnames(self$X), value = TRUE) }) - #browser() - Xint <- as.list(as.data.frame(t(expand.grid(Xmatch)))) + Xint <- as.list(data.table::as.data.table(t(expand.grid(Xmatch)))) d_Xint <- lapply(Xint, function(Xint) { self$X[, prod.DT(.SD), .SD = Xint] diff --git a/man/sl3_Task.Rd b/man/sl3_Task.Rd index c5513f4a..8ed73776 100644 --- a/man/sl3_Task.Rd +++ b/man/sl3_Task.Rd @@ -19,9 +19,10 @@ Constructor below.} \code{sl3_Task} object } \description{ -An increasingly less thin wrapper around a \code{data.table} containing the -data. Contains metadata about the particular machine learning problem, -including which variables are to be used as covariates and outcomes. +An increasingly thick wrapper around a \code{\link[data.table]{data.table}} +containing the data for a prediction task. This contains metadata about the +particular machine learning problem, including which variables are to be +used as covariates and outcomes. } \section{Constructor}{