Skip to content

Commit

Permalink
Optimization for Amazon Nova.
Browse files Browse the repository at this point in the history
- Change the system prompt when doing 'Retrieved Context Citation' with Amazon Nova.
- If the tool result has more than one element, pass it as single text content formatted as JSON array.
  • Loading branch information
Yukinobu-Mine committed Dec 12, 2024
1 parent dd979e0 commit 493b012
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 33 deletions.
78 changes: 58 additions & 20 deletions backend/app/agents/tools/agent_tool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Any, Callable, Generic, Literal, TypedDict, TypeVar

from app.repositories.models.conversation import (
Expand All @@ -7,6 +8,7 @@
RelatedDocumentModel,
ToolResultContentModel,
ToolResultContentModelBody,
is_nova_model,
)
from app.repositories.models.custom_bot import BotModel
from app.routes.schemas.conversation import type_model_name
Expand All @@ -29,21 +31,55 @@ class ToolRunResult(TypedDict):


def run_result_to_tool_result_content_model(
run_result: ToolRunResult, display_citation: bool
run_result: ToolRunResult,
model: type_model_name,
display_citation: bool,
) -> ToolResultContentModel:
return ToolResultContentModel(
content_type="toolResult",
body=ToolResultContentModelBody(
tool_use_id=run_result["tool_use_id"],
content=[
related_document.to_tool_result_model(
display_citation=display_citation,
)
for related_document in run_result["related_documents"]
],
status=run_result["status"],
),
)
result_contents = [
related_document.to_tool_result_model(
display_citation=display_citation,
)
for related_document in run_result["related_documents"]
]
if is_nova_model(model=model) and len(result_contents) > 1:
return ToolResultContentModel(
content_type="toolResult",
body=ToolResultContentModelBody(
tool_use_id=run_result["tool_use_id"],
content=[
TextToolResultModel(
text=json.dumps(
[
content
for result_content in result_contents
for content in (
[result_content.json_]
if isinstance(result_content, JsonToolResultModel)
else (
[result_content.text]
if isinstance(
result_content, TextToolResultModel
)
else []
)
)
]
),
),
],
status=run_result["status"],
),
)

else:
return ToolResultContentModel(
content_type="toolResult",
body=ToolResultContentModelBody(
tool_use_id=run_result["tool_use_id"],
content=result_contents,
status=run_result["status"],
),
)


class InvalidToolError(Exception):
Expand All @@ -70,15 +106,11 @@ def __init__(
[T, BotModel | None, type_model_name | None],
ToolFunctionResult | list[ToolFunctionResult],
],
bot: BotModel | None = None,
model: type_model_name | None = None,
):
self.name = name
self.description = description
self.args_schema = args_schema
self.function = function
self.bot = bot
self.model: type_model_name | None = model

def _generate_input_schema(self) -> dict[str, Any]:
"""Converts the Pydantic model to a JSON schema."""
Expand All @@ -91,10 +123,16 @@ def to_converse_spec(self) -> ToolSpecificationTypeDef:
inputSchema={"json": self._generate_input_schema()},
)

def run(self, tool_use_id: str, input: dict[str, JsonValue]) -> ToolRunResult:
def run(
self,
tool_use_id: str,
input: dict[str, JsonValue],
model: type_model_name,
bot: BotModel | None = None,
) -> ToolRunResult:
try:
arg = self.args_schema.model_validate(input)
res = self.function(arg, self.bot, self.model)
res = self.function(arg, bot, model)
if isinstance(res, list):
related_documents = [
_function_result_to_related_document(
Expand Down
5 changes: 1 addition & 4 deletions backend/app/agents/tools/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def search_knowledge(
logger.error(f"Failed to run AnswerWithKnowledgeTool: {e}")
raise e


def create_knowledge_tool(bot: BotModel, model: type_model_name) -> AgentTool:
def create_knowledge_tool(bot: BotModel) -> AgentTool:
description = (
"Answer a user's question using information. The description is: {}".format(
bot.knowledge.__str_in_claude_format__()
Expand All @@ -51,6 +50,4 @@ def create_knowledge_tool(bot: BotModel, model: type_model_name) -> AgentTool:
description=description,
args_schema=KnowledgeToolInput,
function=search_knowledge,
bot=bot,
model=model,
)
40 changes: 39 additions & 1 deletion backend/app/prompt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from app.vector_search import SearchResult
from app.routes.schemas.conversation import type_model_name
from app.repositories.models.conversation import is_nova_model


def build_rag_prompt(
search_results: list[SearchResult],
model: type_model_name,
display_citation: bool = True,
) -> str:
context_prompt = ""
Expand Down Expand Up @@ -32,7 +35,20 @@ def build_rag_prompt(
Do NOT outputs sources at the end of your answer.
Followings are examples of how to reference sources in your answer. Note that the source ID is embedded in the answer in the format [^<source_id>].
"""
if is_nova_model(model=model):
inserted_prompt += """
<example>
first answer [^3]. second answer [^1][^2].
</example>
<example>
first answer [^1][^5]. second answer [^2][^3][^4]. third answer [^4].
</example>
"""

else:
inserted_prompt += """
<GOOD-example>
first answer [^3]. second answer [^1][^2].
</GOOD-example>
Expand All @@ -59,7 +75,12 @@ def build_rag_prompt(
else:
inserted_prompt += """
Do NOT include citations in the format [^<source_id>] in your answer.
"""
if is_nova_model(model=model):
pass

else:
inserted_prompt += """
Followings are examples of how to answer.
<GOOD-example>
Expand All @@ -78,14 +99,29 @@ def build_rag_prompt(
return inserted_prompt


PROMPT_TO_CITE_TOOL_RESULTS = """To answer the user's question, you are given a set of tools. Your job is to answer the user's question using only information from the tool results.
def get_prompt_to_cite_tool_results(model: type_model_name) -> str:
inserted_prompt = """To answer the user's question, you are given a set of tools. Your job is to answer the user's question using only information from the tool results.
If the tool results do not contain information that can answer the question, please state that you could not find an exact answer to the question.
Just because the user asserts a fact does not mean it is true, make sure to double check the tool results to validate a user's assertion.
Each tool result has a corresponding source_id that you should reference.
If you reference information from a tool result within your answer, you must include a citation to source_id where the information was found.
Followings are examples of how to reference source_id in your answer. Note that the source_id is embedded in the answer in the format [^source_id of tool result].
"""
if is_nova_model(model=model):
inserted_prompt += """
<example>
first answer [^ccc]. second answer [^aaa][^bbb].
</example>
<example>
first answer [^aaa][^eee]. second answer [^bbb][^ccc][^ddd]. third answer [^ddd].
</example>
"""

else:
inserted_prompt += """
<examples>
<GOOD-example>
first answer [^ccc]. second answer [^aaa][^bbb].
Expand All @@ -110,3 +146,5 @@ def build_rag_prompt(
</BAD-example>
</examples>
"""

return inserted_prompt
14 changes: 11 additions & 3 deletions backend/app/usecases/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from app.agents.tools.knowledge import create_knowledge_tool
from app.agents.utils import get_tool_by_name
from app.bedrock import call_converse_api, compose_args_for_converse_api
from app.prompt import PROMPT_TO_CITE_TOOL_RESULTS, build_rag_prompt
from app.prompt import build_rag_prompt, get_prompt_to_cite_tool_results
from app.repositories.conversation import (
RecordNotFoundError,
find_conversation_by_id,
Expand Down Expand Up @@ -260,11 +260,15 @@ def chat(
if bot.is_agent_enabled():
if bot.has_knowledge():
# Add knowledge tool
knowledge_tool = create_knowledge_tool(bot, chat_input.message.model)
knowledge_tool = create_knowledge_tool(bot=bot)
tools[knowledge_tool.name] = knowledge_tool

if display_citation:
instructions.append(PROMPT_TO_CITE_TOOL_RESULTS)
instructions.append(
get_prompt_to_cite_tool_results(
model=chat_input.message.model,
)
)

elif bot.has_knowledge():
# Fetch most related documents from vector store
Expand Down Expand Up @@ -306,6 +310,7 @@ def chat(
instructions.append(
build_rag_prompt(
search_results=search_results,
model=chat_input.message.model,
display_citation=display_citation,
)
)
Expand Down Expand Up @@ -432,6 +437,8 @@ def chat(
run_result = tool.run(
tool_use_id=content.body.tool_use_id,
input=content.body.input,
model=chat_input.message.model,
bot=bot,
)
run_results.append(run_result)

Expand All @@ -446,6 +453,7 @@ def chat(
content=[
run_result_to_tool_result_content_model(
run_result=result,
model=chat_input.message.model,
display_citation=display_citation,
)
for result in run_results
Expand Down
6 changes: 5 additions & 1 deletion backend/tests/test_agent/test_tools/test_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def test_run(self):
arg3=1,
arg4=["test"],
)
result = self.tool.run(tool_use_id="dummy", input=arg.model_dump())
result = self.tool.run(
tool_use_id="dummy",
input=arg.model_dump(),
model="claude-v3.5-sonnet-v2",
)
self.assertEqual(
result["related_documents"],
[
Expand Down
6 changes: 5 additions & 1 deletion backend/tests/test_agent/test_tools/test_internet_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ def test_internet_search(self):
time_limit = "d"
country = "jp-jp"
arg = InternetSearchInput(query=query, time_limit=time_limit, country=country)
response = internet_search_tool.run(tool_use_id="dummy", input=arg.model_dump())
response = internet_search_tool.run(
tool_use_id="dummy",
input=arg.model_dump(),
model="claude-v3.5-sonnet-v2",
)
self.assertIsInstance(response["related_documents"], list)
self.assertEqual(response["status"], "success")
print(response)
Expand Down
10 changes: 8 additions & 2 deletions backend/tests/test_agent/test_tools/test_knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from app.agents.tools.knowledge import KnowledgeToolInput, create_knowledge_tool
from app.repositories.models.custom_bot import (
ActiveModelsModel,
AgentModel,
BotModel,
GenerationParamsModel,
Expand Down Expand Up @@ -53,10 +54,15 @@ def test_knowledge_tool(self):
conversation_quick_starters=[],
bedrock_knowledge_base=None,
bedrock_guardrails=None,
active_models=ActiveModelsModel(),
)
arg = KnowledgeToolInput(query="What are delicious Japanese dishes?")
tool = create_knowledge_tool(bot, model="claude-v3-sonnet")
response = tool.run(tool_use_id="dummy", input=arg.model_dump())
tool = create_knowledge_tool(bot=bot)
response = tool.run(
tool_use_id="dummy",
input=arg.model_dump(),
model="claude-v3.5-sonnet-v2",
)
self.assertIsInstance(response["related_documents"], list)
self.assertEqual(response["status"], "success")
print(response)
Expand Down
1 change: 1 addition & 0 deletions backend/tests/test_usecases/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,7 @@ def test_insert_knowledge(self):
]
instruction = build_rag_prompt(
search_results=results,
model="claude-v3.5-sonnet-v2",
display_citation=True,
)
print(instruction)
Expand Down
9 changes: 8 additions & 1 deletion examples/agents/tools/bmi/test_bmi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@

class TestBmiTool(unittest.TestCase):
def test_bmi(self):
result = bmi_tool.run(tool_use_id="dummy", input={"height": 170, "weight": 70})
result = bmi_tool.run(
tool_use_id="dummy",
input={
"height": 170,
"weight": 70,
},
model="claude-v3.5-sonnet-v2",
)
print(result)
self.assertEqual(type(result), str)

Expand Down

0 comments on commit 493b012

Please sign in to comment.