-
Notifications
You must be signed in to change notification settings - Fork 0
/
phistruct.py
101 lines (74 loc) · 2.9 KB
/
phistruct.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
=================================================================================
This script is for running PHIStruct. It takes a directory of PDB files as input
and outputs the predicted host genus for each protein. It also displays the
prediction score (class probability) for each host genus recognized by PHIStruct.
@author Mark Edward M. Gonzales
=================================================================================
"""
import argparse
import os
import joblib
import torch
from transformers import EsmTokenizer
from SaProt.model.esm.base import EsmBaseModel
from SaProt.utils.foldseek_util import get_struc_seq
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
config = {
"task": "base",
"config_path": "SaProt/SaProt_650M_AF2",
"load_pretrained": True,
}
model = EsmBaseModel(**config).to(device)
tokenizer = EsmTokenizer.from_pretrained(config["config_path"])
model.eval()
# Adapted from https://github.com/westlake-repl/SaProt?tab=readme-ov-file#convert-protein-structure-into-structure-aware-sequence
def encode(pdb_path):
_, _, combined_seq = get_struc_seq("SaProt/bin/foldseek", pdb_path, ["A"])["A"]
return combined_seq
# Adapted from https://github.com/westlake-repl/SaProt/issues/14
def embed(seq):
inputs = tokenizer(seq, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
embedding = model.get_hidden_states(inputs, reduction="mean")
return embedding[0].tolist()
def predict(embedding, clf):
proba = clf.predict_proba([embedding])
scores = []
for idx, class_name in enumerate(clf.classes_):
scores.append((class_name, proba[0][idx]))
return sorted(scores, key=lambda x: x[1], reverse=True)
def write_results(id, scores, output_dir):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if id.endswith(".pdb"):
id = id[: -len(".pdb")]
with open(f"{output_dir}/{id}.csv", "w") as f:
for entry in scores:
f.write(f"{entry[0]},{entry[1]}\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input",
required=True,
help="Path to the directory storing the PDB files describing the structures of the receptor-binding proteins",
)
parser.add_argument(
"--model",
required=True,
help="Path to the trained model (recognized format: joblib or compressed joblib, framework: scikit-learn)",
)
parser.add_argument(
"--output",
required=True,
help="Path to the directory to which the results of running PHIStruct will be written",
)
args = parser.parse_args()
clf = joblib.load(args.model)
for protein in os.listdir(args.input):
write_results(
protein,
predict(embed(encode(str(f"{args.input}/{protein}"))), clf),
args.output,
)