Skip to content

Commit

Permalink
fixed. #16
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Aug 22, 2024
1 parent e2833b9 commit 522a450
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 44 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ E2B_API_KEY=""

# for web ui
DATA_DIR="~/.cache/chatpilot/data"
AGENTICA_HOME="~/.cache/chatpilot/data"

# DO NOT TRACK
SCARF_NO_ANALYTICS=true
Expand Down
11 changes: 8 additions & 3 deletions chatpilot/apps/rag_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
Embeddings,
)
from chromadb.utils import embedding_functions
from chromadb.utils.embedding_functions.openai_embedding_function import OpenAIEmbeddingFunction
from chromadb.utils.embedding_functions.text2vec_embedding_function import Text2VecEmbeddingFunction
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
SentenceTransformerEmbeddingFunction
)
from fastapi import (
FastAPI,
Depends,
Expand Down Expand Up @@ -116,13 +121,13 @@ def __call__(self, input: ChromaDocuments) -> Embeddings:


if "text-embedding" in app.state.RAG_EMBEDDING_MODEL and app.state.OPENAI_API_KEYS and app.state.OPENAI_API_KEYS[0]:
app.state.sentence_transformer_ef = embedding_functions.OpenAIEmbeddingFunction(
app.state.sentence_transformer_ef = OpenAIEmbeddingFunction(
api_key=app.state.OPENAI_API_KEYS[0],
api_base=app.state.OPENAI_API_BASE_URLS[0],
model_name=app.state.RAG_EMBEDDING_MODEL,
)
elif "text2vec" in app.state.RAG_EMBEDDING_MODEL:
app.state.sentence_transformer_ef = embedding_functions.Text2VecEmbeddingFunction(
app.state.sentence_transformer_ef = Text2VecEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL
)
elif "w2v" in app.state.RAG_EMBEDDING_MODEL:
Expand Down Expand Up @@ -237,7 +242,7 @@ async def update_embedding_model(
model_name=app.state.RAG_EMBEDDING_MODEL
)
elif app.state.RAG_EMBEDDING_MODEL:
app.state.sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
app.state.sentence_transformer_ef = SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL
)
else:
Expand Down
6 changes: 4 additions & 2 deletions chatpilot/langchain_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
ENABLE_URL_CRAWLER_TOOL,
ENABLE_RUN_PYTHON_CODE_TOOL,
REACT_RPOMPT,
OPENAI_API_KEY,
OPENAI_API_BASE,
)
from chatpilot.react_parser import ReActParserAndNoTool

Expand All @@ -40,8 +42,8 @@ def __init__(
self,
model_type: str = "openai",
model_name: str = "gpt-3.5-turbo-1106",
model_api_key: str = os.getenv("OPENAI_API_KEY"),
model_api_base: str = os.getenv("OPENAI_API_BASE"),
model_api_key: str = os.getenv("OPENAI_API_KEY") or OPENAI_API_KEY,
model_api_base: str = os.getenv("OPENAI_API_BASE") or OPENAI_API_BASE,
search_name: Optional[str] = "serper",
agent_type: str = "react",
enable_search_tool: Optional[bool] = None,
Expand Down
20 changes: 16 additions & 4 deletions chatpilot/rag_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.embeddings.text2vec import Text2vecEmbeddings
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from loguru import logger

from chatpilot.config import RAG_TEMPLATE
import os
from chatpilot.config import RAG_TEMPLATE,OPENAI_API_KEY,OPENAI_API_BASE


class RagFusion:
def __init__(
self,
documents: List[Document],
model_api_key: str = os.getenv("OPENAI_API_KEY") or OPENAI_API_KEY,
model_api_base: str = os.getenv("OPENAI_API_BASE") or OPENAI_API_BASE,
openai_model: str = "gpt-3.5-turbo-1106",
generate_model: str = "gpt-3.5-turbo-16k",
temperature: float = 0.0,
Expand All @@ -41,19 +45,26 @@ def __init__(
])
# Using LLM generate more queries
self.requery_model = ChatOpenAI(
openai_api_key=model_api_key,
openai_api_base=model_api_base,
temperature=temperature,
model=openai_model,
)
vectorstore = Chroma.from_documents(
documents,
OpenAIEmbeddings()
OpenAIEmbeddings(
openai_api_key=model_api_key,
openai_api_base=model_api_base,
)
)
self.retriever = vectorstore.as_retriever()

self.rag_prompt = ChatPromptTemplate.from_template(RAG_TEMPLATE)

# LLM to RAG model
self.generate_model = ChatOpenAI(
openai_api_key=model_api_key,
openai_api_base=model_api_base,
temperature=temperature,
model=generate_model,
)
Expand Down Expand Up @@ -110,7 +121,8 @@ def run(self, question: str):
}

text_documents = [Document(page_content=doc, metadata={"source": f"{id}"}) for id, doc in all_documents.items()]
print('text_documents:', text_documents)
rag_fusion = RagFusion(text_documents)
question = "气候变化的影响"
question = "气候变化及其经济影响"
result = rag_fusion.run(question)
print(result)
23 changes: 0 additions & 23 deletions examples/chat_agent_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,26 +82,3 @@ def demo5():
print("===")


def demo6():
import asyncio
async def d():
m = LangchainAssistant()

questions = [
"俄罗斯今日新闻top3",
# "人体最大的器官是啥",
# "how many letters in the word 'educabe'?",
# "它是一个真的单词吗?",
]
for i in questions:
print(i)
events = await m.astream_run(i)
async for event in events:
print(event)
print("===")
pass

asyncio.run(d())


demo6()
13 changes: 2 additions & 11 deletions examples/chat_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@
def demo6():
import asyncio
m = LangchainAssistant(
model_type='azure',
model_name="gpt-35-turbo",
model_type='openai',
model_name="gpt-3.5-turbo",
model_api_key=os.getenv("OPENAI_API_KEYS"),
model_api_base=os.getenv("OPENAI_API_BASE_URLS"),
search_name="serper",
agent_type="react",
enable_search_tool=True,
enable_run_python_code_tool=False,
enable_crawler_tool=False,
streaming=True,
)
async def d():
Expand All @@ -51,19 +49,12 @@ async def d():

def demo5():
m = LangchainAssistant(
model_type='azure',
model_name="gpt-35-turbo",
model_api_key=os.getenv("OPENAI_API_KEYS"),
model_api_base=os.getenv("OPENAI_API_BASE_URLS"),
search_name="serper",
agent_type="react",
enable_search_tool=True,
enable_run_python_code_tool=True,
enable_crawler_tool=False,
)
questions = [
# "今天的俄罗斯相关的新闻top3有哪些?",
# "今天北京的天气怎么样?",
"人类最大的器官是?"
]
for i in questions:
Expand Down
4 changes: 3 additions & 1 deletion examples/rag_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
file_url = "https://docs.smith.langchain.com/overview"
loader = WebBaseLoader(file_url)
docs = loader.load()
print(docs)
web_documents = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_overlap=200
chunk_size=2000, chunk_overlap=200
).split_documents(docs)
print('web_documents:', web_documents)
rag_fusion = RagFusion(web_documents)
question = "LangSmith是啥"
result = rag_fusion.run(question)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ jieba>=0.39
loguru
tqdm
pandas
tiktoken
langchain~=0.1.11
langchain-community~=0.0.27
langchain-openai~=0.0.8
Expand Down

0 comments on commit 522a450

Please sign in to comment.