diff --git a/backend/app/bedrock.py b/backend/app/bedrock.py index cdfbb5f6..11689101 100644 --- a/backend/app/bedrock.py +++ b/backend/app/bedrock.py @@ -9,6 +9,7 @@ from app.repositories.models.conversation import ( SimpleMessageModel, ContentModel, + is_nova_model, ) from app.repositories.models.custom_bot import GenerationParamsModel from app.repositories.models.custom_bot_guardrails import BedrockGuardrailsModel @@ -22,6 +23,7 @@ ContentBlockTypeDef, GuardrailConverseContentBlockTypeDef, InferenceConfigurationTypeDef, + SystemContentBlockTypeDef, ) from mypy_boto3_bedrock_runtime.literals import ConversationRoleType @@ -46,11 +48,6 @@ def _is_conversation_role(role: str) -> TypeGuard[ConversationRoleType]: return role in ["user", "assistant"] -def _is_nova_model(model: type_model_name) -> bool: - """Check if the model is an Amazon Nova model""" - return model in ["amazon-nova-pro", "amazon-nova-lite", "amazon-nova-micro"] - - def _prepare_nova_model_params( model: type_model_name, generation_params: Optional[GenerationParamsModel] = None ) -> Tuple[InferenceConfigurationTypeDef, Dict[str, Any]]: @@ -131,14 +128,24 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]: ] # Prepare model-specific parameters - if _is_nova_model(model): + if is_nova_model(model): # Special handling for Nova models inference_config, additional_model_request_fields = _prepare_nova_model_params( model, generation_params ) + system_prompts: list[SystemContentBlockTypeDef] = ( + [ + { + "text": "\n\n".join(instructions), + } + ] + if len(instructions) > 0 + else [] + ) + else: # Standard handling for non-Nova models - inference_config = { + inference_config: InferenceConfigurationTypeDef = { "maxTokens": ( generation_params.max_tokens if generation_params @@ -167,17 +174,20 @@ def process_content(c: ContentModel, role: str) -> list[ContentBlockTypeDef]: else DEFAULT_GENERATION_CONFIG["top_k"] ) } + system_prompts: list[SystemContentBlockTypeDef] = [ + { + "text": instruction, + } + for instruction in instructions + if len(instruction) > 0 + ] # Construct the base arguments args: ConverseStreamRequestRequestTypeDef = { "inferenceConfig": inference_config, "modelId": get_model_id(model), "messages": arg_messages, - "system": [ - {"text": instruction} - for instruction in instructions - if len(instruction) > 0 - ], + "system": system_prompts, "additionalModelRequestFields": additional_model_request_fields, } diff --git a/backend/app/repositories/models/conversation.py b/backend/app/repositories/models/conversation.py index 5627f7df..7ebfc2fa 100644 --- a/backend/app/repositories/models/conversation.py +++ b/backend/app/repositories/models/conversation.py @@ -37,6 +37,11 @@ from pydantic import BaseModel, Discriminator, Field, JsonValue, field_validator +def is_nova_model(model: type_model_name) -> bool: + """Check if the model is an Amazon Nova model""" + return model in ["amazon-nova-pro", "amazon-nova-lite", "amazon-nova-micro"] + + class TextContentModel(BaseModel): content_type: Literal["text"] body: str = Field(