-
Notifications
You must be signed in to change notification settings - Fork 3
/
nodes.py
135 lines (109 loc) · 4.2 KB
/
nodes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from transformers import MarianMTModel, MarianTokenizer
import requests
import random
import json
from hashlib import md5
marian_list = [
"opus-mt-zh-en",
"opus-mt-rn-en",
"opus-mt-taw-en",
"opus-mt-az-en",
"opus-mt-ru-en",
"opus-mt-ja-en",
"opus-mt-en-zh",
"opus-mt-en-ru",
"opus-mt-en-jap",
"opus-mt-en-rn",
]
lang_list = [
'auto','zh', 'yue', 'kor', 'th', 'pt','el','bul','fin','slo','cht','wyw'
,'fra','ara','de','nl','est','cs','swe','jp','spa','ru','it','pl','ja']
def make_md5(s, encoding='utf-8'):
return md5(s.encode(encoding)).hexdigest()
class LoadMarianMTCheckPoint:
@classmethod
def INPUT_TYPES(cls):
# default_model_path = Path(folder_paths.models_dir) / "marian_models"
# marian_list = [str(p.name) for p in default_model_path.iterdir()]
return {
"required": {
"checkpoint": (marian_list, {"multiline": False,"default": "opus-mt-zh-en"}),
}
}
RETURN_TYPES = ("MODEL","TOKENIZER")
RETURN_NAMES = ("model","tokenizer")
FUNCTION = "load_marian_mt"
CATEGORY = "kkTranslator"
def load_marian_mt(self, checkpoint):
# default_model_path = Path(folder_paths.models_dir) / "marian_models"
model_name = 'Helsinki-NLP/' + checkpoint
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
return (model,tokenizer)
class PromptTranslateToText:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL", ),
"tokenizer": ("TOKENIZER", ),
"prompt_text": ("STRING", {"multiline": True,"default": "你好"}),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "run"
CATEGORY = "kkTranslator"
def run(self, model,tokenizer,prompt_text):
translated = model.generate(**tokenizer(prompt_text, return_tensors="pt", padding=True))
text = ""
for t in translated:
text += tokenizer.decode(t, skip_special_tokens=True)
print(text)
return (text,)
class PromptBaiduFanyiToText:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"appid": ("STRING", {"multiline": False,"default": ""}),
"secretkey": ("STRING", {"multiline": False,"default": ""}),
"from_lang": (lang_list, {"multiline": True,"default": "auto"}),
"prompt_text": ("STRING", {"multiline": True,"default": "你好"}),
}
}
RETURN_TYPES = ("STRING",)
FUNCTION = "run"
CATEGORY = "kkTranslator"
def run(self, appid,secretkey,from_lang,prompt_text):
if appid == "" or secretkey == "":
raise "Please input your appid and secretkey"
endpoint = 'http://api.fanyi.baidu.com'
path = '/api/trans/vip/translate'
url = endpoint + path
salt = random.randint(32768, 65536)
sign = make_md5(appid + prompt_text + str(salt) + secretkey)
# Build request
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
payload = {'appid': appid, 'q': prompt_text, 'from': from_lang, 'to': 'en', 'salt': salt, 'sign': sign}
# Send request
r = requests.post(url, params=payload, headers=headers)
result = r.json()
text = result['trans_result'][0]['dst']
return (text,)
# A dictionary that contains all nodes you want to export with their names
NODE_CLASS_MAPPINGS = {
"PromptTranslateToText": PromptTranslateToText,
"LoadMarianMTCheckPoint":LoadMarianMTCheckPoint,
"PromptBaiduFanyiToText": PromptBaiduFanyiToText,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"PromptTranslateToText": "Prompt Translate to Text",
"LoadMarianMTCheckPoint":"Load MarianMT CheckPoint",
"PromptBaiduFanyiToText": "Prompt Baidu Fanyi to Text",
}
if __name__ == "__main__":
# load = LoadMarianMTCheckPoint()
# load.load_marian_mt("opus-mt-zh-en")
fanyi = PromptBaiduFanyiToText()
fanyi.run("xxxxxxxxxx", "xxxxxxxxxxxxxx", "zh", "你好")