Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Nova doesn't work #645

Open
wants to merge 10 commits into
base: v2
Choose a base branch
from
80 changes: 66 additions & 14 deletions backend/app/agents/tools/agent_tool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Any, Callable, Generic, Literal, TypedDict, TypeVar

from app.repositories.models.conversation import (
Expand All @@ -11,6 +12,7 @@
from app.repositories.models.custom_bot import BotModel
from app.routes.schemas.conversation import type_model_name
from pydantic import BaseModel, JsonValue
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from mypy_boto3_bedrock_runtime.type_defs import (
ToolSpecificationTypeDef,
)
Expand All @@ -28,18 +30,53 @@ class ToolRunResult(TypedDict):


def run_result_to_tool_result_content_model(
run_result: ToolRunResult, display_citation: bool
run_result: ToolRunResult,
model: type_model_name,
display_citation: bool,
) -> ToolResultContentModel:
result_contents = [
related_document.to_tool_result_model(
display_citation=display_citation,
)
for related_document in run_result["related_documents"]
]
from app.bedrock import is_nova_model

if is_nova_model(model=model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前回のレビューで漏れてすみません!related_document.to_tool_result_modelにmodel渡して分岐する方が、同等の処理が一箇所にまとまるため可読性が良いとも思いましたが、いかがでしょうか?

text_or_json_contents = [
result_content
for result_content in result_contents
if isinstance(result_content, TextToolResultModel)
or isinstance(result_content, JsonToolResultModel)
]
if len(text_or_json_contents) > 1:
return ToolResultContentModel(
content_type="toolResult",
body=ToolResultContentModelBody(
tool_use_id=run_result["tool_use_id"],
content=[
TextToolResultModel(
text=json.dumps(
[
(
content.json_
if isinstance(content, JsonToolResultModel)
else content.text
)
for content in text_or_json_contents
]
),
),
],
status=run_result["status"],
),
)

return ToolResultContentModel(
content_type="toolResult",
body=ToolResultContentModelBody(
tool_use_id=run_result["tool_use_id"],
content=[
related_document.to_tool_result_model(
display_citation=display_citation,
)
for related_document in run_result["related_documents"]
],
content=result_contents,
status=run_result["status"],
),
)
Expand All @@ -49,6 +86,18 @@ class InvalidToolError(Exception):
pass


class RemoveTitle(GenerateJsonSchema):
statefb marked this conversation as resolved.
Show resolved Hide resolved
"""Custom JSON schema generator that doesn't output `title`s for types and parameters."""

def field_title_should_be_set(self, schema) -> bool:
return False

def generate(self, schema, mode="validation") -> JsonSchemaValue:
value = super().generate(schema, mode)
del value["title"]
return value


class AgentTool(Generic[T]):
def __init__(
self,
Expand All @@ -59,19 +108,16 @@ def __init__(
[T, BotModel | None, type_model_name | None],
ToolFunctionResult | list[ToolFunctionResult],
],
bot: BotModel | None = None,
model: type_model_name | None = None,
):
self.name = name
self.description = description
self.args_schema = args_schema
self.function = function
self.bot = bot
self.model: type_model_name | None = model

def _generate_input_schema(self) -> dict[str, Any]:
"""Converts the Pydantic model to a JSON schema."""
return self.args_schema.model_json_schema()
# Specify a custom generator `RemoveTitle` because some foundation models do not work properly if there are unnecessary titles.
return self.args_schema.model_json_schema(schema_generator=RemoveTitle)

def to_converse_spec(self) -> ToolSpecificationTypeDef:
return ToolSpecificationTypeDef(
Expand All @@ -80,10 +126,16 @@ def to_converse_spec(self) -> ToolSpecificationTypeDef:
inputSchema={"json": self._generate_input_schema()},
)

def run(self, tool_use_id: str, input: dict[str, JsonValue]) -> ToolRunResult:
def run(
self,
tool_use_id: str,
input: dict[str, JsonValue],
model: type_model_name,
bot: BotModel | None = None,
) -> ToolRunResult:
try:
arg = self.args_schema.model_validate(input)
res = self.function(arg, self.bot, self.model)
res = self.function(arg, bot, model)
if isinstance(res, list):
related_documents = [
_function_result_to_related_document(
Expand Down
4 changes: 1 addition & 3 deletions backend/app/agents/tools/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def search_knowledge(
raise e


def create_knowledge_tool(bot: BotModel, model: type_model_name) -> AgentTool:
def create_knowledge_tool(bot: BotModel) -> AgentTool:
description = (
"Answer a user's question using information. The description is: {}".format(
bot.knowledge.__str_in_claude_format__()
Expand All @@ -51,6 +51,4 @@ def create_knowledge_tool(bot: BotModel, model: type_model_name) -> AgentTool:
description=description,
args_schema=KnowledgeToolInput,
function=search_knowledge,
bot=bot,
model=model,
)
64 changes: 45 additions & 19 deletions backend/app/bedrock.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations
import logging
import os
from typing import TypeGuard, Dict, Any, Optional, Tuple
from typing import TypeGuard, Dict, Any, Optional, Tuple, TYPE_CHECKING

from app.agents.tools.agent_tool import AgentTool
from app.config import BEDROCK_PRICING
from app.config import DEFAULT_GENERATION_CONFIG as DEFAULT_CLAUDE_GENERATION_CONFIG
from app.config import DEFAULT_MISTRAL_GENERATION_CONFIG
Expand All @@ -15,15 +15,18 @@
from app.routes.schemas.conversation import type_model_name
from app.utils import get_bedrock_runtime_client

from mypy_boto3_bedrock_runtime.type_defs import (
ConverseStreamRequestRequestTypeDef,
MessageTypeDef,
ConverseResponseTypeDef,
ContentBlockTypeDef,
GuardrailConverseContentBlockTypeDef,
InferenceConfigurationTypeDef,
)
from mypy_boto3_bedrock_runtime.literals import ConversationRoleType
if TYPE_CHECKING:
from app.agents.tools.agent_tool import AgentTool
from mypy_boto3_bedrock_runtime.type_defs import (
ConverseStreamRequestRequestTypeDef,
MessageTypeDef,
ConverseResponseTypeDef,
ContentBlockTypeDef,
GuardrailConverseContentBlockTypeDef,
InferenceConfigurationTypeDef,
SystemContentBlockTypeDef,
)
from mypy_boto3_bedrock_runtime.literals import ConversationRoleType

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand All @@ -46,7 +49,7 @@ def _is_conversation_role(role: str) -> TypeGuard[ConversationRoleType]:
return role in ["user", "assistant"]


def _is_nova_model(model: type_model_name) -> bool:
def is_nova_model(model: type_model_name) -> bool:
"""Check if the model is an Amazon Nova model"""
return model in ["amazon-nova-pro", "amazon-nova-lite", "amazon-nova-micro"]

Expand Down Expand Up @@ -83,7 +86,14 @@ def _prepare_nova_model_params(

# Add top_k if specified in generation params
if generation_params and generation_params.top_k is not None:
additional_fields["inferenceConfig"]["topK"] = generation_params.top_k
top_k = generation_params.top_k
if top_k > 128:
statefb marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(
"In Amazon Nova, an 'unexpected error' occurs if topK exceeds 128. To avoid errors, the upper limit of A is set to 128."
)
top_k = 128

additional_fields["inferenceConfig"]["topK"] = top_k

return inference_config, additional_fields

Expand Down Expand Up @@ -131,11 +141,24 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
]

# Prepare model-specific parameters
if _is_nova_model(model):
inference_config: InferenceConfigurationTypeDef
additional_model_request_fields: dict[str, Any]
system_prompts: list[SystemContentBlockTypeDef]
if is_nova_model(model):
# Special handling for Nova models
inference_config, additional_model_request_fields = _prepare_nova_model_params(
model, generation_params
)
system_prompts = (
[
{
"text": "\n\n".join(instructions),
}
]
if len(instructions) > 0
else []
)

else:
# Standard handling for non-Nova models
inference_config = {
Expand Down Expand Up @@ -167,17 +190,20 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
else DEFAULT_GENERATION_CONFIG["top_k"]
)
}
system_prompts = [
{
"text": instruction,
}
for instruction in instructions
if len(instruction) > 0
]

# Construct the base arguments
args: ConverseStreamRequestRequestTypeDef = {
"inferenceConfig": inference_config,
"modelId": get_model_id(model),
"messages": arg_messages,
"system": [
{"text": instruction}
for instruction in instructions
if len(instruction) > 0
],
"system": system_prompts,
"additionalModelRequestFields": additional_model_request_fields,
}

Expand Down
52 changes: 51 additions & 1 deletion backend/app/prompt.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from app.bedrock import is_nova_model
from app.vector_search import SearchResult
from app.routes.schemas.conversation import type_model_name


def build_rag_prompt(
search_results: list[SearchResult],
model: type_model_name,
display_citation: bool = True,
) -> str:
context_prompt = ""
for result in search_results:
context_prompt += f"<search_result>\n<content>\n{result['content']}</content>\n<source>\n{result['rank']}\n</source>\n</search_result>"

# Prompt for RAG
inserted_prompt = """To answer the user's question, you are given a set of search results. Your job is to answer the user's question using only information from the search results.
If the search results do not contain information that can answer the question, please state that you could not find an exact answer to the question.
Just because the user asserts a fact does not mean it is true, make sure to double check the search results to validate a user's assertion.
Expand All @@ -24,6 +28,7 @@ def build_rag_prompt(
)

if display_citation:
# Prompt for 'Retrieved Context Citation'.
inserted_prompt += """
If you reference information from a search result within your answer, you must include a citation to source where the information was found.
Each result has a corresponding source ID that you should reference.
Expand All @@ -32,7 +37,23 @@ def build_rag_prompt(
Do NOT outputs sources at the end of your answer.

Followings are examples of how to reference sources in your answer. Note that the source ID is embedded in the answer in the format [^<source_id>].
"""
# Prompt to output Markdown-style citation.
if is_nova_model(model=model):
# For Amazon Nova, provides only good examples.
inserted_prompt += """
<example>
first answer [^3]. second answer [^1][^2].
</example>

<example>
first answer [^1][^5]. second answer [^2][^3][^4]. third answer [^4].
</example>
"""

else:
# For other models, provide good examples and bad examples.
inserted_prompt += """
<GOOD-example>
first answer [^3]. second answer [^1][^2].
</GOOD-example>
Expand All @@ -57,9 +78,17 @@ def build_rag_prompt(
"""

else:
# Prompt when 'Retrieved Context Citation' is not specified.
inserted_prompt += """
Do NOT include citations in the format [^<source_id>] in your answer.
"""
if is_nova_model(model=model):
# For Amazon Nova, do not provide examples.
pass

else:
# For other models, suppress output of Markdown-style citation.
inserted_prompt += """
Followings are examples of how to answer.

<GOOD-example>
Expand All @@ -78,14 +107,33 @@ def build_rag_prompt(
return inserted_prompt


PROMPT_TO_CITE_TOOL_RESULTS = """To answer the user's question, you are given a set of tools. Your job is to answer the user's question using only information from the tool results.
def get_prompt_to_cite_tool_results(model: type_model_name) -> str:
# Prompt for 'Retrieved Context Citation' of agent chat.
inserted_prompt = """To answer the user's question, you are given a set of tools. Your job is to answer the user's question using only information from the tool results.
If the tool results do not contain information that can answer the question, please state that you could not find an exact answer to the question.
Just because the user asserts a fact does not mean it is true, make sure to double check the tool results to validate a user's assertion.

Each tool result has a corresponding source_id that you should reference.
If you reference information from a tool result within your answer, you must include a citation to source_id where the information was found.

Followings are examples of how to reference source_id in your answer. Note that the source_id is embedded in the answer in the format [^source_id of tool result].
"""
# Prompt to output Markdown-style citation.
if is_nova_model(model=model):
# For Amazon Nova, provides only good examples.
inserted_prompt += """
<example>
first answer [^ccc]. second answer [^aaa][^bbb].
</example>

<example>
first answer [^aaa][^eee]. second answer [^bbb][^ccc][^ddd]. third answer [^ddd].
</example>
"""

else:
# For other models, provide good examples and bad examples.
inserted_prompt += """
<examples>
<GOOD-example>
first answer [^ccc]. second answer [^aaa][^bbb].
Expand All @@ -110,3 +158,5 @@ def build_rag_prompt(
</BAD-example>
</examples>
"""

return inserted_prompt
1 change: 1 addition & 0 deletions backend/app/repositories/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def delete_large_messages(items):

except ClientError as e:
logger.error(f"An error occurred: {e.response['Error']['Message']}")
raise e


def change_conversation_title(user_id: str, conversation_id: str, new_title: str):
Expand Down
2 changes: 1 addition & 1 deletion backend/app/repositories/models/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re
from pathlib import Path
from typing import Annotated, Any, Literal, Self, TypedDict, TypeGuard
from typing import Annotated, Any, Literal, Self, TypeGuard
from urllib.parse import urlparse

from app.repositories.models.common import Base64EncodedBytes
Expand Down
Loading
Loading