Skip to content

Commit

Permalink
Update Lrnr_base.R
Browse files Browse the repository at this point in the history
  • Loading branch information
Larsvanderlaan authored Aug 12, 2023
1 parent 74db1f3 commit 31993b7
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions R/Lrnr_base.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,21 +177,25 @@ Lrnr_base <- R6Class(
))
}
},
base_predict = function(task = NULL) {
base_predict = function(task = NULL) {
self$assert_trained()
if (is.null(task)) {
task <- private$.training_task
}

assert_that(is(task, "sl3_Task"))
task <- self$subset_covariates(task)
task <- self$process_formula(task)

predictions <- private$.predict(task)

ncols <- ncol(predictions)
if (!is.null(ncols) && (ncols == 1)) {
predictions <- as.vector(unlist(predictions))
if(inherits(predictions, "packed_predictions")) {
# if packed and data.table, as.vector(predictions) retains list structure.
if(is.data.table(predictions)) predictions <- as.vector(predictions)
return(predictions)
} else if(!is.null(ncols) && (ncols == 1)) {
# otherwise return vector.
predictions <- as.vector(unlist(predictions)
}
return(predictions)
},
Expand Down

0 comments on commit 31993b7

Please sign in to comment.