Skip to content

Commit

Permalink
Return mappings during training
Browse files Browse the repository at this point in the history
  • Loading branch information
lhjohn committed Aug 29, 2024
1 parent 6adfde6 commit d446a1b
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 39 deletions.
21 changes: 16 additions & 5 deletions R/Estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ fitEstimator <- function(trainData,
covariateValue = 0,
isNumeric = .data$columnId %in% cvResult$numericalIndex
# get mapping maybe here
)
) %>%
left_join(cvResult$cat1Mapping %>% rename(cat1Idx = index), by = "covariateId") %>%
left_join(cvResult$cat2Mapping %>% rename(cat2Idx = index), by = "covariateId")

comp <- start - Sys.time()
result <- list(
Expand Down Expand Up @@ -298,7 +300,14 @@ predictDeepEstimator <- function(plpModel,
plpModel <- list(model = plpModel)
attr(plpModel, "modelType") <- "binary"
}
if ("plpData" %in% class(data)) {

if (!is.null(plpModel$covariateImportance)) {
# this means that the model finished training since only in the end covariateImportance is added
browser()

# data <- createDataset(mappedData, plpModel = plpModel)

} else if ("plpData" %in% class(data)) {
mappedData <- PatientLevelPrediction::MapIds(data$covariateData,
cohort = cohort,
mapping = plpModel$covariateImportance %>%
Expand Down Expand Up @@ -424,7 +433,8 @@ gridCvDeep <- function(mappedData,
prediction$cohortStartDate <- as.Date(prediction$cohortStartDate,
origin = "1970-01-01")
numericalIndex <- dataset$get_numerical_features()
# get mapping as above
cat1Mapping <- as.data.frame(dataset$get_cat_1_mapping())
cat2Mapping <- as.data.frame(dataset$get_cat_2_mapping())

# save torch code here
if (!dir.exists(file.path(modelLocation))) {
Expand All @@ -437,8 +447,9 @@ gridCvDeep <- function(mappedData,
prediction = prediction,
finalParam = finalParam,
paramGridSearch = paramGridSearch,
numericalIndex = numericalIndex$to_list()
# add mapping here, two columns [covariateId, columnId]
numericalIndex = numericalIndex$to_list(),
cat1Mapping = cat1Mapping,
cat2Mapping = cat2Mapping
)
)
}
Expand Down
10 changes: 5 additions & 5 deletions R/Transformer.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ setCustomEmbeddingTransformer <- function(
setEstimator(
learningRate = "auto",
weightDecay = 1e-4,
batchSize = 512,
epochs = 10,
batchSize = 256,
epochs = 2,
seed = NULL,
device = "cpu"
)
Expand All @@ -77,13 +77,13 @@ setCustomEmbeddingTransformer <- function(
estimatorSettings$embeddingFilePath <- embeddingFilePath
transformerSettings <- setTransformer(
numBlocks = 3,
dimToken = 192,
dimToken = 16,
dimOut = 1,
numHeads = 8,
numHeads = 4,
attDropout = 0.2,
ffnDropout = 0.1,
resDropout = 0.0,
dimHidden = 256,
dimHidden = 32,
estimatorSettings = estimatorSettings,
hyperParamSearch = "random",
randomSample = 1
Expand Down
26 changes: 15 additions & 11 deletions inst/python/Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def __init__(self, data, labels=None, numerical_features=None,
if cat2_feature_names is None:
cat2_feature_names = []

self.feature_mapping = {}

cat2_feature_names += embed_names

# filter by categorical columns,
Expand All @@ -103,13 +101,13 @@ def __init__(self, data, labels=None, numerical_features=None,
# Now, use 'cat2_ref' as a normal DataFrame and access "columnId"
data_cat_1 = data_cat.filter(
~pl.col("covariateId").is_in(cat2_ref["covariateId"]))
cat_1_mapping = pl.DataFrame({
self.cat_1_mapping = pl.DataFrame({
"covariateId": data_cat_1["covariateId"].unique(),
"index": pl.Series(range(1, len(data_cat_1["covariateId"].unique()) + 1))
})
cat_1_mapping.write_json(str(desktop_path / "cat1_mapping.json"))
self.cat_1_mapping.write_json(str(desktop_path / "cat1_mapping.json"))

data_cat_1 = data_cat_1.join(cat_1_mapping, on="covariateId", how="left") \
data_cat_1 = data_cat_1.join(self.cat_1_mapping, on="covariateId", how="left") \
.select(pl.col("rowId"), pl.col("index").alias("covariateId"))

cat_tensor = torch.tensor(data_cat_1.to_numpy())
Expand All @@ -130,22 +128,22 @@ def __init__(self, data, labels=None, numerical_features=None,
# process cat_2 features
data_cat_2 = data_cat.filter(
pl.col("covariateId").is_in(cat2_ref))
cat_2_mapping = pl.DataFrame({
self.cat_2_mapping = pl.DataFrame({
"covariateId": data_cat_2["covariateId"].unique(),
"index": pl.Series(range(1, len(data_cat_2["covariateId"].unique()) + 1))
})
cat_2_mapping = cat_2_mapping.lazy()
cat_2_mapping = (
self.cat_2_mapping = self.cat_2_mapping.lazy()
self.cat_2_mapping = (
self.data_ref
.filter(pl.col("covariateId").is_in(data_cat_2["covariateId"].unique()))
.select(pl.col("conceptId"), pl.col("covariateId"))
.join(cat_2_mapping, on="covariateId", how="left")
.join(self.cat_2_mapping, on="covariateId", how="left")
.collect()
)
cat_2_mapping.write_json(str(desktop_path / "cat2_mapping.json"))
self.cat_2_mapping.write_json(str(desktop_path / "cat2_mapping.json"))
# cat_2_mapping.write_json(str(desktop_path / "cat2_mapping.json"))

data_cat_2 = data_cat_2.join(cat_2_mapping, on="covariateId", how="left") \
data_cat_2 = data_cat_2.join(self.cat_2_mapping, on="covariateId", how="left") \
.select(pl.col("rowId"), pl.col("index").alias("covariateId")) # maybe rename this to something else

cat_2_tensor = torch.tensor(data_cat_2.to_numpy())
Expand Down Expand Up @@ -211,6 +209,12 @@ def get_cat_features(self):

def get_cat_2_features(self):
return self.cat_2_features

def get_cat_2_mapping(self):
return self.cat_2_mapping

def get_cat_1_mapping(self):
return self.cat_1_mapping

def __len__(self):
return self.target.size()[0]
Expand Down
32 changes: 14 additions & 18 deletions inst/python/InitStrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,9 @@ class CustomEmbeddingInitStrategy(InitStrategy):
def initialize(self, model, model_parameters, estimator_settings):
file_path = estimator_settings.get("embedding_file_path")

# Ensure `cat_2_features` is added to `model_parameters`
# cat_2_features_default = 20 # Set a default value if you don't have one
print(model_parameters['cat_2_features'])
print(model_parameters['cat_features'])
print(model_parameters['num_features'])

# print(model_parameters['cat_2_features'])
# print(model_parameters['cat_features'])
# print(model_parameters['num_features'])

# Instantiate the model with the provided parameters
model_temp = model(**model_parameters)
Expand All @@ -51,7 +48,7 @@ def initialize(self, model, model_parameters, estimator_settings):
raise KeyError(f"The key '{embedding_key}' does not exist in the state dictionary")

new_embeddings = state_dict[embedding_key].float()
print(f"new_embeddings: {new_embeddings}")
# print(f"new_embeddings: {new_embeddings}")

# Ensure that model_temp.categorical_embedding_2 exists
if not hasattr(model_temp, 'categorical_embedding_2'):
Expand All @@ -60,10 +57,10 @@ def initialize(self, model, model_parameters, estimator_settings):
# # replace weights
# cat2_concept_mapping = pl.read_json(os.path.expanduser("~/Desktop/cat2_concept_mapping.json"))
cat2_mapping = pl.read_json(os.path.expanduser("~/Desktop/cat2_mapping.json"))
print(f"cat2_mapping: {cat2_mapping}")
# print(f"cat2_mapping: {cat2_mapping}")

concept_df = pl.DataFrame({"conceptId": state['names']}).with_columns(pl.col("conceptId"))
print(f"concept_df: {concept_df}")
# print(f"concept_df: {concept_df}")

# Initialize tensor for mapped embeddings
mapped_embeddings = torch.zeros((cat2_mapping.shape[0] + 1, new_embeddings.shape[1]))
Expand All @@ -75,20 +72,19 @@ def initialize(self, model, model_parameters, estimator_settings):
concept_idx = concept_df["conceptId"].to_list().index(concept_id)
mapped_embeddings[index] = new_embeddings[concept_idx]

print(f"mapped_embeddings: {mapped_embeddings}")
# print(f"mapped_embeddings: {mapped_embeddings}")

# Assign the mapped embeddings to the model
model_temp.categorical_embedding_2.weight = torch.nn.Parameter(mapped_embeddings)
model_temp.categorical_embedding_2.weight.requires_grad = False

print("New Embeddings:")
print(new_embeddings)
print(f"Restored Epoch: {state['epoch']}")
print(f"Restored Mean Rank: {state['mean_rank']}")
print(f"Restored Loss: {state['loss']}")
print(f"Restored Names: {state['names'][:5]}")
print(f"Number of names: {len(state['names'])}")
# print(f"Filtered Embeddings: {filtered_embeddings}")
# print("New Embeddings:")
# print(new_embeddings)
# print(f"Restored Epoch: {state['epoch']}")
# print(f"Restored Mean Rank: {state['mean_rank']}")
# print(f"Restored Loss: {state['loss']}")
# print(f"Restored Names: {state['names'][:5]}")
# print(f"Number of names: {len(state['names'])}")
else:
raise FileNotFoundError(f"File not found or path is incorrect: {file_path}")

Expand Down

0 comments on commit d446a1b

Please sign in to comment.