Skip to content

Commit

Permalink
test: add tracking for tool execution in regression tests
Browse files Browse the repository at this point in the history
  • Loading branch information
provos committed Nov 27, 2024
1 parent c33f45d commit d50cbf2
Showing 1 changed file with 50 additions and 8 deletions.
58 changes: 50 additions & 8 deletions tests/regression/planai/test_function_calling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from typing import Optional
from unittest.mock import create_autospec

import pytest
from pydantic import BaseModel, Field
Expand All @@ -24,6 +25,26 @@ class AIHistoryFact(BaseModel):
)


@pytest.fixture
def tracked_tools():
def create_tracked_tool(original_tool):
# Create a spy on the original tool's execute method
original_execute = original_tool.execute
execution_tracker = create_autospec(original_execute)

# Make sure we call the original function and get its return value
def execute_with_tracking(**kwargs):
return original_execute(**kwargs)

execution_tracker.side_effect = execute_with_tracking

# Replace the execute method but maintain all other properties
original_tool.execute = execution_tracker
return original_tool, execution_tracker

return create_tracked_tool


@tool(name="get_flight_times")
def get_flight_times(departure: str, arrival: str) -> str:
"""Get the flight times between two cities.
Expand Down Expand Up @@ -99,10 +120,11 @@ def llm_client(request):
class TestFunctionCalling:
"""Regression tests for function calling using a real LLM client."""

def test_flight_time_query(self, llm_client):
def test_flight_time_query(self, llm_client, tracked_tools):
"""Test that the LLM correctly handles flight time queries using function calling."""
tracked_tool, execute_tracker = tracked_tools(get_flight_times)

llm_client.support_json_mode = False
tools = [get_flight_times]
question = "What is the flight time from JFK to LAX?"

messages = [
Expand All @@ -113,11 +135,17 @@ def test_flight_time_query(self, llm_client):
{"role": "user", "content": question},
]

response = llm_client.chat(messages=messages, tools=tools)
response = llm_client.chat(messages=messages, tools=[tracked_tool])

# Log the full response for debugging
print(f"\nFlight query response: {response}")

# Verify the tool was called with correct parameters
execute_tracker.assert_called_once()
call_kwargs = execute_tracker.call_args[1]
assert call_kwargs["departure"].upper() == "JFK"
assert call_kwargs["arrival"].upper() == "LAX"

# Basic assertions about the response
assert response is not None
assert isinstance(response, str)
Expand All @@ -129,10 +157,11 @@ def test_flight_time_query(self, llm_client):
term in response.lower() for term in ["jfk", "lax", "flight", "time"]
)

def test_ai_history_query(self, llm_client):
def test_ai_history_query(self, llm_client, tracked_tools):
"""Test that the LLM correctly handles AI history queries using vector search."""
tracked_tool, execute_tracker = tracked_tools(search_data_in_vector_db)

llm_client.support_json_mode = False
tools = [search_data_in_vector_db]
question = "When was Artificial Intelligence founded?"

messages = [
Expand All @@ -143,11 +172,16 @@ def test_ai_history_query(self, llm_client):
{"role": "user", "content": question},
]

response = llm_client.chat(messages=messages, tools=tools)
response = llm_client.chat(messages=messages, tools=[tracked_tool])

# Log the full response for debugging
print(f"\nAI history query response: {response}")

# Verify the tool was called with correct parameters
execute_tracker.assert_called_once()
call_kwargs = execute_tracker.call_args[1]
assert "query" in call_kwargs

# Basic assertions about the response
assert response is not None
assert isinstance(response, str)
Expand Down Expand Up @@ -210,8 +244,10 @@ def test_tool_selection(self, llm_client, query, expected_tool):
class TestStructuredOutput:
"""Regression tests for structured output generation using generate_pydantic."""

def test_flight_info_structured(self, llm_client):
def test_flight_info_structured(self, llm_client, tracked_tools):
"""Test generating structured flight information."""
tracked_tool, execute_tracker = tracked_tools(get_flight_times)

prompt = """Generate structured information about the flight from JFK to LAX.
Use the get_flight_times function to get accurate flight details."""

Expand All @@ -223,9 +259,15 @@ def test_flight_info_structured(self, llm_client):
prompt_template=prompt,
output_schema=FlightInfo,
system=system,
tools=[get_flight_times],
tools=[tracked_tool],
)

# Verify the tool was called with correct parameters
execute_tracker.assert_called_once()
call_kwargs = execute_tracker.call_args[1]
assert call_kwargs["departure"].upper() == "JFK"
assert call_kwargs["arrival"].upper() == "LAX"

# Verify the structured response
assert isinstance(response, FlightInfo)
assert response.departure_city == "JFK"
Expand Down

0 comments on commit d50cbf2

Please sign in to comment.