Skip to content

Commit

Permalink
Merge pull request #378 from tlverse/fix-stringsAsFactors
Browse files Browse the repository at this point in the history
  • Loading branch information
nhejazi authored Jan 28, 2022
2 parents 2bc71d8 + 4d6a71e commit 6544257
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 37 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Maintainer: Jeremy Coyle <jeremyrcoyle@gmail.com>
Description: A modern implementation of the Super Learner prediction algorithm,
coupled with a general-purpose framework for composing arbitrary pipelines
for machine learning tasks.
Depends: R (>= 2.14.0)
Depends: R (>= 3.1.0)
Imports:
data.table,
assertthat,
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,10 @@ importFrom(assertthat,is.count)
importFrom(assertthat,is.flag)
importFrom(caret,findLinearCombos)
importFrom(data.table,":=")
importFrom(data.table,as.data.table)
importFrom(data.table,data.table)
importFrom(data.table,set)
importFrom(data.table,setDT)
importFrom(data.table,setcolorder)
importFrom(data.table,setnames)
importFrom(data.table,setorderv)
Expand Down
69 changes: 36 additions & 33 deletions R/sl3_Task.R
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
#' Define a Machine Learning Task
#'
#' An increasingly less thin wrapper around a \code{data.table} containing the
#' data. Contains metadata about the particular machine learning problem,
#' including which variables are to be used as covariates and outcomes.
#' An increasingly thick wrapper around a \code{\link[data.table]{data.table}}
#' containing the data for a prediction task. This contains metadata about the
#' particular machine learning problem, including which variables are to be
#' used as covariates and outcomes.
#'
#' @docType class
#'
#' @importFrom R6 R6Class
#' @importFrom assertthat assert_that is.count is.flag
#' @importFrom assertthat assert_that
#' @importFrom origami make_folds
#' @importFrom uuid UUIDgenerate
#' @importFrom digest digest
#' @import data.table
#' @importFrom data.table as.data.table data.table setcolorder setDT setnames ":="
#'
#' @export
#'
Expand All @@ -22,20 +23,18 @@
#' @format \code{\link{R6Class}} object.
#'
#' @template sl3_Task_extra
#
sl3_Task <- R6Class(
classname = "sl3_Task",
portable = TRUE,
class = TRUE,
public = list(
initialize = function(data, covariates, outcome = NULL,
outcome_type = NULL, outcome_levels = NULL,
id = NULL, weights = NULL, offset = NULL, time = NULL,
nodes = NULL, column_names = NULL, row_index = NULL,
folds = NULL, flag = TRUE,
id = NULL, weights = NULL, offset = NULL,
time = NULL, nodes = NULL, column_names = NULL,
row_index = NULL, folds = NULL, flag = TRUE,
drop_missing_outcome = FALSE) {


# generate node list from other arguments if not explicitly specified
if (is.null(nodes)) {
nodes <- list(
Expand Down Expand Up @@ -66,7 +65,7 @@ sl3_Task <- R6Class(
# verify nodes are contained in column map
missing_cols <- setdiff(all_nodes, names(column_names))

assert_that(
assertthat::assert_that(
length(missing_cols) == 0,
msg = sprintf(
"Couldn't find %s",
Expand All @@ -78,7 +77,7 @@ sl3_Task <- R6Class(
referenced_columns <- column_names[all_nodes]

missing_cols <- setdiff(referenced_columns, data_names)
assert_that(
assertthat::assert_that(
length(missing_cols) == 0,
msg = sprintf(
"Data doesn't contain referenced columns %s",
Expand Down Expand Up @@ -134,8 +133,7 @@ sl3_Task <- R6Class(
private$.folds <- folds

# assign uuid using digest
private$.uuid <- digest(self$data)

private$.uuid <- digest::digest(self$data)
invisible(self)
},
add_interactions = function(interactions, warn_on_existing = TRUE) {
Expand Down Expand Up @@ -168,27 +166,31 @@ sl3_Task <- R6Class(
# check if interaction terms numeric
int_numeric <- sapply(int, function(i) is.numeric(self$X[[i]]))
if (all(int_numeric)) {
d_int <- data.table(self$X[, prod.DT(.SD), .SD = int])
setnames(d_int, paste0(int, collapse = "_"))
d_int <- data.table::data.table(self$X[, prod.DT(.SD), .SD = int])
data.table::setnames(d_int, paste0(int, collapse = "_"))
return(d_int)
} else {
# match interaction terms to X
Xmatch <- lapply(int, function(i) grep(i, colnames(self$X), value = T))
Xint <- as.list(as.data.frame(t(expand.grid(Xmatch))))
Xmatch <- lapply(int, function(i) {
grep(i, colnames(self$X), value = TRUE)
})
Xint <- as.list(data.table::as.data.table(t(expand.grid(Xmatch))))

d_Xint <- lapply(Xint, function(Xint) self$X[, prod.DT(.SD), .SD = Xint])
setDT(d_Xint)
setnames(d_Xint, sapply(Xint, paste0, collapse = "_"))
d_Xint <- lapply(Xint, function(Xint) {
self$X[, prod.DT(.SD), .SD = Xint]
})
data.table::setDT(d_Xint)
data.table::setnames(d_Xint, sapply(Xint, paste0, collapse = "_"))

no_Xint <- rowSums(d_Xint) == 0 # happens when we omit 1 factor level
if (any(int_numeric)) {
d_Xint$other <- rep(0, nrow(d_Xint))
d_Xint[no_Xint, "other"] <- 1
if (any(int_numeric)) {
# we actually need to take the product if we have a numeric covariate
d_Xint[no_Xint, "other"] <- prod.DT(data.table(
# need to take the product if we have a numeric covariate
d_Xint[no_Xint, "other"] <- prod.DT(data.table::data.table(
rep(1, sum(no_Xint)),
self$X[no_Xint, names(which(int_numeric)), with = F]
self$X[no_Xint, names(which(int_numeric)), with = FALSE]
))
}
other_name <- paste0("other.", paste0(int, collapse = "_"))
Expand All @@ -199,10 +201,13 @@ sl3_Task <- R6Class(
})

interaction_names <- unlist(lapply(interaction_data, colnames))
interaction_data <- data.table(do.call(cbind, interaction_data))
setnames(interaction_data, interaction_names)
interaction_data <- data.table::data.table(
do.call(cbind, interaction_data)
)
data.table::setnames(interaction_data, interaction_names)

interaction_cols <- self$add_columns(interaction_data, column_uuid = NULL)
interaction_cols <- self$add_columns(interaction_data,
column_uuid = NULL)
new_covariates <- c(self$nodes$covariates, interaction_names)
return(self$next_in_chain(
covariates = new_covariates,
Expand Down Expand Up @@ -273,7 +278,7 @@ sl3_Task <- R6Class(
# verify nodes are contained in dataset
missing_cols <- setdiff(all_nodes, names(column_names))

assert_that(
assertthat::assert_that(
length(missing_cols) == 0,
msg = sprintf(
"Couldn't find %s",
Expand Down Expand Up @@ -326,7 +331,7 @@ sl3_Task <- R6Class(
new_folds <- NULL
} else {
if (must_reindex) {
stop("subset indicies have copies, this requires dropping folds.")
stop("subset indices have copies, this requires dropping folds.")
}
new_folds <- subset_folds(private$.folds, row_index)
}
Expand Down Expand Up @@ -399,7 +404,7 @@ sl3_Task <- R6Class(
return(offset)
},
print = function() {
cat(sprintf("A sl3 Task with %d obs and these nodes:\n", self$nrow))
cat(sprintf("An sl3 Task with %d obs and these nodes:\n", self$nrow))
print(self$nodes)
},
revere_fold_task = function(fold_number) {
Expand Down Expand Up @@ -441,9 +446,8 @@ sl3_Task <- R6Class(
X_dt[, intercept := 1]

# make intercept first column
setcolorder(X_dt, c(old_ncol + 1, seq_len(old_ncol)))
data.table::setcolorder(X_dt, c(old_ncol + 1, seq_len(old_ncol)))
}

return(X_dt)
},
Y = function() {
Expand Down Expand Up @@ -539,5 +543,4 @@ sl3_Task <- R6Class(
#' @rdname sl3_Task
#'
#' @export
#
make_sl3_Task <- sl3_Task$new
7 changes: 4 additions & 3 deletions man/sl3_Task.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 6544257

Please sign in to comment.