Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anthropic support #112

Merged
merged 24 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/workflows/e2e_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
name: e2e tests
on:
pull_request:
branches:
- main
jobs:
e2e-tests:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -71,6 +69,7 @@ jobs:
E2B_API_KEY: ${{ secrets.E2B_API_KEY }}
PINATA_GATEWAY_TOKEN: ${{ secrets.PINATA_GATEWAY_TOKEN }}
PINATA_API_JWT: ${{ secrets.PINATA_API_JWT }}
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
- name: "OpenAI gpt-4-turbo-preview"
run:
cd contracts && npx hardhat openai --contract-address ${{ env.TEST_CONTRACT_ADDRESS }} --model gpt-4-turbo-preview --message "Who is the president of USA?" --network ${{ env.NETWORK }}
Expand Down Expand Up @@ -111,6 +110,11 @@ jobs:
cd contracts && npx hardhat groq --contract-address ${{ env.TEST_CONTRACT_ADDRESS }} --model gemma-7b-it --message "Who is the president of USA?" --network ${{ env.NETWORK }}
env:
PRIVATE_KEY_LOCALHOST: ${{ secrets.PRIVATE_KEY }}
- name: "Anthropic claude-3-5-sonnet-20240620"
run:
cd contracts && npx hardhat llm --contract-address ${{ env.TEST_CONTRACT_ADDRESS }} --model claude-3-5-sonnet-20240620 --message "Who is the president of USA?" --network ${{ env.NETWORK }}
env:
PRIVATE_KEY_LOCALHOST: ${{ secrets.PRIVATE_KEY }}
- name: "OpenAI Image Generation"
run:
cd contracts && npx hardhat image_generation --contract-address ${{ env.TEST_CONTRACT_ADDRESS }} --query "Red rose" --network ${{ env.NETWORK }}
Expand Down
86 changes: 77 additions & 9 deletions contracts/contracts/ChatGpt.sol
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,16 @@ contract ChatGpt {

// @notice Address of the oracle contract
address public oracleAddress;

// @notice Configuration for the LLM request
IOracle.LlmRequest private config;

// @notice CID of the knowledge base
string public knowledgeBase;

// @notice Mapping from chat ID to the tool currently running
mapping(uint => string) public toolRunning;

// @notice Event emitted when the oracle address is updated
event OracleAddressUpdated(address indexed newOracleAddress);

Expand All @@ -45,6 +51,22 @@ contract ChatGpt {
owner = msg.sender;
oracleAddress = initialOracleAddress;
knowledgeBase = knowledgeBaseCID;

config = IOracle.LlmRequest({
model : "claude-3-5-sonnet-20240620",
frequencyPenalty : 21, // > 20 for null
logitBias : "", // empty str for null
maxTokens : 1000, // 0 for null
presencePenalty : 21, // > 20 for null
responseFormat : "{\"type\":\"text\"}",
seed : 0, // null
stop : "", // null
temperature : 10, // Example temperature (scaled up, 10 means 1.0), > 20 means null
topP : 101, // Percentage 0-100, > 100 means null
tools : "[{\"type\":\"function\",\"function\":{\"name\":\"web_search\",\"description\":\"Search the internet\",\"parameters\":{\"type\":\"object\",\"properties\":{\"query\":{\"type\":\"string\",\"description\":\"Search query\"}},\"required\":[\"query\"]}}},{\"type\":\"function\",\"function\":{\"name\":\"code_interpreter\",\"description\":\"Evaluates python code in a sandbox environment. The environment resets on every execution. You must send the whole script every time and print your outputs. Script should be pure python code that can be evaluated. It should be in python format NOT markdown. The code should NOT be wrapped in backticks. All python packages including requests, matplotlib, scipy, numpy, pandas, etc are available. Output can only be read from stdout, and stdin. Do not use things like plot.show() as it will not work. print() any output and results so you can capture the output.\",\"parameters\":{\"type\":\"object\",\"properties\":{\"code\":{\"type\":\"string\",\"description\":\"The pure python script to be evaluated. The contents will be in main.py. It should not be in markdown format.\"}},\"required\":[\"code\"]}}}]",
toolChoice : "auto", // "none" or "auto"
user : "" // null
});
}

// @notice Ensures the caller is the contract owner
Expand Down Expand Up @@ -92,7 +114,7 @@ contract ChatGpt {
);
} else {
// Otherwise, create an LLM call
IOracle(oracleAddress).createLlmCall(currentId);
IOracle(oracleAddress).createLlmCall(currentId, config);
}
emit ChatCreated(msg.sender, currentId);

Expand All @@ -105,8 +127,8 @@ contract ChatGpt {
// @dev Called by teeML oracle
function onOracleLlmResponse(
uint runId,
string memory response,
string memory /*errorMessage*/
IOracle.LlmResponse memory response,
string memory errorMessage
) public onlyOracle {
ChatRun storage run = chatRuns[runId];
require(
Expand All @@ -115,10 +137,48 @@ contract ChatGpt {
);

Message memory newMessage;
newMessage.content = response;
newMessage.role = "assistant";
run.messages.push(newMessage);
run.messagesCount++;
if (!compareStrings(errorMessage, "")) {
newMessage.role = "assistant";
newMessage.content = errorMessage;
run.messages.push(newMessage);
run.messagesCount++;
} else {
if (!compareStrings(response.functionName, "")) {
toolRunning[runId] = response.functionName;
IOracle(oracleAddress).createFunctionCall(runId, response.functionName, response.functionArguments);
} else {
toolRunning[runId] = "";
}
newMessage.role = "assistant";
newMessage.content = response.content;
run.messages.push(newMessage);
run.messagesCount++;
}
}

// @notice Handles the response from the oracle for a function call
// @param runId The ID of the chat run
// @param response The response from the oracle
// @param errorMessage Any error message
// @dev Called by teeML oracle
function onOracleFunctionResponse(
uint runId,
string memory response,
string memory errorMessage
) public onlyOracle {
require(
!compareStrings(toolRunning[runId], ""),
"No function to respond to"
);
ChatRun storage run = chatRuns[runId];
if (compareStrings(errorMessage, "")) {
Message memory newMessage;
newMessage.role = "user";
newMessage.content = response;
run.messages.push(newMessage);
run.messagesCount++;
IOracle(oracleAddress).createLlmCall(runId, config);
}
}

// @notice Handles the response from the oracle for a knowledge base query
Expand Down Expand Up @@ -155,7 +215,7 @@ contract ChatGpt {
lastMessage.content = newContent;

// Call LLM
IOracle(oracleAddress).createLlmCall(runId);
IOracle(oracleAddress).createLlmCall(runId, config);
}

// @notice Adds a new message to an existing chat run
Expand Down Expand Up @@ -186,7 +246,7 @@ contract ChatGpt {
);
} else {
// Otherwise, create an LLM call
IOracle(oracleAddress).createLlmCall(runId);
IOracle(oracleAddress).createLlmCall(runId, config);
}
}

Expand All @@ -213,4 +273,12 @@ contract ChatGpt {
}
return roles;
}

// @notice Compares two strings for equality
// @param a The first string
// @param b The second string
// @return True if the strings are equal, false otherwise
function compareStrings(string memory a, string memory b) private pure returns (bool) {
return (keccak256(abi.encodePacked((a))) == keccak256(abi.encodePacked((b))));
}
}
8 changes: 6 additions & 2 deletions contracts/contracts/ChatOracle.sol
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ contract ChatOracle is IOracle {
// @dev Default is OpenAI
mapping(uint => string) public promptType;

// @notice Mapping of prompt ID to the LLM configuration
mapping(uint => IOracle.LlmRequest) public llmConfigurations;

// @notice Mapping of prompt ID to the OpenAI configuration
mapping(uint => IOracle.OpenAiRequest) public openAiConfigurations;

Expand Down Expand Up @@ -178,7 +181,7 @@ contract ChatOracle is IOracle {
// @notice Creates a new LLM call
// @param promptCallbackId The callback ID for the LLM call
// @return The ID of the created prompt
function createLlmCall(uint promptCallbackId) public returns (uint) {
function createLlmCall(uint promptCallbackId, IOracle.LlmRequest memory request) public returns (uint) {
uint promptId = promptsCount;
callbackAddresses[promptId] = msg.sender;
promptCallbackIds[promptId] = promptCallbackId;
Expand All @@ -187,6 +190,7 @@ contract ChatOracle is IOracle {

promptsCount++;

llmConfigurations[promptId] = request;
emit PromptAdded(promptId, promptCallbackId, msg.sender);

return promptId;
Expand All @@ -201,7 +205,7 @@ contract ChatOracle is IOracle {
function addResponse(
uint promptId,
uint promptCallBackId,
string memory response,
IOracle.LlmResponse memory response,
string memory errorMessage
) public onlyWhitelisted {
require(!isPromptProcessed[promptId], "Prompt already processed");
Expand Down
36 changes: 36 additions & 0 deletions contracts/contracts/GroqChatGpt.sol
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ contract GroqChatGpt {
// @notice Address of the oracle contract
address public oracleAddress;

// @notice Mapping from chat ID to the tool currently running
mapping(uint => string) public toolRunning;

// @notice Event emitted when the oracle address is updated
event OracleAddressUpdated(address indexed newOracleAddress);

Expand All @@ -56,6 +59,8 @@ contract GroqChatGpt {
stop : "", // null
temperature : 10, // Example temperature (scaled up, 10 means 1.0), > 20 means null
topP : 101, // Percentage 0-100, > 100 means null
tools : "[{\"type\":\"function\",\"function\":{\"name\":\"web_search\",\"description\":\"Search the internet\",\"parameters\":{\"type\":\"object\",\"properties\":{\"query\":{\"type\":\"string\",\"description\":\"Search query\"}},\"required\":[\"query\"]}}},{\"type\":\"function\",\"function\":{\"name\":\"code_interpreter\",\"description\":\"Evaluates python code in a sandbox environment. The environment resets on every execution. You must send the whole script every time and print your outputs. Script should be pure python code that can be evaluated. It should be in python format NOT markdown. The code should NOT be wrapped in backticks. All python packages including requests, matplotlib, scipy, numpy, pandas, etc are available. Output can only be read from stdout, and stdin. Do not use things like plot.show() as it will not work. print() any output and results so you can capture the output.\",\"parameters\":{\"type\":\"object\",\"properties\":{\"code\":{\"type\":\"string\",\"description\":\"The pure python script to be evaluated. The contents will be in main.py. It should not be in markdown format.\"}},\"required\":[\"code\"]}}}]",
toolChoice : "auto", // "none" or "auto"
user : "" // null
});
}
Expand Down Expand Up @@ -122,6 +127,12 @@ contract GroqChatGpt {
run.messages.push(newMessage);
run.messagesCount++;
} else {
if (!compareStrings(response.functionName, "")) {
toolRunning[runId] = response.functionName;
IOracle(oracleAddress).createFunctionCall(runId, response.functionName, response.functionArguments);
} else {
toolRunning[runId] = "";
}
Message memory newMessage;
newMessage.role = "assistant";
newMessage.content = response.content;
Expand All @@ -130,6 +141,31 @@ contract GroqChatGpt {
}
}

// @notice Handles the response from the oracle for a function call
// @param runId The ID of the chat run
// @param response The response from the oracle
// @param errorMessage Any error message
// @dev Called by teeML oracle
function onOracleFunctionResponse(
uint runId,
string memory response,
string memory errorMessage
) public onlyOracle {
require(
!compareStrings(toolRunning[runId], ""),
"No function to respond to"
);
ChatRun storage run = chatRuns[runId];
if (compareStrings(errorMessage, "")) {
Message memory newMessage;
newMessage.role = "user";
newMessage.content = response;
run.messages.push(newMessage);
run.messagesCount++;
IOracle(oracleAddress).createGroqLlmCall(runId, config);
}
}

// @notice Adds a new message to an existing chat run
// @param message The new message to add
// @param runId The ID of the chat run
Expand Down
47 changes: 26 additions & 21 deletions contracts/contracts/OpenAiChatGpt.sol
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ contract OpenAiChatGpt {
// @notice Address of the oracle contract
address public oracleAddress;

// @notice Mapping from chat ID to the tool currently running
mapping(uint => string) public toolRunning;

// @notice Event emitted when the oracle address is updated
event OracleAddressUpdated(address indexed newOracleAddress);

Expand All @@ -46,19 +49,19 @@ contract OpenAiChatGpt {
chatRunsCount = 0;

config = IOracle.OpenAiRequest({
model : "gpt-4-turbo-preview",
frequencyPenalty : 21, // > 20 for null
logitBias : "", // empty str for null
maxTokens : 1000, // 0 for null
presencePenalty : 21, // > 20 for null
responseFormat : "{\"type\":\"text\"}",
seed : 0, // null
stop : "", // null
temperature : 10, // Example temperature (scaled up, 10 means 1.0), > 20 means null
topP : 101, // Percentage 0-100, > 100 means null
tools : "[{\"type\":\"function\",\"function\":{\"name\":\"web_search\",\"description\":\"Search the internet\",\"parameters\":{\"type\":\"object\",\"properties\":{\"query\":{\"type\":\"string\",\"description\":\"Search query\"}},\"required\":[\"query\"]}}},{\"type\":\"function\",\"function\":{\"name\":\"code_interpreter\",\"description\":\"Evaluates python code in a sandbox environment. The environment resets on every execution. You must send the whole script every time and print your outputs. Script should be pure python code that can be evaluated. It should be in python format NOT markdown. The code should NOT be wrapped in backticks. All python packages including requests, matplotlib, scipy, numpy, pandas, etc are available. Output can only be read from stdout, and stdin. Do not use things like plot.show() as it will not work. print() any output and results so you can capture the output.\",\"parameters\":{\"type\":\"object\",\"properties\":{\"code\":{\"type\":\"string\",\"description\":\"The pure python script to be evaluated. The contents will be in main.py. It should not be in markdown format.\"}},\"required\":[\"code\"]}}}]",
toolChoice : "auto", // "none" or "auto"
user : "" // null
model : "gpt-4-turbo-preview",
frequencyPenalty : 21, // > 20 for null
logitBias : "", // empty str for null
maxTokens : 1000, // 0 for null
presencePenalty : 21, // > 20 for null
responseFormat : "{\"type\":\"text\"}",
seed : 0, // null
stop : "", // null
temperature : 10, // Example temperature (scaled up, 10 means 1.0), > 20 means null
topP : 101, // Percentage 0-100, > 100 means null
tools : "[{\"type\":\"function\",\"function\":{\"name\":\"web_search\",\"description\":\"Search the internet\",\"parameters\":{\"type\":\"object\",\"properties\":{\"query\":{\"type\":\"string\",\"description\":\"Search query\"}},\"required\":[\"query\"]}}},{\"type\":\"function\",\"function\":{\"name\":\"code_interpreter\",\"description\":\"Evaluates python code in a sandbox environment. The environment resets on every execution. You must send the whole script every time and print your outputs. Script should be pure python code that can be evaluated. It should be in python format NOT markdown. The code should NOT be wrapped in backticks. All python packages including requests, matplotlib, scipy, numpy, pandas, etc are available. Output can only be read from stdout, and stdin. Do not use things like plot.show() as it will not work. print() any output and results so you can capture the output.\",\"parameters\":{\"type\":\"object\",\"properties\":{\"code\":{\"type\":\"string\",\"description\":\"The pure python script to be evaluated. The contents will be in main.py. It should not be in markdown format.\"}},\"required\":[\"code\"]}}}]",
toolChoice : "auto", // "none" or "auto"
user : "" // null
});
}

Expand Down Expand Up @@ -126,15 +129,17 @@ contract OpenAiChatGpt {
run.messages.push(newMessage);
run.messagesCount++;
} else {
if (compareStrings(response.content, "")) {
if (!compareStrings(response.functionName, "")) {
toolRunning[runId] = response.functionName;
IOracle(oracleAddress).createFunctionCall(runId, response.functionName, response.functionArguments);
} else {
Message memory newMessage;
newMessage.role = "assistant";
newMessage.content = response.content;
run.messages.push(newMessage);
run.messagesCount++;
toolRunning[runId] = "";
}
Message memory newMessage;
kgrofelnik marked this conversation as resolved.
Show resolved Hide resolved
newMessage.role = "assistant";
newMessage.content = response.content;
run.messages.push(newMessage);
run.messagesCount++;
}
}

Expand All @@ -148,11 +153,11 @@ contract OpenAiChatGpt {
string memory response,
string memory errorMessage
) public onlyOracle {
ChatRun storage run = chatRuns[runId];
require(
compareStrings(run.messages[run.messagesCount - 1].role, "user"),
!compareStrings(toolRunning[runId], ""),
"No function to respond to"
);
ChatRun storage run = chatRuns[runId];
if (compareStrings(errorMessage, "")) {
Message memory newMessage;
newMessage.role = "user";
Expand Down
Loading