Skip to content

Commit

Permalink
Merge pull request #3 from ajaynomics/ajaynomics/text-classification
Browse files Browse the repository at this point in the history
support for text classification models
  • Loading branch information
ajaynomics authored Jan 22, 2024
2 parents 5d3d86b + 75c4ac3 commit cbc97ae
Show file tree
Hide file tree
Showing 12 changed files with 125 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Gemfile.lock
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
PATH
remote: .
specs:
cloudflare-ai (0.3.0)
cloudflare-ai (0.4.0)
activemodel (~> 7.0)
activesupport (~> 7.0)
event_stream_parser (~> 1.0)
Expand Down
24 changes: 23 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ Thiis gem provides a client that wraps around [Cloudflare's REST API](https://de
client = Cloudflare::AI::Client.new(account_id: ENV["CLOUDFLARE_ACCOUNT_ID"], api_token: ENV["CLOUDFLARE_API_TOKEN"])
```

### Model selection
The model name is an optional parameter to every one of the client methods described below.
For example, if an example is documented as
```ruby
result = client.complete(prompt: "Hello my name is")
```
this is implicitly the same as
```ruby
result = client.complete(prompt: "Hello my name is", model: "@cf/meta/llama-2-7b-chat-fp16")
```
The full list of supported models is available here: [models.rb](lib/cloudflare/ai/models.rb).
More information is available [in the cloudflare documentation](https://developers.cloudflare.com/workers-ai/models/).
The default model used is the first enumerated model in the applicable set in [models.rb](lib/cloudflare/ai/models.rb).

### Text generation (chat / scoped prompt)
```ruby
Expand Down Expand Up @@ -120,7 +133,16 @@ result = client.embed(text: ["Hello", "World"])
```

#### Result object
All invocations of the `embedding` methods return a `Cloudflare::AI::Results::TextEmbedding`.
All invocations of the `embed` methods return a `Cloudflare::AI::Results::TextEmbedding`.

### Text classification
```ruby
result = client.classify(text: "You meanie!")
p result.result # => [{"label"=>"NEGATIVE", "score"=>0.6647962927818298}, {"label"=>"POSITIVE", "score"=>0.3352036774158478}]
```

#### Result object
All invocations of the `classify` methods return a `Cloudflare::AI::Results::TextClassification`.

# Logging

Expand Down
7 changes: 7 additions & 0 deletions lib/cloudflare/ai/client.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ def chat(messages:, model_name: default_text_generation_model_name, &block)
post_streamable_request(url, payload, &block)
end

def classify(text:, model_name: Cloudflare::AI::Models.text_classification.first)
url = service_url_for(account_id: account_id, model_name: model_name)
payload = {text: text}.to_json

Cloudflare::AI::Results::TextClassification.new(connection.post(url, payload).body)
end

def complete(prompt:, model_name: default_text_generation_model_name, &block)
url = service_url_for(account_id: account_id, model_name: model_name)
stream = block ? true : false
Expand Down
3 changes: 3 additions & 0 deletions lib/cloudflare/ai/results/text_classification.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class Cloudflare::AI::Results::TextClassification < Cloudflare::AI::Result
# Empty seam kept for consistency with other result objects that have more complexity.
end
2 changes: 1 addition & 1 deletion lib/cloudflare/ai/version.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

module Cloudflare
module AI
VERSION = "0.3.0"
VERSION = "0.4.0"
end
end
6 changes: 5 additions & 1 deletion test/cloudflare/ai/clients/test_helpers.rb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ def set_service_url_for_model(model_name)
@url = @client.send(:service_url_for, account_id: @account_id, model_name: model_name)
end

def stub_response_for_unsuccessful_completion
def stub_successful_response
stub_request(:post, @url).to_return(status: 200, body: {success: true}.to_json)
end

def stub_unsuccessful_response
stub_request(:post, @url)
.to_return(status: 200, body: {success: false, errors: [{code: 10000, message: "Some error"}]}.to_json)
end
Expand Down
43 changes: 43 additions & 0 deletions test/cloudflare/ai/clients/text_classification_test.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
require "test_helper"
require_relative "test_helpers"

module Cloudflare::AI::Clients
class TextClassificationTest < Minitest::Test
include Cloudflare::AI::Clients::TestHelpers

def test_successful_request
stub_successful_response
response = @client.classify(text: "This is a happy thought", model_name: @model_name)

assert response.is_a? Cloudflare::AI::Results::TextClassification
assert response.success?
end

def test_unsuccessful_request
stub_unsuccessful_response
response = @client.classify(text: "This won't work", model_name: @model_name)

assert response.is_a? Cloudflare::AI::Results::TextClassification
assert response.failure?
end

def test_uses_default_model_if_not_provided
model_name = Cloudflare::AI::Models.text_classification.first
@url = @client.send(:service_url_for, account_id: @account_id, model_name: model_name)

stub_successful_response
assert @client.classify(text: "This will run with the default model") # Webmock will raise an error if the request was to wrong model
end

private

def stub_successful_response
stub_request(:post, @url)
.to_return(status: 200, body: {success: true}.to_json)
end

def default_model_name
Cloudflare::AI::Models.text_embedding.first
end
end
end
19 changes: 7 additions & 12 deletions test/cloudflare/ai/clients/text_embedding_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,27 @@
require_relative "test_helpers"

module Cloudflare::AI::Clients
class TextEmbeddiingTest < Minitest::Test
class TextEmbeddingTest < Minitest::Test
include Cloudflare::AI::Clients::TestHelpers

def test_successful_request_with_string_input
stub_response_for_successful_embedding_of_string
stub_successful_response
response = @client.embed(text: "hello", model_name: @model_name)

assert response.is_a? Cloudflare::AI::Results::TextEmbedding
assert response.success?
end

def test_successful_request_with_array_input
stub_response_for_successful_embedding_of_array
stub_successful_response
response = @client.embed(text: ["hello", "jello"], model_name: @model_name)

assert response.is_a? Cloudflare::AI::Results::TextEmbedding
assert response.success?
end

def test_unsuccessful_request
stub_response_for_unsuccessful_completion
stub_unsuccessful_response
response = @client.embed(text: "hello", model_name: @model_name)

assert response.is_a? Cloudflare::AI::Results::TextEmbedding
Expand All @@ -33,20 +33,15 @@ def test_uses_default_model_if_not_provided
model_name = Cloudflare::AI::Models.text_embedding.first
@url = @client.send(:service_url_for, account_id: @account_id, model_name: model_name)

stub_response_for_successful_embedding_of_string
stub_successful_response
assert @client.embed(text: "hello") # Webmock will raise an error if the request was to wrong model
end

private

def stub_response_for_successful_embedding_of_string
def stub_successful_response
stub_request(:post, @url)
.to_return(status: 200, body: {result: {shape: [1, 4], data: [0.1, 0.2, 0.3, 0.4]}, success: true}.to_json)
end

def stub_response_for_successful_embedding_of_array
stub_request(:post, @url)
.to_return(status: 200, body: {result: {shape: [2, 4], data: [[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.3, 0.4]]}, success: true}.to_json)
.to_return(status: 200, body: {success: true}.to_json)
end

def default_model_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ class ChatCompletionTest < Minitest::Test
include Cloudflare::AI::Clients::TextGeneration::TestHelpers

def test_successful_request
stub_response_for_successful_completion
stub_successful_response
response = @client.chat(messages: messages_fixture, model_name: @model_name)

assert response.is_a? Cloudflare::AI::Results::TextGeneration
assert response.success?
end

def test_unsuccessful_request
stub_response_for_unsuccessful_completion
stub_unsuccessful_response
response = @client.chat(messages: messages_fixture, model_name: @model_name)

assert response.is_a? Cloudflare::AI::Results::TextGeneration
Expand All @@ -28,13 +28,13 @@ def test_uses_default_model_if_not_provided
model_name = Cloudflare::AI::Models.text_generation.first
@url = @client.send(:service_url_for, account_id: @account_id, model_name: model_name)

stub_response_for_successful_completion
stub_successful_response
assert @client.chat(messages: messages_fixture) # Webmock will raise an error if the request was to wrong model
end

def test_handle_streaming_from_cloudflare_to_client_if_block_given
set_service_url_for_model(Cloudflare::AI::Models.text_generation.first)
stub_response_for_successful_completion
stub_successful_response

inner_streaming_response_from_cloudflare_handled = false
outer_streaming_response_relayed = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ class PromptCompletionTest < Minitest::Test
include Cloudflare::AI::Clients::TextGeneration::TestHelpers

def test_successful_request
stub_response_for_successful_completion
stub_successful_response
response = @client.complete(prompt: "Happy song", model_name: @model_name)

assert response.is_a? Cloudflare::AI::Results::TextGeneration
assert response.success?
end

def test_unsuccessful_request
stub_response_for_unsuccessful_completion
stub_unsuccessful_response
response = @client.complete(prompt: "Sad song", model_name: @model_name)

assert response.is_a? Cloudflare::AI::Results::TextGeneration
Expand All @@ -26,14 +26,14 @@ def test_unsuccessful_request

def test_uses_default_model_if_not_provided
set_service_url_for_model(default_model_name)
stub_response_for_successful_completion
stub_successful_response

assert @client.complete(prompt: "Default song") # Webmock will raise an error if the request was to wrong model
end

def test_handle_streaming_from_cloudflare_to_client_if_block_given
set_service_url_for_model(default_model_name)
stub_response_for_successful_completion
stub_successful_response

inner_streaming_data_received_from_cloudflare = false
outer_streaming_data_relayed_to_client_block = false
Expand Down
10 changes: 0 additions & 10 deletions test/cloudflare/ai/clients/text_generation/test_helpers.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,6 @@ module TextGeneration
module TestHelpers
private

def stub_response_for_successful_completion(response: "Happy song")
stub_request(:post, @url)
.to_return(status: 200, body: {result: {response: response}, success: true}.to_json)
end

def stub_response_for_unsuccessful_completion
stub_request(:post, @url)
.to_return(status: 200, body: {success: false, errors: [{code: 10000, message: "Some error"}]}.to_json)
end

def default_model_name
Cloudflare::AI::Models.text_generation.first
end
Expand Down
27 changes: 27 additions & 0 deletions test/cloudflare/ai/results/text_classification_test.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
require "test_helper"

class Cloudflare::AI::Results::TextClassificationTest < Minitest::Test
def test_successful_result
result = Cloudflare::AI::Results::TextClassification.new(successful_response_json)
assert result.success?
refute result.failure?

assert_equal successful_response_json["result"], result.result
end

def test_to_json
result = Cloudflare::AI::Results::TextClassification.new(successful_response_json)
assert_equal successful_response_json.to_json, result.to_json
end

private

def successful_response_json
{
result: [{label: "positive", score: 0.9999998807907104}, {label: "negative", score: 1.1920928955078125e-7}],
success: true,
errors: [],
messages: []
}.deep_stringify_keys
end
end

0 comments on commit cbc97ae

Please sign in to comment.