Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for custom prompt override in memory.add function #1998

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/features/custom-prompts.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,20 @@ m.add("I like going to hikes", user_id="alice")
}
```
</CodeGroup>


## Customizing Prompts per Memory Addition

In addition to setting a default prompt in the configuration, you can also override prompts for individual memory entries by using the prompt and graph_prompt parameters in m.add(). This allows you to tailor specific entries without changing the overall configuration.

For example, to add a memory with a custom prompt:

```python Code
m.add("Yesterday, I ordered a laptop, the order id is 12345", user_id="alice", prompt=custom_prompt)
```

You can also use graph_prompt to customize the prompt specifically for graph memory entries:

```python Code
m.add("Yesterday, I ordered a laptop, the order id is 12345", user_id="alice", graph_prompt=graph_prompt)
```
6 changes: 6 additions & 0 deletions docs/open-source/graph_memory/features.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ config = {
m = Memory.from_config(config_dict=config)
```

You can also **override prompts** for individual memory additions by using the `graph_prompt` parameter in `m.add()`

```python Code
m.add("Yesterday, I ordered a laptop, the order id is 12345", user_id="alice", graph_prompt=graph_prompt)
```

If you want to use a managed version of Mem0, please check out [Mem0](https://mem0.dev/pd). If you have any questions, please feel free to reach out to us using one of the following methods:

<Snippet file="get-help.mdx" />
5 changes: 3 additions & 2 deletions mem0/memory/graph_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, config):
self.user_id = None
self.threshold = 0.7

def add(self, data, filters):
def add(self, data, filters, graph_prompt):
"""
Adds data to the graph.

Expand All @@ -60,7 +60,8 @@ def add(self, data, filters):
# retrieve the search results
search_output = self._search(data, filters)

if self.config.graph_store.custom_prompt:
custom_prompt = graph_prompt if graph_prompt else self.config.graph_store.custom_prompt
if custom_prompt:
messages = [
{
"role": "system",
Expand Down
17 changes: 10 additions & 7 deletions mem0/memory/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def add(
metadata=None,
filters=None,
prompt=None,
graph_prompt=None,
):
"""
Create a new memory.
Expand All @@ -79,6 +80,7 @@ def add(
metadata (dict, optional): Metadata to store with the memory. Defaults to None.
filters (dict, optional): Filters to apply to the search. Defaults to None.
prompt (str, optional): Prompt to use for memory deduction. Defaults to None.
graph_prompt (str, optional): Prompt to use for graph memory deduction. Defaults to None.

Returns:
dict: A dictionary containing the result of the memory addition operation.
Expand Down Expand Up @@ -111,8 +113,8 @@ def add(
messages = [{"role": "user", "content": messages}]

with concurrent.futures.ThreadPoolExecutor() as executor:
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters)
future2 = executor.submit(self._add_to_graph, messages, filters)
future1 = executor.submit(self._add_to_vector_store, messages, metadata, filters, prompt)
future2 = executor.submit(self._add_to_graph, messages, filters, graph_prompt)

concurrent.futures.wait([future1, future2])

Expand All @@ -134,11 +136,12 @@ def add(
)
return {"message": "ok"}

def _add_to_vector_store(self, messages, metadata, filters):
def _add_to_vector_store(self, messages, metadata, filters, prompt):
parsed_messages = parse_messages(messages)

if self.custom_prompt:
system_prompt = self.custom_prompt
custom_prompt = prompt if prompt else self.custom_prompt
if custom_prompt:
system_prompt = custom_prompt
user_prompt = f"Input: {parsed_messages}"
else:
system_prompt, user_prompt = get_fact_retrieval_messages(parsed_messages)
Expand Down Expand Up @@ -230,7 +233,7 @@ def _add_to_vector_store(self, messages, metadata, filters):

return returned_memories

def _add_to_graph(self, messages, filters):
def _add_to_graph(self, messages, filters, graph_prompt):
added_entities = []
if self.api_version == "v1.1" and self.enable_graph:
if filters["user_id"]:
Expand All @@ -242,7 +245,7 @@ def _add_to_graph(self, messages, filters):
else:
self.graph.user_id = "USER"
data = "\n".join([msg["content"] for msg in messages if "content" in msg and msg["role"] != "system"])
added_entities = self.graph.add(data, filters)
added_entities = self.graph.add(data, filters, graph_prompt)

return added_entities

Expand Down
23 changes: 18 additions & 5 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,40 @@ def memory_instance():
return Memory(config)


@pytest.mark.parametrize("version, enable_graph", [("v1.0", False), ("v1.1", True)])
def test_add(memory_instance, version, enable_graph):
@pytest.mark.parametrize(
"version, enable_graph, custom_prompt",
[
("v1.0", False, None),
("v1.1", True, None),
("v1.0", False, "CustomPrompt"),
("v1.1", True, "CustomPrompt"),
]
)
def test_add(memory_instance, version, enable_graph, custom_prompt):
memory_instance.config.version = version
memory_instance.enable_graph = enable_graph
memory_instance._add_to_vector_store = Mock(return_value=[{"memory": "Test memory", "event": "ADD"}])
memory_instance._add_to_graph = Mock(return_value=[])

result = memory_instance.add(messages=[{"role": "user", "content": "Test message"}], user_id="test_user")
result = memory_instance.add(
messages=[{"role": "user", "content": "Test message"}],
user_id="test_user",
prompt=custom_prompt,
graph_prompt=custom_prompt
)

assert "results" in result
assert result["results"] == [{"memory": "Test memory", "event": "ADD"}]
assert "relations" in result
assert result["relations"] == []

memory_instance._add_to_vector_store.assert_called_once_with(
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"}
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, {"user_id": "test_user"}, custom_prompt
)

# Remove the conditional assertion for _add_to_graph
memory_instance._add_to_graph.assert_called_once_with(
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}
[{"role": "user", "content": "Test message"}], {"user_id": "test_user"}, custom_prompt
)


Expand Down
Loading