Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/celltype_annotation_automl' into…
Browse files Browse the repository at this point in the history
… celltype_annotation_automl
  • Loading branch information
xingzhongyu committed Oct 2, 2024
2 parents 26c73ba + ab0b5df commit 77a2761
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 137 deletions.
2 changes: 1 addition & 1 deletion dance/modules/multi_modality/joint_embedding/scmogcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def fit(self, g_mod1, g_mod2, train_size, cell_type, batch_label, phase_score):
Bipartite expression feature graph for modality 1.
g_mod2 : dgl.DGLGraph
Bipartite expression feature graph for modality 2.
train_size : int
train_size : int or array_like
Number of training samples.
labels : torch.Tensor
Labels for training samples.
Expand Down
5 changes: 3 additions & 2 deletions dance/modules/multi_modality/joint_embedding/scmvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def _inference(self, X1=None, X2=None):

if X1 is not None:
if self.log_variational:
X1_ = torch.log(X1_ + 1)
X1_ = torch.log(torch.clamp(X1_, min=1e-7) + 1)

mean_l, logvar_l, library = self.X1_encoder_l(X1_)

Expand All @@ -380,7 +380,8 @@ def _inference(self, X1=None, X2=None):

if self.Type == 'ZINB':
if self.log_variational:
X2_ = torch.log(X2_ + 1)
# X2_ = torch.log(X2_ + 1)
X2_ = torch.log(torch.clamp(X2_, min=1e-7) + 1)
mean_l2, logvar_l2, library2 = self.X2_encoder_l(X2_)

means, logvar = self._encode_modalities(X1_, X2_)
Expand Down
154 changes: 96 additions & 58 deletions examples/tuning/joint_embedding_scmogcn/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import gc
import os
import pprint
import sys
Expand All @@ -19,8 +20,10 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-t", "--subtask", default="openproblems_bmmc_cite_phase2",
choices=["GSE140203_BRAIN_atac2gex", "openproblems_bmmc_cite_phase2", "openproblems_bmmc_multiome_phase2"])
"-t", "--subtask", default="openproblems_bmmc_cite_phase2", choices=[
"GSE140203_BRAIN_atac2gex", "openproblems_bmmc_cite_phase2", "openproblems_bmmc_multiome_phase2",
"GSE140203_SKIN_atac2gex", "openproblems_2022_multi_atac2gex"
])
parser.add_argument("-d", "--data_folder", default="./data/joint_embedding")
parser.add_argument("-pre", "--pretrained_folder", default="./data/joint_embedding/pretrained")
parser.add_argument("-csv", "--csv_path", default="decoupled_lsi.csv")
Expand Down Expand Up @@ -55,66 +58,101 @@
logger.info(f"\n files is saved in {file_root_path}")
pipeline_planer = PipelinePlaner.from_config_file(f"{file_root_path}/{args.tune_mode}_tuning_config.yaml")
os.environ["WANDB_AGENT_MAX_INITIAL_FAILURES"] = "2000"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["WANDB_AGENT_DISABLE_FLAPPING"] = "True"

def evaluate_pipeline(tune_mode=args.tune_mode, pipeline_planer=pipeline_planer):
wandb.init(settings=wandb.Settings(start_method='thread'))
set_seed(args.seed)
dataset = JointEmbeddingNIPSDataset(args.subtask, root=args.data_folder, preprocess=args.preprocess)
data = dataset.load_data()
# Prepare preprocessing pipeline and apply it to data
kwargs = {tune_mode: dict(wandb.config)}
preprocessing_pipeline = pipeline_planer.generate(**kwargs)
print(f"Pipeline config:\n{preprocessing_pipeline.to_yaml()}")
preprocessing_pipeline(data)
if args.preprocess != "aux":
cell_type_labels = data.data['test_sol'].obs["cell_type"].to_numpy()
cell_type_labels_unique = list(np.unique(cell_type_labels))
c_labels = np.array([cell_type_labels_unique.index(item) for item in cell_type_labels])
data.data['mod1'].obsm["cell_type"] = c_labels
data.data["mod1"].obsm["S_scores"] = np.zeros(data.data['mod1'].shape[0])
data.data["mod1"].obsm["G2M_scores"] = np.zeros(data.data['mod1'].shape[0])
data.data["mod1"].obsm["batch_label"] = np.zeros(data.data['mod1'].shape[0])
data.data["mod1"].obsm["phase_labels"] = np.zeros(data.data['mod1'].shape[0])

train_size = len(data.get_split_idx("train"))

data = CellFeatureBipartiteGraph(cell_feature_channel="feature.cell", mod="mod1")(data)
data = CellFeatureBipartiteGraph(cell_feature_channel="feature.cell", mod="mod2")(data)
# data.set_config(
# feature_mod=["mod1", "mod2"],
# label_mod=["mod1", "mod1", "mod1", "mod1", "mod1"],
# feature_channel=["X_pca", "X_pca"],
# label_channel=["cell_type", "batch_label", "phase_labels", "S_scores", "G2M_scores"],
# )
(x_mod1, x_mod2), (cell_type, batch_label, phase_label, S_score, G2M_score) = data.get_data(return_type="torch")
phase_score = torch.cat([S_score[:, None], G2M_score[:, None]], 1)
test_id = np.arange(x_mod1.shape[0])
labels = cell_type.numpy()
adata_sol = data.data['test_sol'] # [data._split_idx_dict['test']]
model = ScMoGCNWrapper(args, num_celL_types=int(cell_type.max() + 1), num_batches=int(batch_label.max() + 1),
num_phases=phase_score.shape[1], num_features=x_mod1.shape[1] + x_mod2.shape[1])
model.fit(
g_mod1=data.data["mod1"].uns["g"],
g_mod2=data.data["mod2"].uns["g"],
train_size=train_size,
cell_type=cell_type,
batch_label=batch_label,
phase_score=phase_score,
)

embeds = model.predict(test_id).cpu().numpy()
score = model.score(test_id, labels, metric="clustering")
# score.update(model.score(test_id, labels, adata_sol=adata_sol, metric="openproblems"))
score.update({
'subtask': args.subtask,
'method': 'scmogcn',
})

score["ARI"] = score["dance_ari"]
del score["dance_ari"]
wandb.log(score)
wandb.finish()
torch.cuda.empty_cache()
wandb_config = wandb.config
if "run_kwargs" in pipeline_planer.config:
if any(d == dict(wandb.config["run_kwargs"]) for d in pipeline_planer.config.run_kwargs):
wandb_config = wandb_config["run_kwargs"]
else:
wandb.log({"skip": 1})
wandb.finish()
return
try:
dataset = JointEmbeddingNIPSDataset(args.subtask, root=args.data_folder, preprocess=args.preprocess)
data = dataset.load_data()
# Prepare preprocessing pipeline and apply it to data
kwargs = {tune_mode: dict(wandb_config)}
preprocessing_pipeline = pipeline_planer.generate(**kwargs)
print(f"Pipeline config:\n{preprocessing_pipeline.to_yaml()}")
preprocessing_pipeline(data)
# train_idx=list(set(data.mod["meta1"].obs_names) & set(data.mod["mod1"].obs_names))
train_name = [item for item in data.mod["mod1"].obs_names if item in data.mod["meta1"].obs_names]
train_idx = [data.mod["mod1"].obs_names.get_loc(name) for name in train_name]
test_idx = list({i for i in range(data.mod["mod1"].shape[0])}.difference(set(train_idx)))

# train_size=data.mod["meta1"].shape[0]
# test_size=data.mod["mod1"].shape[0]-train_size
data.set_split_idx("train", train_idx)
data.set_split_idx("test", test_idx)
if args.preprocess != "aux":
cell_type_labels = data.data['test_sol'].obs["cell_type"].to_numpy()
cell_type_labels_unique = list(np.unique(cell_type_labels))
c_labels = np.array([cell_type_labels_unique.index(item) for item in cell_type_labels])
data.data['mod1'].obsm["cell_type"] = c_labels
data.data["mod1"].obsm["S_scores"] = np.zeros(data.data['mod1'].shape[0])
data.data["mod1"].obsm["G2M_scores"] = np.zeros(data.data['mod1'].shape[0])
data.data["mod1"].obsm["batch_label"] = np.zeros(data.data['mod1'].shape[0])
data.data["mod1"].obsm["phase_labels"] = np.zeros(data.data['mod1'].shape[0])

# train_size = len(data.get_split_idx("train"))
#按理说meta1应该包括mod1前半部分的所有内容,可能中途打乱了顺序
data = CellFeatureBipartiteGraph(cell_feature_channel="feature.cell", mod="mod1")(data)
data = CellFeatureBipartiteGraph(cell_feature_channel="feature.cell", mod="mod2")(data)
# data.set_config(
# feature_mod=["mod1", "mod2"],
# label_mod=["mod1", "mod1", "mod1", "mod1", "mod1"],
# feature_channel=["X_pca", "X_pca"],
# label_channel=["cell_type", "batch_label", "phase_labels", "S_scores", "G2M_scores"],
# )
(x_mod1, x_mod2), (cell_type, batch_label, phase_label, S_score,
G2M_score) = data.get_data(return_type="torch")
phase_score = torch.cat([S_score[:, None], G2M_score[:, None]], 1)
test_id = np.arange(x_mod1.shape[0])
labels = cell_type.numpy()
adata_sol = data.data['test_sol'] # [data._split_idx_dict['test']]
model = ScMoGCNWrapper(args, num_celL_types=int(cell_type.max() + 1),
num_batches=int(batch_label.max() + 1), num_phases=phase_score.shape[1],
num_features=x_mod1.shape[1] + x_mod2.shape[1])
model.fit(
g_mod1=data.data["mod1"].uns["g"],
g_mod2=data.data["mod2"].uns["g"],
train_size=train_idx,
cell_type=cell_type,
batch_label=batch_label,
phase_score=phase_score,
)

embeds = model.predict(test_id).cpu().numpy()
score = model.score(test_id, labels, metric="clustering")
# score.update(model.score(test_id, labels, adata_sol=adata_sol, metric="openproblems"))
score.update({
'subtask': args.subtask,
'method': 'scmogcn',
})

score["ARI"] = score["dance_ari"]
del score["dance_ari"]
wandb.log(score)
wandb.finish()
finally:
# del data,model,adata_sol,adata,embeds,emb1, emb2,total_loader,total,test_loader,test,train_loader,train,Nfeature2,Nfeature1
# del x_train, y_train, x_train_raw, y_train_raw, x_train_size,y_train_size,train_labels,x_test, y_test, x_test_raw, y_test_raw, x_test_size,y_test_size, test_labels
# del labels,le,dataset,score
# variables_to_delete=["data","model","adata_sol","adata","embeds","emb1", "emb2","total_loader","total,test_loader","test,train_loader","train","Nfeature2","Nfeature1","x_train", "y_train", "x_train_raw", "y_train_raw", "x_train_size","y_train_size","train_labels","x_test", "y_test"," x_test_raw", y_test_raw, x_test_size,y_test_size, test_labels,labels,le,dataset,score]
locals_keys = list(locals().keys())
for var in locals_keys:
try:
exec(f"del {var}")
logger.info(f"Deleted '{var}'")
except NameError:
logger.info(f"Variable '{var}' does not exist, continuing...")
torch.cuda.empty_cache()
gc.collect()

entity, project, sweep_id = pipeline_planer.wandb_sweep_agent(
evaluate_pipeline, sweep_id=args.sweep_id, count=args.count) #Score can be recorded for each epoch
Expand Down
Loading

0 comments on commit 77a2761

Please sign in to comment.