Skip to content

Commit

Permalink
If model is Amazon Nova, combine multiple system prompts into one text.
Browse files Browse the repository at this point in the history
  • Loading branch information
Yukinobu-Mine committed Dec 12, 2024
1 parent 061663e commit 8084a30
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
34 changes: 22 additions & 12 deletions backend/app/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from app.repositories.models.conversation import (
SimpleMessageModel,
ContentModel,
is_nova_model,
)
from app.repositories.models.custom_bot import GenerationParamsModel
from app.repositories.models.custom_bot_guardrails import BedrockGuardrailsModel
Expand All @@ -22,6 +23,7 @@
ContentBlockTypeDef,
GuardrailConverseContentBlockTypeDef,
InferenceConfigurationTypeDef,
SystemContentBlockTypeDef,
)
from mypy_boto3_bedrock_runtime.literals import ConversationRoleType

Expand All @@ -46,11 +48,6 @@ def _is_conversation_role(role: str) -> TypeGuard[ConversationRoleType]:
return role in ["user", "assistant"]


def _is_nova_model(model: type_model_name) -> bool:
"""Check if the model is an Amazon Nova model"""
return model in ["amazon-nova-pro", "amazon-nova-lite", "amazon-nova-micro"]


def _prepare_nova_model_params(
model: type_model_name, generation_params: Optional[GenerationParamsModel] = None
) -> Tuple[InferenceConfigurationTypeDef, Dict[str, Any]]:
Expand Down Expand Up @@ -131,14 +128,24 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
]

# Prepare model-specific parameters
if _is_nova_model(model):
if is_nova_model(model):
# Special handling for Nova models
inference_config, additional_model_request_fields = _prepare_nova_model_params(
model, generation_params
)
system_prompts: list[SystemContentBlockTypeDef] = (
[
{
"text": "\n\n".join(instructions),
}
]
if len(instructions) > 0
else []
)

else:
# Standard handling for non-Nova models
inference_config = {
inference_config: InferenceConfigurationTypeDef = {
"maxTokens": (
generation_params.max_tokens
if generation_params
Expand Down Expand Up @@ -167,17 +174,20 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]:
else DEFAULT_GENERATION_CONFIG["top_k"]
)
}
system_prompts: list[SystemContentBlockTypeDef] = [
{
"text": instruction,
}
for instruction in instructions
if len(instruction) > 0
]

# Construct the base arguments
args: ConverseStreamRequestRequestTypeDef = {
"inferenceConfig": inference_config,
"modelId": get_model_id(model),
"messages": arg_messages,
"system": [
{"text": instruction}
for instruction in instructions
if len(instruction) > 0
],
"system": system_prompts,
"additionalModelRequestFields": additional_model_request_fields,
}

Expand Down
5 changes: 5 additions & 0 deletions backend/app/repositories/models/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
from pydantic import BaseModel, Discriminator, Field, JsonValue, field_validator


def is_nova_model(model: type_model_name) -> bool:
"""Check if the model is an Amazon Nova model"""
return model in ["amazon-nova-pro", "amazon-nova-lite", "amazon-nova-micro"]


class TextContentModel(BaseModel):
content_type: Literal["text"]
body: str = Field(
Expand Down

0 comments on commit 8084a30

Please sign in to comment.