Skip to content

Commit

Permalink
Pass embedding dimensions to model
Browse files Browse the repository at this point in the history
  • Loading branch information
lhjohn committed Aug 27, 2024
1 parent af5a3df commit 6adfde6
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ fitEstimator <- function(trainData,
included = incs,
covariateValue = 0,
isNumeric = .data$columnId %in% cvResult$numericalIndex
# get mapping maybe here
)

comp <- start - Sys.time()
Expand Down Expand Up @@ -268,6 +269,7 @@ fitEstimator <- function(trainData,
hyperParamSearch = hyperSummary
),
covariateImportance = covariateRef
# also return mapping as part of covariateRef above, not necessary to do separately
)

class(result) <- "plpModel"
Expand Down Expand Up @@ -301,6 +303,7 @@ predictDeepEstimator <- function(plpModel,
cohort = cohort,
mapping = plpModel$covariateImportance %>%
dplyr::select("columnId", "covariateId")
# check this if it is correclty passing the mapped data rather than creating a new mapping
)
data <- createDataset(mappedData, plpModel = plpModel)
}
Expand Down Expand Up @@ -421,6 +424,7 @@ gridCvDeep <- function(mappedData,
prediction$cohortStartDate <- as.Date(prediction$cohortStartDate,
origin = "1970-01-01")
numericalIndex <- dataset$get_numerical_features()
# get mapping as above

# save torch code here
if (!dir.exists(file.path(modelLocation))) {
Expand All @@ -434,6 +438,7 @@ gridCvDeep <- function(mappedData,
finalParam = finalParam,
paramGridSearch = paramGridSearch,
numericalIndex = numericalIndex$to_list()
# add mapping here, two columns [covariateId, columnId]
)
)
}
Expand Down Expand Up @@ -577,8 +582,9 @@ doCrossValidationImpl <- function(dataset,
fillEstimatorSettings(modelSettings$estimatorSettings,
fitParams,
parameters)
currentModelParams$catFeatures <- dataset$get_cat_features()$max()
currentModelParams$catFeatures <- dataset$get_cat_features()$len()
currentModelParams$numFeatures <- dataset$get_numerical_features()$len()
currentModelParams$cat2Features <- dataset$get_cat_2_features()$len()
if (currentEstimatorSettings$findLR) {
lr <- getLR(currentModelParams, currentEstimatorSettings, dataset)
ParallelLogger::logInfo(paste0("Auto learning rate selected as: ", lr))
Expand Down Expand Up @@ -659,7 +665,8 @@ trainFinalModel <- function(dataset, finalParam, modelSettings, labels) {

fitParams <- names(finalParam)[grepl("^estimator", names(finalParam))]

modelParams$catFeatures <- dataset$get_cat_features()$max()
modelParams$catFeatures <- dataset$get_cat_features()$len()
modelParams$cat2Features <- dataset$get_cat_2_features()$len()
modelParams$numFeatures <- dataset$get_numerical_features()$len()
modelParams$modelType <- modelSettings$modelType

Expand Down

0 comments on commit 6adfde6

Please sign in to comment.