Skip to content

Commit

Permalink
feat: Add metadata filter
Browse files Browse the repository at this point in the history
  • Loading branch information
rlinlen committed Dec 11, 2024
1 parent 062a1c9 commit 7e2ffa3
Show file tree
Hide file tree
Showing 10 changed files with 845 additions and 7 deletions.
1 change: 1 addition & 0 deletions backend/app/repositories/models/custom_bot_kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,4 @@ class BedrockKnowledgeBaseModel(BaseModel):
web_crawling_filters: WebCrawlingFiltersModel = WebCrawlingFiltersModel(
exclude_patterns=[], include_patterns=[]
)
kb_metadata_filter: dict | None = None
2 changes: 2 additions & 0 deletions backend/app/routes/schemas/bot_kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class BedrockKnowledgeBaseInput(BaseSchema):
web_crawling_filters: WebCrawlingFilters = WebCrawlingFilters(
exclude_patterns=[], include_patterns=[]
)
kb_metadata_filter: dict | None


class BedrockKnowledgeBaseOutput(BaseSchema):
Expand All @@ -119,3 +120,4 @@ class BedrockKnowledgeBaseOutput(BaseSchema):
web_crawling_filters: WebCrawlingFilters = WebCrawlingFilters(
exclude_patterns=[], include_patterns=[]
)
kb_metadata_filter: dict | None
33 changes: 27 additions & 6 deletions backend/app/vector_search.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from decimal import Decimal
import json
import logging
from typing import TypedDict
from urllib.parse import urlparse
Expand All @@ -20,7 +22,6 @@
logger = logging.getLogger(__name__)
agent_client = get_bedrock_agent_client()


class SearchResult(TypedDict):
bot_id: str
content: str
Expand Down Expand Up @@ -70,16 +71,36 @@ def _bedrock_knowledge_base_search(bot: BotModel, query: str) -> list[SearchResu

limit = bot.bedrock_knowledge_base.search_params.max_results
knowledge_base_id = bot.bedrock_knowledge_base.knowledge_base_id

kb_metadata_filter = bot.bedrock_knowledge_base.kb_metadata_filter

# bedrock doesn't take decimals
def convert_decimals(obj):
if isinstance(obj, list):
return [convert_decimals(i) for i in obj]
elif isinstance(obj, dict):
return {k: convert_decimals(v) for k, v in obj.items()}
elif isinstance(obj, Decimal):
# Convert to int if it's a whole number
if obj.as_tuple().exponent >= 0:
return int(obj)
# Convert to float if it has decimal places
return float(obj)
return obj

vector_search_configuration = {
"numberOfResults": limit,
"overrideSearchType": search_type
}
if kb_metadata_filter:
converted_kb_metadata_filter = convert_decimals(kb_metadata_filter)
vector_search_configuration["filter"] = converted_kb_metadata_filter

try:
response = agent_client.retrieve(
knowledgeBaseId=knowledge_base_id,
retrievalQuery={"text": query},
retrievalConfiguration={
"vectorSearchConfiguration": {
"numberOfResults": limit,
"overrideSearchType": search_type,
}
"vectorSearchConfiguration": vector_search_configuration
},
)

Expand Down
Loading

0 comments on commit 7e2ffa3

Please sign in to comment.