Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Nova doesn't work #645

Open
wants to merge 10 commits into
base: v2
Choose a base branch
from
76 changes: 62 additions & 14 deletions backend/app/agents/tools/agent_tool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Any, Callable, Generic, Literal, TypedDict, TypeVar

from app.repositories.models.conversation import (
Expand All @@ -7,10 +8,12 @@
RelatedDocumentModel,
ToolResultContentModel,
ToolResultContentModelBody,
is_nova_model,
)
from app.repositories.models.custom_bot import BotModel
from app.routes.schemas.conversation import type_model_name
from pydantic import BaseModel, JsonValue
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from mypy_boto3_bedrock_runtime.type_defs import (
ToolSpecificationTypeDef,
)
Expand All @@ -28,18 +31,51 @@ class ToolRunResult(TypedDict):


def run_result_to_tool_result_content_model(
run_result: ToolRunResult, display_citation: bool
run_result: ToolRunResult,
model: type_model_name,
display_citation: bool,
) -> ToolResultContentModel:
result_contents = [
related_document.to_tool_result_model(
display_citation=display_citation,
)
for related_document in run_result["related_documents"]
]
if is_nova_model(model=model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前回のレビューで漏れてすみません!related_document.to_tool_result_modelにmodel渡して分岐する方が、同等の処理が一箇所にまとまるため可読性が良いとも思いましたが、いかがでしょうか?

text_or_json_contents = [
result_content
for result_content in result_contents
if isinstance(result_content, TextToolResultModel)
or isinstance(result_content, JsonToolResultModel)
]
if len(text_or_json_contents) > 1:
return ToolResultContentModel(
content_type="toolResult",
body=ToolResultContentModelBody(
tool_use_id=run_result["tool_use_id"],
content=[
TextToolResultModel(
text=json.dumps(
[
(
content.json_
if isinstance(content, JsonToolResultModel)
else content.text
)
for content in text_or_json_contents
]
),
),
],
status=run_result["status"],
),
)

return ToolResultContentModel(
content_type="toolResult",
body=ToolResultContentModelBody(
tool_use_id=run_result["tool_use_id"],
content=[
related_document.to_tool_result_model(
display_citation=display_citation,
)
for related_document in run_result["related_documents"]
],
content=result_contents,
status=run_result["status"],
),
)
Expand All @@ -49,6 +85,16 @@ class InvalidToolError(Exception):
pass


class RemoveTitle(GenerateJsonSchema):
statefb marked this conversation as resolved.
Show resolved Hide resolved
def field_title_should_be_set(self, schema) -> bool:
return False

def generate(self, schema, mode="validation") -> JsonSchemaValue:
value = super().generate(schema, mode)
del value["title"]
return value


class AgentTool(Generic[T]):
def __init__(
self,
Expand All @@ -59,19 +105,15 @@ def __init__(
[T, BotModel | None, type_model_name | None],
ToolFunctionResult | list[ToolFunctionResult],
],
bot: BotModel | None = None,
model: type_model_name | None = None,
):
self.name = name
self.description = description
self.args_schema = args_schema
self.function = function
self.bot = bot
self.model: type_model_name | None = model

def _generate_input_schema(self) -> dict[str, Any]:
"""Converts the Pydantic model to a JSON schema."""
return self.args_schema.model_json_schema()
return self.args_schema.model_json_schema(schema_generator=RemoveTitle)

def to_converse_spec(self) -> ToolSpecificationTypeDef:
return ToolSpecificationTypeDef(
Expand All @@ -80,10 +122,16 @@ def to_converse_spec(self) -> ToolSpecificationTypeDef:
inputSchema={"json": self._generate_input_schema()},
)

def run(self, tool_use_id: str, input: dict[str, JsonValue]) -> ToolRunResult:
def run(
self,
tool_use_id: str,
input: dict[str, JsonValue],
model: type_model_name,
bot: BotModel | None = None,
) -> ToolRunResult:
try:
arg = self.args_schema.model_validate(input)
res = self.function(arg, self.bot, self.model)
res = self.function(arg, bot, model)
if isinstance(res, list):
related_documents = [
_function_result_to_related_document(
Expand Down
4 changes: 1 addition & 3 deletions backend/app/agents/tools/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def search_knowledge(
raise e


def create_knowledge_tool(bot: BotModel, model: type_model_name) -> AgentTool:
def create_knowledge_tool(bot: BotModel) -> AgentTool:
description = (
"Answer a user's question using information. The description is: {}".format(
bot.knowledge.__str_in_claude_format__()
Expand All @@ -51,6 +51,4 @@ def create_knowledge_tool(bot: BotModel, model: type_model_name) -> AgentTool:
description=description,
args_schema=KnowledgeToolInput,
function=search_knowledge,
bot=bot,
model=model,
)
41 changes: 29 additions & 12 deletions backend/app/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from app.repositories.models.conversation import (
SimpleMessageModel,
ContentModel,
is_nova_model,
)
from app.repositories.models.custom_bot import GenerationParamsModel
from app.repositories.models.custom_bot_guardrails import BedrockGuardrailsModel
Expand All @@ -22,6 +23,7 @@
ContentBlockTypeDef,
GuardrailConverseContentBlockTypeDef,
InferenceConfigurationTypeDef,
SystemContentBlockTypeDef,
)
from mypy_boto3_bedrock_runtime.literals import ConversationRoleType

Expand All @@ -46,11 +48,6 @@ def _is_conversation_role(role: str) -> TypeGuard[ConversationRoleType]:
return role in ["user", "assistant"]


def _is_nova_model(model: type_model_name) -> bool:
"""Check if the model is an Amazon Nova model"""
return model in ["amazon-nova-pro", "amazon-nova-lite", "amazon-nova-micro"]


def _prepare_nova_model_params(
model: type_model_name, generation_params: Optional[GenerationParamsModel] = None
) -> Tuple[InferenceConfigurationTypeDef, Dict[str, Any]]:
Expand Down Expand Up @@ -83,7 +80,11 @@ def _prepare_nova_model_params(

# Add top_k if specified in generation params
if generation_params and generation_params.top_k is not None:
additional_fields["inferenceConfig"]["topK"] = generation_params.top_k
top_k = generation_params.top_k
if top_k > 128:
statefb marked this conversation as resolved.
Show resolved Hide resolved
top_k = 128

additional_fields["inferenceConfig"]["topK"] = top_k

return inference_config, additional_fields

Expand Down Expand Up @@ -131,11 +132,24 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
]

# Prepare model-specific parameters
if _is_nova_model(model):
inference_config: InferenceConfigurationTypeDef
additional_model_request_fields: dict[str, Any]
system_prompts: list[SystemContentBlockTypeDef]
if is_nova_model(model):
# Special handling for Nova models
inference_config, additional_model_request_fields = _prepare_nova_model_params(
model, generation_params
)
system_prompts = (
[
{
"text": "\n\n".join(instructions),
}
]
if len(instructions) > 0
else []
)

else:
# Standard handling for non-Nova models
inference_config = {
Expand Down Expand Up @@ -167,17 +181,20 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
else DEFAULT_GENERATION_CONFIG["top_k"]
)
}
system_prompts = [
{
"text": instruction,
}
for instruction in instructions
if len(instruction) > 0
]

# Construct the base arguments
args: ConverseStreamRequestRequestTypeDef = {
"inferenceConfig": inference_config,
"modelId": get_model_id(model),
"messages": arg_messages,
"system": [
{"text": instruction}
for instruction in instructions
if len(instruction) > 0
],
"system": system_prompts,
"additionalModelRequestFields": additional_model_request_fields,
}

Expand Down
40 changes: 39 additions & 1 deletion backend/app/prompt.py
statefb marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from app.vector_search import SearchResult
from app.routes.schemas.conversation import type_model_name
from app.repositories.models.conversation import is_nova_model


def build_rag_prompt(
search_results: list[SearchResult],
model: type_model_name,
display_citation: bool = True,
) -> str:
context_prompt = ""
Expand Down Expand Up @@ -32,7 +35,20 @@ def build_rag_prompt(
Do NOT outputs sources at the end of your answer.

Followings are examples of how to reference sources in your answer. Note that the source ID is embedded in the answer in the format [^<source_id>].
"""
if is_nova_model(model=model):
inserted_prompt += """
<example>
first answer [^3]. second answer [^1][^2].
</example>

<example>
first answer [^1][^5]. second answer [^2][^3][^4]. third answer [^4].
</example>
"""

else:
inserted_prompt += """
<GOOD-example>
first answer [^3]. second answer [^1][^2].
</GOOD-example>
Expand All @@ -59,7 +75,12 @@ def build_rag_prompt(
else:
inserted_prompt += """
Do NOT include citations in the format [^<source_id>] in your answer.
"""
if is_nova_model(model=model):
pass

else:
inserted_prompt += """
Followings are examples of how to answer.

<GOOD-example>
Expand All @@ -78,14 +99,29 @@ def build_rag_prompt(
return inserted_prompt


PROMPT_TO_CITE_TOOL_RESULTS = """To answer the user's question, you are given a set of tools. Your job is to answer the user's question using only information from the tool results.
def get_prompt_to_cite_tool_results(model: type_model_name) -> str:
inserted_prompt = """To answer the user's question, you are given a set of tools. Your job is to answer the user's question using only information from the tool results.
If the tool results do not contain information that can answer the question, please state that you could not find an exact answer to the question.
Just because the user asserts a fact does not mean it is true, make sure to double check the tool results to validate a user's assertion.

Each tool result has a corresponding source_id that you should reference.
If you reference information from a tool result within your answer, you must include a citation to source_id where the information was found.

Followings are examples of how to reference source_id in your answer. Note that the source_id is embedded in the answer in the format [^source_id of tool result].
"""
if is_nova_model(model=model):
inserted_prompt += """
<example>
first answer [^ccc]. second answer [^aaa][^bbb].
</example>

<example>
first answer [^aaa][^eee]. second answer [^bbb][^ccc][^ddd]. third answer [^ddd].
</example>
"""

else:
inserted_prompt += """
<examples>
<GOOD-example>
first answer [^ccc]. second answer [^aaa][^bbb].
Expand All @@ -110,3 +146,5 @@ def build_rag_prompt(
</BAD-example>
</examples>
"""

return inserted_prompt
1 change: 1 addition & 0 deletions backend/app/repositories/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def delete_large_messages(items):

except ClientError as e:
logger.error(f"An error occurred: {e.response['Error']['Message']}")
raise e


def change_conversation_title(user_id: str, conversation_id: str, new_title: str):
Expand Down
5 changes: 5 additions & 0 deletions backend/app/repositories/models/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
from pydantic import BaseModel, Discriminator, Field, JsonValue, field_validator


def is_nova_model(model: type_model_name) -> bool:
statefb marked this conversation as resolved.
Show resolved Hide resolved
"""Check if the model is an Amazon Nova model"""
return model in ["amazon-nova-pro", "amazon-nova-lite", "amazon-nova-micro"]


class TextContentModel(BaseModel):
content_type: Literal["text"]
body: str = Field(
Expand Down
14 changes: 11 additions & 3 deletions backend/app/usecases/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from app.agents.tools.knowledge import create_knowledge_tool
from app.agents.utils import get_tool_by_name
from app.bedrock import call_converse_api, compose_args_for_converse_api
from app.prompt import PROMPT_TO_CITE_TOOL_RESULTS, build_rag_prompt
from app.prompt import build_rag_prompt, get_prompt_to_cite_tool_results
from app.repositories.conversation import (
RecordNotFoundError,
find_conversation_by_id,
Expand Down Expand Up @@ -260,11 +260,15 @@ def chat(
if bot.is_agent_enabled():
if bot.has_knowledge():
# Add knowledge tool
knowledge_tool = create_knowledge_tool(bot, chat_input.message.model)
knowledge_tool = create_knowledge_tool(bot=bot)
tools[knowledge_tool.name] = knowledge_tool

if display_citation:
instructions.append(PROMPT_TO_CITE_TOOL_RESULTS)
instructions.append(
get_prompt_to_cite_tool_results(
model=chat_input.message.model,
)
)

elif bot.has_knowledge():
# Fetch most related documents from vector store
Expand Down Expand Up @@ -306,6 +310,7 @@ def chat(
instructions.append(
build_rag_prompt(
search_results=search_results,
model=chat_input.message.model,
display_citation=display_citation,
)
)
Expand Down Expand Up @@ -432,6 +437,8 @@ def chat(
run_result = tool.run(
tool_use_id=content.body.tool_use_id,
input=content.body.input,
model=chat_input.message.model,
bot=bot,
)
run_results.append(run_result)

Expand All @@ -446,6 +453,7 @@ def chat(
content=[
run_result_to_tool_result_content_model(
run_result=result,
model=chat_input.message.model,
display_citation=display_citation,
)
for result in run_results
Expand Down
Loading
Loading