-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from ajaynomics/ajaynomics/text-classification
support for text classification models
- Loading branch information
Showing
12 changed files
with
125 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
class Cloudflare::AI::Results::TextClassification < Cloudflare::AI::Result | ||
# Empty seam kept for consistency with other result objects that have more complexity. | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,6 @@ | |
|
||
module Cloudflare | ||
module AI | ||
VERSION = "0.3.0" | ||
VERSION = "0.4.0" | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
require "test_helper" | ||
require_relative "test_helpers" | ||
|
||
module Cloudflare::AI::Clients | ||
class TextClassificationTest < Minitest::Test | ||
include Cloudflare::AI::Clients::TestHelpers | ||
|
||
def test_successful_request | ||
stub_successful_response | ||
response = @client.classify(text: "This is a happy thought", model_name: @model_name) | ||
|
||
assert response.is_a? Cloudflare::AI::Results::TextClassification | ||
assert response.success? | ||
end | ||
|
||
def test_unsuccessful_request | ||
stub_unsuccessful_response | ||
response = @client.classify(text: "This won't work", model_name: @model_name) | ||
|
||
assert response.is_a? Cloudflare::AI::Results::TextClassification | ||
assert response.failure? | ||
end | ||
|
||
def test_uses_default_model_if_not_provided | ||
model_name = Cloudflare::AI::Models.text_classification.first | ||
@url = @client.send(:service_url_for, account_id: @account_id, model_name: model_name) | ||
|
||
stub_successful_response | ||
assert @client.classify(text: "This will run with the default model") # Webmock will raise an error if the request was to wrong model | ||
end | ||
|
||
private | ||
|
||
def stub_successful_response | ||
stub_request(:post, @url) | ||
.to_return(status: 200, body: {success: true}.to_json) | ||
end | ||
|
||
def default_model_name | ||
Cloudflare::AI::Models.text_embedding.first | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
require "test_helper" | ||
|
||
class Cloudflare::AI::Results::TextClassificationTest < Minitest::Test | ||
def test_successful_result | ||
result = Cloudflare::AI::Results::TextClassification.new(successful_response_json) | ||
assert result.success? | ||
refute result.failure? | ||
|
||
assert_equal successful_response_json["result"], result.result | ||
end | ||
|
||
def test_to_json | ||
result = Cloudflare::AI::Results::TextClassification.new(successful_response_json) | ||
assert_equal successful_response_json.to_json, result.to_json | ||
end | ||
|
||
private | ||
|
||
def successful_response_json | ||
{ | ||
result: [{label: "positive", score: 0.9999998807907104}, {label: "negative", score: 1.1920928955078125e-7}], | ||
success: true, | ||
errors: [], | ||
messages: [] | ||
}.deep_stringify_keys | ||
end | ||
end |