Skip to content

Commit

Permalink
fix subset covariates to support out of order covariates. Covariates …
Browse files Browse the repository at this point in the history
…passed to learner should match the order of the covariate params, not the order of the covariates in the task
  • Loading branch information
jeremyrcoyle committed Jan 22, 2024
1 parent 507e0a1 commit d0a4a1d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 24 deletions.
32 changes: 10 additions & 22 deletions R/Lrnr_base.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ Lrnr_base <- R6Class(
if (length(delta_idx) > 0) {
delta_missing <- task_covs_missing[delta_idx]
task_covs_missing <- task_covs_missing[-delta_idx]

delta_missing_data <- matrix(0, nrow(task$data), length(delta_idx))
colnames(delta_missing_data) <- delta_missing
cols <- task$add_columns(data.table(delta_missing_data))

} else{
cols <- task$column_names
}

# error when task is missing covariates
Expand All @@ -68,29 +75,10 @@ Lrnr_base <- R6Class(
)
}

# subset task covariates to only includes those in learner covariates
covs_subset <- intersect(task_covs, learner_covs)

# return updated task
if (length(delta_idx) == 0) {
# re-order the covariate subset to match order of learner covariates
ordered_covs_subset <- covs_subset[match(covs_subset, learner_covs)]
return(task$next_in_chain(covariates = ordered_covs_subset))
} else {
# incorporate missingness indicators in task covariates subset & sort
covs_subset_delta <- c(covs_subset, delta_missing)
ord_covs <- covs_subset_delta[match(covs_subset_delta, learner_covs)]

# incorporate missingness indicators in task data
delta_missing_data <- matrix(0, nrow(task$data), length(delta_idx))
colnames(delta_missing_data) <- delta_missing
cols <- task$add_columns(data.table(delta_missing_data))

return(task$next_in_chain(
covariates = ord_covs,
return(task$next_in_chain(
covariates = learner_covs,
column_names = cols
))
}
))
} else {
return(task)
}
Expand Down
11 changes: 9 additions & 2 deletions tests/testthat/test-subset_covariates.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ full_preds <- glm_fit_pre_subset$predict(task)
training_preds <- glm_fit_pre_subset$predict()
test_that("extra covariates in prediction set get dropped correctly", expect_equal(full_preds, training_preds))


shuffled_subset <- sample(covariate_subset)
task_pre_subset_shuffled <- sl3_Task$new(mtcars, covariates = shuffled_subset, outcome = outcome)
# debugonce(glm_fit_pre_subset$subset_covariates)
shuffled_preds <- glm_fit_pre_subset$predict(task_pre_subset_shuffled)
test_that("covariates out of order prediction set get shuffled correctly", expect_equal(full_preds, shuffled_preds))


task_train <- sl3_Task$new(mtcars, covariates = covariates, outcome = outcome)
task_predict <- sl3_Task$new(mtcars, covariates = covariate_subset, outcome = outcome)
glm_fit <- lrnr_glm$train(task_train)
Expand All @@ -47,11 +55,10 @@ task_missing_data <- suppressWarnings(
sl3_Task$new(missing_data, covariates = covs, outcome = Y)
)

lrnr_glm <- make_learner(Lrnr_glm_fast, name = "test")
lrnr_glm <- make_learner(Lrnr_glm_fast)
glm_fit <- lrnr_glm$train(task_missing_data)

task_complete_data <- sl3_Task$new(mtcars, covariates = covs, outcome = Y)

test_that("missingness indicators in prediction task works", {
expect_vector(glm_fit$predict(task_complete_data))
})

0 comments on commit d0a4a1d

Please sign in to comment.