Skip to content

Commit

Permalink
Evaluate model with custom embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
lhjohn committed Aug 30, 2024
1 parent d446a1b commit 2856604
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 29 deletions.
18 changes: 17 additions & 1 deletion R/Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,26 @@ createDataset <- function(data, labels, plpModel = NULL) {
r_to_py(labels$outcomeCount),
numericalIndex)
} else {
cat_1_mapping <- plpModel$covariateImportance %>%
dplyr::select(covariateId, cat1Idx) %>%
dplyr::rename(index = cat1Idx) %>%
dplyr::filter(!is.na(index)) %>%
as.data.frame() %>%
r_to_py()

cat_2_mapping <- plpModel$covariateImportance %>%
dplyr::select(covariateId, cat2Idx) %>%
dplyr::rename(index = cat2Idx) %>%
dplyr::filter(!is.na(index)) %>%
as.data.frame() %>%
r_to_py()

numericalFeatures <-
r_to_py(as.array(which(plpModel$covariateImportance$isNumeric)))
data <- dataset(r_to_py(normalizePath(attributes(data)$path)),
numerical_features = numericalFeatures
numerical_features = numericalFeatures,
in_cat_2_mapping = cat_2_mapping,
in_cat_1_mapping = cat_1_mapping
)
}

Expand Down
13 changes: 8 additions & 5 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -303,20 +303,22 @@ predictDeepEstimator <- function(plpModel,

if (!is.null(plpModel$covariateImportance)) {
# this means that the model finished training since only in the end covariateImportance is added
browser()

# data <- createDataset(mappedData, plpModel = plpModel)
mappedData <- PatientLevelPrediction::MapIds(data$covariateData,
cohort = cohort,
mapping = plpModel$covariateImportance %>%
dplyr::select("columnId", "covariateId")
)
data <- createDataset(mappedData, plpModel = plpModel)

} else if ("plpData" %in% class(data)) {
mappedData <- PatientLevelPrediction::MapIds(data$covariateData,
cohort = cohort,
mapping = plpModel$covariateImportance %>%
dplyr::select("columnId", "covariateId")
# check this if it is correclty passing the mapped data rather than creating a new mapping
)
data <- createDataset(mappedData, plpModel = plpModel)
}

# get predictions
prediction <- cohort
if (is.character(plpModel$model)) {
Expand All @@ -336,6 +338,7 @@ predictDeepEstimator <- function(plpModel,
snakeCaseToCamelCaseNames(model$estimator_settings))
estimator$model$load_state_dict(model$model_state_dict)
prediction$value <- estimator$predict_proba(data)
browser()
} else {
prediction$value <- plpModel$model$predict_proba(data)
}
Expand Down
57 changes: 35 additions & 22 deletions inst/python/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class Data(Dataset):
def __init__(self, data, labels=None, numerical_features=None,
cat2_feature_names=None):
in_cat_1_mapping=None, in_cat_2_mapping=None):
desktop_path = Path.home() / "Desktop"

desktop_path = Path.home() / "Desktop"
Expand Down Expand Up @@ -75,9 +75,7 @@ def __init__(self, data, labels=None, numerical_features=None,
else:
self.target = torch.zeros(size=(observations,))

if cat2_feature_names is None:
cat2_feature_names = []

cat2_feature_names = []
cat2_feature_names += embed_names

# filter by categorical columns,
Expand All @@ -101,11 +99,19 @@ def __init__(self, data, labels=None, numerical_features=None,
# Now, use 'cat2_ref' as a normal DataFrame and access "columnId"
data_cat_1 = data_cat.filter(
~pl.col("covariateId").is_in(cat2_ref["covariateId"]))
self.cat_1_mapping = pl.DataFrame({
"covariateId": data_cat_1["covariateId"].unique(),
"index": pl.Series(range(1, len(data_cat_1["covariateId"].unique()) + 1))
})
self.cat_1_mapping.write_json(str(desktop_path / "cat1_mapping.json"))

self.cat_1_mapping = None
if in_cat_1_mapping is None:
self.cat_1_mapping = pl.DataFrame({
"covariateId": data_cat_1["covariateId"].unique(),
"index": pl.Series(range(1, len(data_cat_1["covariateId"].unique()) + 1))
})
# self.cat_1_mapping = pl.DataFrame(self.cat_1_mapping)
self.cat_1_mapping.write_json(str(desktop_path / "cat1_mapping_train.json"))
else:
self.cat_1_mapping = pl.DataFrame(in_cat_1_mapping).with_columns(pl.col('index').cast(pl.Int64), pl.col('covariateId').cast(pl.Float64))
self.cat_1_mapping.write_json(str(desktop_path / "cat1_mapping_test.json"))


data_cat_1 = data_cat_1.join(self.cat_1_mapping, on="covariateId", how="left") \
.select(pl.col("rowId"), pl.col("index").alias("covariateId"))
Expand All @@ -128,19 +134,26 @@ def __init__(self, data, labels=None, numerical_features=None,
# process cat_2 features
data_cat_2 = data_cat.filter(
pl.col("covariateId").is_in(cat2_ref))
self.cat_2_mapping = pl.DataFrame({
"covariateId": data_cat_2["covariateId"].unique(),
"index": pl.Series(range(1, len(data_cat_2["covariateId"].unique()) + 1))
})
self.cat_2_mapping = self.cat_2_mapping.lazy()
self.cat_2_mapping = (
self.data_ref
.filter(pl.col("covariateId").is_in(data_cat_2["covariateId"].unique()))
.select(pl.col("conceptId"), pl.col("covariateId"))
.join(self.cat_2_mapping, on="covariateId", how="left")
.collect()
)
self.cat_2_mapping.write_json(str(desktop_path / "cat2_mapping.json"))

self.cat_2_mapping = None
if in_cat_2_mapping is None:
self.cat_2_mapping = pl.DataFrame({
"covariateId": data_cat_2["covariateId"].unique(),
"index": pl.Series(range(1, len(data_cat_2["covariateId"].unique()) + 1))
})
self.cat_2_mapping = self.cat_2_mapping.lazy()
self.cat_2_mapping = (
self.data_ref
.filter(pl.col("covariateId").is_in(data_cat_2["covariateId"].unique()))
.select(pl.col("conceptId"), pl.col("covariateId"))
.join(self.cat_2_mapping, on="covariateId", how="left")
.collect()
)
self.cat_2_mapping.write_json(str(desktop_path / "cat2_mapping_train.json"))
else:
self.cat_2_mapping = pl.DataFrame(in_cat_2_mapping).with_columns(pl.col('index').cast(pl.Int64), pl.col('covariateId').cast(pl.Float64))
self.cat_2_mapping.write_json(str(desktop_path / "cat2_mapping_test.json"))

# cat_2_mapping.write_json(str(desktop_path / "cat2_mapping.json"))

data_cat_2 = data_cat_2.join(self.cat_2_mapping, on="covariateId", how="left") \
Expand Down
2 changes: 1 addition & 1 deletion inst/python/InitStrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def initialize(self, model, model_parameters, estimator_settings):

# # replace weights
# cat2_concept_mapping = pl.read_json(os.path.expanduser("~/Desktop/cat2_concept_mapping.json"))
cat2_mapping = pl.read_json(os.path.expanduser("~/Desktop/cat2_mapping.json"))
cat2_mapping = pl.read_json(os.path.expanduser("~/Desktop/cat2_mapping_train.json"))
# print(f"cat2_mapping: {cat2_mapping}")

concept_df = pl.DataFrame({"conceptId": state['names']}).with_columns(pl.col("conceptId"))
Expand Down

0 comments on commit 2856604

Please sign in to comment.