-
Notifications
You must be signed in to change notification settings - Fork 135
/
console.py
127 lines (112 loc) · 5.43 KB
/
console.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import xxhash
from ai import AI
from config import Config
from storage import Storage
from contents import *
def console(cfg: Config):
try:
while True:
if not _console(cfg):
return
except KeyboardInterrupt:
print("exit")
def _console(cfg: Config) -> bool:
"""Run the console."""
contents, lang, identify = _get_contents()
print("The article has been retrieved, and the number of text fragments is:", len(contents))
for content in contents:
print('\t', content)
ai = AI(cfg)
storage = Storage.create_storage(cfg)
print("=====================================")
if storage.been_indexed(identify):
print("The article has already been indexed, so there is no need to index it again.")
print("=====================================")
else:
# 1. 对文章的每个段落生成embedding
# 1. Generate an embedding for each paragraph of the article.
embeddings, tokens = ai.create_embeddings(contents)
print(f"Embeddings have been created with {len(embeddings)} embeddings, using {tokens} tokens, "
f"costing ${tokens / 1000 * 0.0004}")
storage.add_all(embeddings, identify)
print("The embeddings have been saved.")
print("=====================================")
while True:
query = input("Please enter your query (/help to view commands):").strip()
if query.startswith("/"):
if query == "/quit":
return False
elif query == "/reset":
print("=====================================")
return True
elif query == "/summary":
# 生成embedding式摘要,根据不同的语言使用有基于SIF的加权平均或一般的直接求平均
# Generate an embedding-based summary, using weighted average based on SIF or direct average based on the language.
ai.generate_summary(storage.get_all_embeddings(identify), num_candidates=100,
use_sif=lang not in ['zh', 'ja', 'ko', 'hi', 'ar', 'fa'])
elif query == "/reindex":
# 重新索引,会清空数据库
# Re-index, which will clear the database.
storage.clear(identify)
embeddings, tokens = ai.create_embeddings(contents)
print(f"Embeddings have been created with {len(embeddings)} embeddings, using {tokens} tokens, "
f"costing ${tokens / 1000 * 0.0004}")
storage.add_all(embeddings, identify)
print("The embeddings have been saved.")
elif query == "/help":
print("Enter /summary to generate an embedding-based summary.")
print("Enter /reindex to re-index the article.")
print("Enter /reset to start over.")
print("Enter /quit to exit.")
print("Enter any other content for a query.")
else:
print("Invalid command.")
print("Enter /summary to generate an embedding-based summary.")
print("Enter /reindex to re-index the article.")
print("Enter /reset to start over.")
print("Enter /quit to exit.")
print("Enter any other content for a query.")
print("=====================================")
continue
else:
# 1. 生成关键词
# 1. Generate keywords.
print("Generate keywords.")
keywords = ai.get_keywords(query)
# 2. 对问题生成embedding
# 2. Generate an embedding for the question.
_, embedding = ai.create_embedding(keywords)
# 3. 从数据库中找到最相似的片段
# 3. Find the most similar fragments from the database.
texts = storage.get_texts(embedding, identify)
print("Related fragments found (first 5):")
for text in texts[:5]:
print('\t', text)
# 4. 把相关片段推给AI,AI会根据这些片段回答问题
# 4. Push the relevant fragments to the AI, which will answer the question based on these fragments.
ai.completion(query, texts)
print("=====================================")
def _get_contents() -> tuple[list[str], str, str]:
"""Get the contents."""
while True:
try:
url = input("Please enter the link to the article or the file path of the PDF/TXT/DOCX document: ").strip()
if os.path.exists(url):
if url.endswith('.pdf'):
contents, data = extract_text_from_pdf(url)
elif url.endswith('.txt'):
contents, data = extract_text_from_txt(url)
elif url.endswith('.docx'):
contents, data = extract_text_from_docx(url)
else:
print("Unsupported file format.")
continue
else:
contents, data = web_crawler_newspaper(url)
if not contents:
print("Unable to retrieve the content of the article. Please enter the link to the article or "
"the file path of the PDF/TXT/DOCX document again.")
continue
return contents, data, xxhash.xxh3_128_hexdigest('\n'.join(contents))
except Exception as e:
print("Error:", e)