forked from ThomasJay/RAG
-
Notifications
You must be signed in to change notification settings - Fork 0
/
apphistory.py
141 lines (110 loc) · 4.13 KB
/
apphistory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from fastapi import FastAPI, File, UploadFile, HTTPException
from pydantic import BaseModel
from langchain_community.llms import Ollama
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
from langchain_community.document_loaders import PDFPlumberLoader
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain.prompts import PromptTemplate
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains.history_aware_retriever import create_history_aware_retriever
import uvicorn
app = FastAPI()
chat_history = []
folder_path = "db"
cached_llm = Ollama(model="llama3")
embedding = FastEmbedEmbeddings()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1024, chunk_overlap=80, length_function=len, is_separator_regex=False
)
raw_prompt = PromptTemplate.from_template(
"""
<s>[INST] You are a technical assistant good at searching documents. If you do not have an answer from the provided information say so. [/INST] </s>
[INST] {input}
Context: {context}
Answer:
[/INST]
"""
)
class Query(BaseModel):
query: str
@app.post("/ai")
async def ai_post(query: Query):
print("Post /ai called")
print(f"query: {query.query}")
response = cached_llm.invoke(query.query)
print(response)
response_answer = {"answer": response}
return response_answer
@app.post("/ask_pdf")
async def ask_pdf_post(query: Query):
print("Post /ask_pdf called")
print(f"query: {query.query}")
print("Loading vector store")
vector_store = Chroma(persist_directory=folder_path, embedding_function=embedding)
print("Creating chain")
retriever = vector_store.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"k": 20,
"score_threshold": 0.1,
},
)
retriever_prompt = ChatPromptTemplate.from_messages(
[
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
(
"human",
"Given the above conversation, generate a search query to lookup in order to get information relevant to the conversation",
),
]
)
history_aware_retriever = create_history_aware_retriever(
llm=cached_llm, retriever=retriever, prompt=retriever_prompt
)
document_chain = create_stuff_documents_chain(cached_llm, raw_prompt)
retrieval_chain = create_retrieval_chain(
history_aware_retriever,
document_chain,
)
result = retrieval_chain.invoke({"input": query.query, "chat_history": chat_history})
print(result["answer"])
chat_history.append(HumanMessage(content=query.query))
chat_history.append(AIMessage(content=result["answer"]))
sources = []
for doc in result["context"]:
sources.append(
{"source": doc.metadata["source"], "page_content": doc.page_content}
)
response_answer = {"answer": result["answer"], "sources": sources}
return response_answer
@app.post("/pdf")
async def pdf_post(file: UploadFile = File(...)):
file_name = file.filename
save_file = f"pdf/{file_name}"
with open(save_file, "wb") as buffer:
content = await file.read()
buffer.write(content)
print(f"filename: {file_name}")
loader = PDFPlumberLoader(save_file)
docs = loader.load_and_split()
print(f"docs len={len(docs)}")
chunks = text_splitter.split_documents(docs)
print(f"chunks len={len(chunks)}")
vector_store = Chroma.from_documents(
documents=chunks, embedding=embedding, persist_directory=folder_path
)
vector_store.persist()
response = {
"status": "Successfully Uploaded",
"filename": file_name,
"doc_len": len(docs),
"chunks": len(chunks),
}
return response
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)