forked from ThomasJay/RAG
-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
116 lines (88 loc) · 3.18 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from fastapi import FastAPI, File, UploadFile, HTTPException
from pydantic import BaseModel
from langchain_community.llms import Ollama
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings
from langchain_community.document_loaders import PDFPlumberLoader
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain.prompts import PromptTemplate
import uvicorn
app = FastAPI()
folder_path = "db"
cached_llm = Ollama(model="llama3")
embedding = FastEmbedEmbeddings()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1024, chunk_overlap=80, length_function=len, is_separator_regex=False
)
raw_prompt = PromptTemplate.from_template(
"""
<s>[INST] You are a technical assistant good at searching documents. If you do not have an answer from the provided information say so. [/INST] </s>
[INST] {input}
Context: {context}
Answer:
[/INST]
"""
)
class Query(BaseModel):
query: str
@app.post("/ai")
async def ai_post(query: Query):
print("Post /ai called")
print(f"query: {query.query}")
response = cached_llm.invoke(query.query)
print(response)
response_answer = {"answer": response}
return response_answer
@app.post("/ask_pdf")
async def ask_pdf_post(query: Query):
print("Post /ask_pdf called")
print(f"query: {query.query}")
print("Loading vector store")
vector_store = Chroma(persist_directory=folder_path, embedding_function=embedding)
print("Creating chain")
retriever = vector_store.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"k": 20,
"score_threshold": 0.1,
},
)
document_chain = create_stuff_documents_chain(cached_llm, raw_prompt)
chain = create_retrieval_chain(retriever, document_chain)
result = chain.invoke({"input": query.query})
print(result)
sources = []
for doc in result["context"]:
sources.append(
{"source": doc.metadata["source"], "page_content": doc.page_content}
)
response_answer = {"answer": result["answer"], "sources": sources}
return response_answer
@app.post("/pdf")
async def pdf_post(file: UploadFile = File(...)):
file_name = file.filename
save_file = f"pdf/{file_name}"
with open(save_file, "wb") as buffer:
content = await file.read()
buffer.write(content)
print(f"filename: {file_name}")
loader = PDFPlumberLoader(save_file)
docs = loader.load_and_split()
print(f"docs len={len(docs)}")
chunks = text_splitter.split_documents(docs)
print(f"chunks len={len(chunks)}")
vector_store = Chroma.from_documents(
documents=chunks, embedding=embedding, persist_directory=folder_path
)
vector_store.persist()
response = {
"status": "Successfully Uploaded",
"filename": file_name,
"doc_len": len(docs),
"chunks": len(chunks),
}
return response
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)