Skip to content

Commit

Permalink
feat: visualization generation instruction and fewshots
Browse files Browse the repository at this point in the history
  • Loading branch information
cxxxxxn committed Apr 19, 2024
1 parent 0e94669 commit 32b5501
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 1 deletion.
12 changes: 11 additions & 1 deletion coml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
FIX_INSTRUCTION,
GENERATE_INSTRUCTION,
GENERATE_INSTRUCTION_COT,
GENERATE_INSTRUCTION_VIS_MATPLOTLIB,
GENERATE_INSTRUCTION_VIS_SEABORN,
SANITY_CHECK_INSTRUCTION,
SUGGEST_INSTRUCTION,
FixContext,
Expand Down Expand Up @@ -125,7 +127,9 @@ class CoMLAgent:
def __init__(
self,
llm: BaseChatModel,
prompt_version: Literal["v1", "v2", "kaggle", "leetcode"] = "v2",
prompt_version: Literal[
"v1", "v2", "kaggle", "leetcode", "matplotlib", "seaborn"
] = "v2",
prompt_validation: Callable[[list[BaseMessage]], bool] | None = None,
num_examples: float | int = 1.0,
message_style: Literal["chatgpt", "gemini"] = "chatgpt",
Expand Down Expand Up @@ -298,6 +302,12 @@ def generate_code(
shot["answer"] = shot.pop("answer_wo_intact")
if "rationale_wo_intact" in shot:
shot["rationale"] = shot.pop("rationale_wo_intact")

if self.prompt_version == "matplotlib":
generate_instruction = GENERATE_INSTRUCTION_VIS_MATPLOTLIB
elif self.prompt_version == "seaborn":
generate_instruction = GENERATE_INSTRUCTION_VIS_SEABORN

messages.append(SystemMessage(content=generate_instruction))

for shot in self._select_examples(request, fewshots):
Expand Down
24 changes: 24 additions & 0 deletions coml/prompt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,18 @@ def render_sanity_check_context(
- Think before you write. You should first understand the user's request, think about how to achieve it, and then write the code.
"""

GENERATE_INSTRUCTION_VIS_MATPLOTLIB = f"""You're a helpful assistant proficient in writing Python code for data visualization. Upon receiving relevant context, such as available variables and any pre-executed code, your goal is to complete the Python code to generate a visualization that meets the user's request.
Instructions:
- Utilize the `matplotlib` library to create the visualization and ensure include `plt.show()` to display the chart.
- You must return the generated code wrapped by ``` before and after it, and do not add any explanation.
"""

GENERATE_INSTRUCTION_VIS_SEABORN = GENERATE_INSTRUCTION_VIS_MATPLOTLIB.replace(
"matplotlib", "seaborn"
)

FIX_INSTRUCTION = f"""{GENERATE_INSTRUCTION.rstrip()}
- If the user thinks the generated code is problematic, you should help fix it. The user will provide you with the exception message (if any), the output of the code (if any), and a hint (if any). You should provide a line-by-line explanation of the code, and point out what is wrong with the code. You should also provide the fixed code.
- If you think the provided problematic code is actually correct, you should first explain the code, and write "THE CODE IS CORRECT." (in upper case) in the observation section. The fixed code can be omitted.
Expand Down Expand Up @@ -459,6 +471,18 @@ def cached_generate_fewshots(prompt_version: str) -> list[GenerateContext]:
with open(
Path(__file__).parent / f"prompts/generate_fewshots_{prompt_version}.json"
) as f:
if prompt_version in ["matplotlib", "seaborn"]:
fewshots = json.load(f)
for shot in fewshots:
variables = {}
for name in shot["datasets"]:
dataset = pd.read_csv(
str(Path(__file__).parent / f"dataset/{name}.csv")
)
# todo: dataframe_format
variables[f"{name.split("/")[1]}_dataset"] = describe_variable(dataset)
shot["variables"] = variables
return fewshots
return json.load(f)


Expand Down
11 changes: 11 additions & 0 deletions coml/prompts/dataset/coffee_shop/member.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Member_ID,Name,Membership_card,Age,Time_of_purchase,Level_of_membership,Address
1,"Ashby, Lazale",Black,29,18,5,Hartford
2,"Breton, Robert",White,67,41,4,Waterbury
3,"Campbell, Jessie",Black,34,20,6,Hartford
4,"Cobb, Sedrick",Black,51,27,2,Waterbury
5,"Hayes, Steven",White,50,44,3,Cheshire
6,"Komisarjevsky, Joshua",White,33,26,2,Cheshire
7,"Peeler, Russell",Black,42,26,6,Bridgeport
8,"Reynolds, Richard",Black,45,24,1,Waterbury
9,"Rizzo, Todd",White,35,18,4,Waterbury
10,"Webb, Daniel",Black,51,27,22,Hartford
21 changes: 21 additions & 0 deletions coml/prompts/dataset/game_injury/game.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
stadium_id,id,Season,Date,Home_team,Away_team,Score,Competition
1,1,2007,18 May 2007,Quruvchi,Pakhtakor,1–1,League
2,2,2007,22 September 2007,Pakhtakor,Quruvchi,0–0,League
3,3,2007,9 December 2007,Pakhtakor,Quruvchi,0–0 (7:6),Cup
4,4,2008,10 July 2008,Pakhtakor,Quruvchi,1–1,League
5,5,2008,16 August 2008,Bunyodkor,Pakhtakor,1–1,League
6,6,2008,31 October 2008,Bunyodkor,Pakhtakor,3–1,Cup
7,7,2009,12 July 2009,Bunyodkor,Pakhtakor,2–1,League
8,8,2009,14 October 2009,Pakhtakor,Bunyodkor,0–0,League
9,9,2009,8 August 2009,Pakhtakor,Bunyodkor,1–0,Cup
10,10,2010,14 March 2010,Bunyodkor,Pakhtakor,2–1,League
10,11,2010,31 October 2010,Pakhtakor,Bunyodkor,0–0,League
10,12,2011,7 July 2011,Pakhtakor,Bunyodkor,0–0,League
1,13,2011,21 August 2011,Bunyodkor,Pakhtakor,2–1,League
2,14,2012,11 March 2012,Bunyodkor,Pakhtakor,–,Supercup
3,15,2012,26 June 2012,Bunyodkor,Pakhtakor,2–0,League
4,16,2012,9 August 2012,Pakhtakor,Bunyodkor,1–1,League
5,17,2012,22 August 2012,Bunyodkor,Pakhtakor,1–1,Cup
11,18,2012,25 November 2012,Pakhtakor,Bunyodkor,1–3,Cup
12,19,2013,30 June 2013,Pakhtakor,Bunyodkor,0–2,League
7,20,2013,8 August 2013,Bunyodkor,Pakhtakor,1–2,League
6 changes: 6 additions & 0 deletions coml/prompts/dataset/pilot_record/pilot.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Pilot_ID,Pilot_name,Rank,Age,Nationality,Position,Join_Year,Team
1,Patrick O'Bryant,13,33,United States,Center Team,2009,Bradley
2,Jermaine O'Neal,6,40,United States,Forward-Center Team,2008,Eau Claire High School
3,Dan O'Sullivan,45,37,United States,Center Team,1999,Fordham
4,Charles Oakley,34,22,United Kindom,Forward Team,2001,Virginia Union
5,Hakeem Olajuwon,34,32,Nigeria,Center Team,2010,Houston
13 changes: 13 additions & 0 deletions coml/prompts/dataset/scientist_1/AssignedTo.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Scientist,Project
123234877,AeH1
152934485,AeH3
222364883,Ast3
326587417,Ast3
332154719,Bte1
546523478,Che1
631231482,Ast3
654873219,Che1
745685214,AeH3
845657245,Ast1
845657246,Ast2
332569843,AeH4
15 changes: 15 additions & 0 deletions coml/prompts/dataset/scientist_1/Projects.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
Code,Name,Hours
AeH1,Winds: Studying Bernoullis Principle,156
AeH2,Aerodynamics and Bridge Design,189
AeH3,Aerodynamics and Gas Mileage,256
AeH4,Aerodynamics and Ice Hockey,789
AeH5,Aerodynamics of a Football,98
AeH6,Aerodynamics of Air Hockey,89
Ast1,A Matter of Time,112
Ast2,A Puzzling Parallax,299
Ast3,Build Your Own Telescope,6546
Bte1,Juicy: Extracting Apple Juice with Pectinase,321
Bte2,A Magnetic Primer Designer,9684
Bte3,Bacterial Transformation Efficiency,321
Che1,A Silver-Cleaning Battery,545
Che2,A Soluble Separation Solution,778
47 changes: 47 additions & 0 deletions coml/prompts/generate_fewshots_matplotlib.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
[
{
"datasets": [
"scientist_1/Projects",
"scientist_1/AssignedTo"
],
"codes": [
"import pandas as pd\nimport matplotlib.pyplot as plt",
"Projects_dataset = pd.read_csv('../dataset/scientist_1/Projects.csv')\nAssignedTo_dataset = pd.read_csv('../dataset/scientist_1/AssignedTo.csv')"
],
"request": "What are the names of projects that require more than 300 hours, and how many scientists are assigned to each? Plot the result in a bar chart.",
"answer": "# Merge the two datasets on the project code\nmerged_data = pd.merge(Projects_dataset, AssignedTo_dataset, left_on='Code', right_on='Project')\n\n# Filter projects that require more than 300 hours\nfiltered_data = merged_data[merged_data['Hours'] > 300]\n\n# Count the number of scientists assigned to each project\nproject_counts = filtered_data['Name'].value_counts()\n\n# Plot a bar chart\nplt.figure(figsize=(10, 8))\nplt.bar(project_counts.index, project_counts.values)\nplt.xticks(rotation=5)\nplt.xlabel('Project Name')\nplt.ylabel('Number of Scientists')\nfrom matplotlib.ticker import MaxNLocator\nplt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))\nplt.title('Number of Scientists Assigned to Each Project')\nplt.show()"
},
{
"datasets": [
"coffee_shop/member"
],
"codes": [
"import pandas as pd\nimport matplotlib.pyplot as plt",
"member = pd.read_csv('../dataset/coffee_shop/member.csv')"
],
"request": "A scatter chart showing the correlation between the age of the customer and the time of purchase colored by membership level.",
"answer": "# Group the dataset by Membership_card\ngroups = member_dataset.groupby('Membership_card')\n\n# Create a scatter chart for each Membership_card\nfor membership_card, group in groups:\n plt.scatter(group['Age'], group['Time_of_purchase'], label=membership_card)\n\n# Set the title and labels\nplt.title('Correlation between Age and Time of Purchase')\nplt.xlabel('Age')\nplt.ylabel('Time of Purchase')\nplt.legend(loc='upper left')\n\n# Show the plot\nplt.show()"
},
{
"datasets": [
"game_injury/game"
],
"codes": [
"import pandas as pd\nimport matplotlib.pyplot as plt",
"game_dataset = pd.read_csv('../dataset/game_injury/game.csv')"
],
"request": "Show the number of games in each season and group by away team in a group line chart. The x-axis is season.",
"answer": "# group the dataset by season and away team\ngrouped = game_dataset.groupby(['Season', 'Away_team']).size().reset_index(name='counts')\n\n# create a pivot table with season as index and away team as columns\npivot_table = pd.pivot_table(grouped, values='counts', index=['Season'], columns=['Away_team'], fill_value=0)\n\n# create the line chart\npivot_table.plot(kind='line')\n\n# set the title and labels\nplt.title('Number of Games in Each Season by Away Team')\nplt.xlabel('Season')\nplt.ylabel('Number of Games')\nfrom matplotlib.ticker import MaxNLocator\nplt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))\n\n# show the plot\nplt.show()"
},
{
"datasets": [
"pilot_record/pilot"
],
"codes": [
"import pandas as pd\nimport matplotlib.pyplot as plt",
"pilot_dataset = pd.read_csv('../dataset/pilot_record/pilot.csv')"
],
"request": "What is the proportion of positions of pilots? Show the result in a pie chart.",
"answer": "# Count the number of each position\nposition_counts = pilot_dataset['Position'].value_counts()\n\n# Create a pie chart\nplt.figure(figsize=(8,6))\nplt.pie(position_counts, labels = position_counts.index, autopct='%1.1f%%')\n\n# Set the title\nplt.title('Proportion of Positions')\n\n# Show the plot\nplt.show()"
}
]
14 changes: 14 additions & 0 deletions coml/prompts/generate_fewshots_seaborn.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[
{
"datasets": [
"scientist_1/Projects",
"scientist_1/AssignedTo"
],
"codes": [
"import pandas as pd\nimport matplotlib.pyplot as plt\nimport seaborn as sns",
"Projects_dataset = pd.read_csv('../dataset/scientist_1/Projects.csv')\nAssignedTo_dataset = pd.read_csv('../dataset/scientist_1/AssignedTo.csv')"
],
"request": "What are the names of projects that require more than 300 hours, and how many scientists are assigned to each? Plot the result in a bar chart.",
"answer": "# Merge the two datasets on the project code\nmerged_data = pd.merge(Projects_dataset, AssignedTo_dataset, left_on='Code', right_on='Project')\n\n# Filter projects that require more than 300 hours\nfiltered_data = merged_data[merged_data['Hours'] > 300]\n\n# Count the number of scientists assigned to each project\nproject_counts = filtered_data['Name'].value_counts()\n\n# Plot a bar chart\nplt.figure(figsize=(10, 8))\nsns.barplot(project_counts)\nplt.xticks(rotation=5)\nplt.xlabel('Project Name')\nplt.ylabel('Number of Scientists')\nfrom matplotlib.ticker import MaxNLocator\nplt.gca().yaxis.set_major_locator(MaxNLocator(integer=True))\nplt.title('Number of Scientists Assigned to Each Project')\nplt.show()"
}
]

0 comments on commit 32b5501

Please sign in to comment.