From 31993b70e3461f404e1c461238e190174553095c Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Fri, 11 Aug 2023 18:11:07 -0700 Subject: [PATCH] Update Lrnr_base.R --- R/Lrnr_base.R | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/R/Lrnr_base.R b/R/Lrnr_base.R index 04e28f9e..9eb2df0b 100644 --- a/R/Lrnr_base.R +++ b/R/Lrnr_base.R @@ -177,21 +177,25 @@ Lrnr_base <- R6Class( )) } }, - base_predict = function(task = NULL) { + base_predict = function(task = NULL) { self$assert_trained() if (is.null(task)) { task <- private$.training_task } - + assert_that(is(task, "sl3_Task")) task <- self$subset_covariates(task) task <- self$process_formula(task) - + predictions <- private$.predict(task) - ncols <- ncol(predictions) - if (!is.null(ncols) && (ncols == 1)) { - predictions <- as.vector(unlist(predictions)) + if(inherits(predictions, "packed_predictions")) { + # if packed and data.table, as.vector(predictions) retains list structure. + if(is.data.table(predictions)) predictions <- as.vector(predictions) + return(predictions) + } else if(!is.null(ncols) && (ncols == 1)) { + # otherwise return vector. + predictions <- as.vector(unlist(predictions) } return(predictions) },