-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add tests for chatgpt-ruby gem methods
- Loading branch information
Showing
4 changed files
with
215 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
require 'minitest/autorun' | ||
require 'chatgpt/client' | ||
|
||
class TestChatGPTClassify < Minitest::Test | ||
def setup | ||
api_key = ENV['API_KEY'] | ||
@client = ChatGPT::Client.new(api_key) | ||
end | ||
|
||
def test_classify_returns_valid_response | ||
text = "Is this a valid question?" | ||
|
||
response_body = { | ||
"data" => [ | ||
{ | ||
"label": "Valid Question" | ||
} | ||
] | ||
} | ||
|
||
response_object = RestClient::Response.new(response_body.to_json) | ||
|
||
RestClient.stub :post, response_object do | ||
response = @client.classify(text) | ||
refute_nil response | ||
assert_equal "Valid Question", response | ||
end | ||
end | ||
|
||
def test_classify_with_custom_params | ||
text = "Is this a valid question?" | ||
params = { model: 'text-davinci-002' } | ||
|
||
response_body = { | ||
"data" => [ | ||
{ | ||
"label": "Valid Question" | ||
} | ||
] | ||
} | ||
|
||
response_object = RestClient::Response.new(response_body.to_json) | ||
|
||
RestClient.stub :post, response_object do | ||
response = @client.classify(text, params) | ||
refute_nil response | ||
assert_equal "Valid Question", response | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
require 'minitest/autorun' | ||
require 'chatgpt/client' | ||
|
||
class TestChatGPTGenerateAnswers < Minitest::Test | ||
def setup | ||
api_key = ENV['API_KEY'] | ||
@client = ChatGPT::Client.new(api_key) | ||
end | ||
|
||
def test_generate_answers_returns_valid_response | ||
prompt = "What is the capital of France?" | ||
documents = ["Paris is the capital of France.", "France is a country in Europe."] | ||
|
||
response_body = { | ||
"data" => [ | ||
{ | ||
"answer" => "Paris" | ||
} | ||
] | ||
} | ||
|
||
response_object = RestClient::Response.new(response_body.to_json) | ||
|
||
RestClient.stub :post, response_object do | ||
response = @client.generate_answers(prompt, documents) | ||
refute_nil response | ||
assert_equal "Paris", response | ||
end | ||
end | ||
|
||
def test_generate_answers_with_custom_params | ||
prompt = "What is the capital of France?" | ||
documents = ["Paris is the capital of France.", "France is a country in Europe."] | ||
params = { model: 'text-davinci-002', max_tokens: 10 } | ||
|
||
response_body = { | ||
"data" => [ | ||
{ | ||
"answer" => "Paris" | ||
} | ||
] | ||
} | ||
|
||
response_object = RestClient::Response.new(response_body.to_json) | ||
|
||
RestClient.stub :post, response_object do | ||
response = @client.generate_answers(prompt, documents, params) | ||
refute_nil response | ||
assert_equal "Paris", response | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
require 'minitest/autorun' | ||
require 'chatgpt/client' | ||
|
||
class TestChatGPTGenerateSummaries < Minitest::Test | ||
def setup | ||
api_key = ENV['API_KEY'] | ||
@client = ChatGPT::Client.new(api_key) | ||
end | ||
|
||
def test_generate_summaries_returns_valid_response | ||
documents = ["This is a long text about apples.", "This is another long text about oranges."] | ||
|
||
response_body = { | ||
"choices" => [ | ||
{ | ||
"text" => "Summary: Apples and oranges." | ||
} | ||
] | ||
} | ||
|
||
response_object = RestClient::Response.new(response_body.to_json) | ||
|
||
RestClient.stub :post, response_object do | ||
response = @client.generate_summaries(documents) | ||
refute_nil response | ||
assert_equal "Summary: Apples and oranges.", response | ||
end | ||
end | ||
|
||
def test_generate_summaries_with_custom_params | ||
documents = ["This is a long text about apples.", "This is another long text about oranges."] | ||
params = { | ||
model: 'text-davinci-002', | ||
max_tokens: 30, | ||
temperature: 0.8, | ||
top_p: 1.0, | ||
frequency_penalty: 0.1, | ||
presence_penalty: 0.1 | ||
} | ||
|
||
response_body = { | ||
"choices" => [ | ||
{ | ||
"text" => "Summary: Apples and oranges." | ||
} | ||
] | ||
} | ||
|
||
response_object = RestClient::Response.new(response_body.to_json) | ||
|
||
RestClient.stub :post, response_object do | ||
response = @client.generate_summaries(documents, params) | ||
refute_nil response | ||
assert_equal "Summary: Apples and oranges.", response | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
require 'minitest/autorun' | ||
require 'chatgpt/client' | ||
|
||
class TestChatGPTSearch < Minitest::Test | ||
def setup | ||
api_key = ENV['API_KEY'] | ||
@client = ChatGPT::Client.new(api_key) | ||
end | ||
|
||
def test_search_returns_valid_response | ||
documents = ["Apple", "Orange", "Banana", "Grape"] | ||
query = "fruit" | ||
|
||
response_body = { | ||
"data" => [ | ||
{"id": "0", "score": 1.0}, | ||
{"id": "1", "score": 0.5}, | ||
{"id": "2", "score": 0.3}, | ||
{"id": "3", "score": 0.2} | ||
] | ||
} | ||
|
||
response_object = RestClient::Response.new(response_body.to_json) | ||
|
||
RestClient.stub :post, response_object do | ||
response = @client.search(documents, query) | ||
assert_equal 4, response.length | ||
assert response[0]['id'] | ||
assert response[0]['score'] | ||
end | ||
end | ||
|
||
def test_search_with_custom_params | ||
documents = ["Apple", "Orange", "Banana", "Grape"] | ||
query = "fruit" | ||
params = { engine: 'davinci', max_rerank: 100 } | ||
|
||
response_body = { | ||
"data" => [ | ||
{"id": "0", "score": 1.0}, | ||
{"id": "1", "score": 0.5}, | ||
{"id": "2", "score": 0.3}, | ||
{"id": "3", "score": 0.2} | ||
] | ||
} | ||
|
||
response_object = RestClient::Response.new(response_body.to_json) | ||
|
||
RestClient.stub :post, response_object do | ||
response = @client.search(documents, query, params) | ||
assert_equal 4, response.length | ||
assert response[0]['id'] | ||
assert response[0]['score'] | ||
end | ||
end | ||
end |