Skip to content

Commit

Permalink
refactor: Update Graph class to write prompt optimization results to …
Browse files Browse the repository at this point in the history
…text and JSON files
  • Loading branch information
provos committed Sep 12, 2024
1 parent 3daeefe commit a53e0d7
Showing 1 changed file with 50 additions and 14 deletions.
64 changes: 50 additions & 14 deletions src/planai/cli_optimize_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import sys
from operator import attrgetter
from pathlib import Path
from textwrap import dedent
from typing import Any, Dict, List, Optional, Type

Expand Down Expand Up @@ -427,9 +428,6 @@ def optimize_prompt(
# for prompt optimization
llm_class.debug_mode = False

# Make sure to pick the prompt from the upstream workers and reflect it in the cache key
inject_prompt_awareness(llm_class)

# Attempt to retrieve the 'prompt' attribute
if not hasattr(llm_class, "prompt"):
print(f"'prompt' attribute not found in class '{class_name}'.")
Expand Down Expand Up @@ -516,7 +514,17 @@ def consume_work(self, task: llm_output_type):
# create two new prompts
input_tasks = []
for i, example in enumerate(data[:2]):
task = create_input_task(module, task_name, example)
# test whether we can create the input task before we spin up the whole graph
if i == 0:
try:
task = create_input_task(module, task_name, example)
_ = llm_class.get_full_prompt(task)
except Exception as e:
print(
f"Error creating input task from {example} - did you provide the right debug log: {e}"
)
exit(1)

prompt_template = llm_class.prompt
input_tasks.append(
(
Expand All @@ -529,19 +537,47 @@ def consume_work(self, task: llm_output_type):
)
)

# Make sure to pick the prompt from the upstream workers and reflect it in the cache key
inject_prompt_awareness(llm_class)

graph.run(initial_tasks=input_tasks, run_dashboard=False)

output = graph.get_tasks()
output = graph.get_output_tasks()
write_results(llm_class.name, output)

# Create a list to hold all the task data.
all_tasks_data = []
for task in output:
data = task.model_dump()
all_tasks_data.append(data)

# Write the list of task data to a JSON file.
with open("optimized_prompts.json", "w") as json_file:
json.dump(all_tasks_data, json_file, indent=2)
def write_results(class_name: str, output: List[PromptCritique]):
"""
Writes the results from prompt optimization to a text file and a JSON file.
Parameters:
class_name (str): The name of the worker class for which the prompt was optimized.
output (List[Task]): A list of Task objects containing the prompt data and scores.
"""

def get_available_filename(base_name, ext):
"""
Get the next available file name by checking existing files,
incrementing version number if necessary.
"""
version = 1
file_path = Path(f"{base_name}.v{version}.{ext}")
while file_path.exists():
version += 1
file_path = Path(f"{base_name}.v{version}.{ext}")
return file_path

for index, task in enumerate(output, start=1):
# Create the base file name prefixed with the class name and prompt number.
base_filename = f"{class_name}_prompt_{index}"

# Create the text file containing the prompt and score.
text_filename = get_available_filename(base_filename, "text")
text_filename.write_text(f"Score: {task.score}\n{task.prompt_template}")

# Create the JSON file dumping the whole content.
json_filename = get_available_filename(base_filename, "json")
json_filename.write_text(json.dumps(task.model_dump(), indent=2))


def sanitize_prompt(original_template: str, prompt_template: str) -> str:
Expand All @@ -567,7 +603,7 @@ def inject_prompt_awareness(llm_class: LLMTaskWorker):
def new_format_prompt(task: Task) -> str:
input_prompt: Optional[ImprovedPrompt] = task.find_input_task(ImprovedPrompt)
if input_prompt is None:
raise ValueError("No input task found")
raise ValueError("No input task found for ImprovedPrompt")
with llm_class.lock:
llm_class.prompt = input_prompt.prompt_template
return original_format_prompt(task)
Expand Down

0 comments on commit a53e0d7

Please sign in to comment.