Skip to content

Commit

Permalink
pydantic validation added and test modified
Browse files Browse the repository at this point in the history
  • Loading branch information
spike-spiegel-21 committed Sep 28, 2024
1 parent f3fa21b commit fd29e1c
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 23 deletions.
14 changes: 13 additions & 1 deletion mem0/configs/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Literal, Optional

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -72,3 +72,15 @@ class AzureConfig(BaseModel):
azure_deployment: str = Field(description="The name of the Azure deployment.", default=None)
azure_endpoint: str = Field(description="The endpoint URL for the Azure service.", default=None)
api_version: str = Field(description="The version of the Azure API being used.", default=None)


class CustomCategories(BaseModel):
"""
Custom categories for memory.
Args:
categories (list): The list of custom categories.
filter (str): Filter to control the category behaviour.
"""
categories: List[Dict[str, str]] = Field(...,description="List of categories with key-value pairs as strings")
filter: Optional[Literal['restrict', 'omit', 'extend']] = Field('extend', description="Optional filter to control the category display behavior")
25 changes: 13 additions & 12 deletions mem0/memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@
import uuid
import warnings
from datetime import datetime
from typing import Any, Dict, List, Optional, Literal
from typing import Any, Dict

import pytz
from pydantic import ValidationError

from mem0.configs.base import MemoryConfig, MemoryItem
from mem0.configs.base import CustomCategories, MemoryConfig, MemoryItem
from mem0.configs.prompts import get_update_memory_messages
from mem0.memory.base import MemoryBase
from mem0.memory.setup import setup_config
from mem0.memory.storage import SQLiteManager
from mem0.memory.telemetry import capture_event
from mem0.memory.utils import get_custom_category_fact_retrieval_messages, get_fact_retrieval_messages, parse_messages
from mem0.memory.utils import (
get_custom_category_fact_retrieval_messages,
get_fact_retrieval_messages,
parse_messages,
)
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory

# Setup user config
Expand Down Expand Up @@ -67,8 +71,7 @@ def add(
metadata=None,
filters=None,
prompt=None,
custom_category: Optional[List[Dict[str, str]]] = None,
custom_category_filter: Optional[Literal['extend', 'restrict', 'omit']] = None
custom_categories=None
):
"""
Create a new memory.
Expand Down Expand Up @@ -102,11 +105,8 @@ def add(
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]

if custom_category_filter is not None and custom_category is None:
raise ValueError("custom_category_filter can only be used when custom_category is provided")

with concurrent.futures.ThreadPoolExecutor() as executor:
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters, custom_category, custom_category_filter)
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters, custom_categories)
future2 = executor.submit(self._add_to_graph, messages, filters)

concurrent.futures.wait([future1, future2])
Expand All @@ -129,14 +129,15 @@ def add(
)
return {"message": "ok"}

def _add_to_vector_store(self, messages, metadata, filters, custom_category, custom_category_filter):
def _add_to_vector_store(self, messages, metadata, filters, custom_categories):
parsed_messages = parse_messages(messages)

if self.custom_prompt:
system_prompt = self.custom_prompt
user_prompt = f"Input: {parsed_messages}"
elif custom_category:
system_prompt, user_prompt = get_custom_category_fact_retrieval_messages(custom_category, custom_category_filter, parsed_messages)
elif custom_categories:
validated_custom_categories = CustomCategories(**custom_categories)
system_prompt, user_prompt = get_custom_category_fact_retrieval_messages(validated_custom_categories, parsed_messages)
else:
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)

Expand Down
21 changes: 13 additions & 8 deletions mem0/memory/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from mem0.configs.prompts import FACT_RETRIEVAL_PROMPT
from mem0.configs.prompts import EXTEND_FACT_RETRIEVAL_PROMPT, OMIT_FACT_RETRIEVAL_PROMPT, RESTRICT_FACT_RETRIEVAL_PROMPT
from mem0.configs.prompts import (
EXTEND_FACT_RETRIEVAL_PROMPT,
FACT_RETRIEVAL_PROMPT,
OMIT_FACT_RETRIEVAL_PROMPT,
RESTRICT_FACT_RETRIEVAL_PROMPT,
)


def get_fact_retrieval_messages(message):
return FACT_RETRIEVAL_PROMPT, f"Input: {message}"

def get_custom_category_fact_retrieval_messages(custom_category, custom_category_filter, messages):
if custom_category_filter == "omit":
return prepare_input_message(custom_category, OMIT_FACT_RETRIEVAL_PROMPT), f"Input: {messages}"
if custom_category_filter == "restrict":
return prepare_input_message(custom_category, RESTRICT_FACT_RETRIEVAL_PROMPT), f"Input: {messages}"
def get_custom_category_fact_retrieval_messages(custom_categories, messages):
if custom_categories.filter == "omit":
return prepare_input_message(custom_categories.categories, OMIT_FACT_RETRIEVAL_PROMPT), f"Input: {messages}"
if custom_categories.filter == "restrict":
return prepare_input_message(custom_categories.categories, RESTRICT_FACT_RETRIEVAL_PROMPT), f"Input: {messages}"

return prepare_input_message(custom_category, EXTEND_FACT_RETRIEVAL_PROMPT), f"Input: {messages}"
return prepare_input_message(custom_categories.categories, EXTEND_FACT_RETRIEVAL_PROMPT), f"Input: {messages}"


def parse_messages(messages):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ def test_add(memory_instance, version, enable_graph):
memory_instance._add_to_vector_store = Mock(return_value=[{"memory": "Test memory", "event": "ADD"}])
memory_instance._add_to_graph = Mock(return_value=[])

result = memory_instance.add(messages=[{"role": "user", "content": "Test message"}], user_id="test_user")
result = memory_instance.add(messages=[{"role": "user", "content": "Test message"}], user_id="test_user", custom_categories={"categories": [{"key": "value"}]})

assert "results" in result
assert result["results"] == [{"memory": "Test memory", "event": "ADD"}]
assert "relations" in result
assert result["relations"] == []

memory_instance._add_to_vector_store.assert_called_once_with(
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"}, None, None
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"}, {"categories": [{"key": "value"}]}
)

# Remove the conditional assertion for _add_to_graph
Expand Down

0 comments on commit fd29e1c

Please sign in to comment.