Skip to content

Commit

Permalink
Use generative-ai-swift tests in Vertex AI (#12585)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Mar 19, 2024
1 parent aa474f5 commit c68e1e3
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 130 deletions.
7 changes: 3 additions & 4 deletions .github/workflows/vertexai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ jobs:
- name: Xcode
run: sudo xcode-select -s /Applications/${{ matrix.xcode }}.app/Contents/Developer
- name: Initialize xcodebuild
run: xcodebuild -list
# TODO: Add unit tests and switch from `spmbuildonly` to `spm`.
- name: Build
run: scripts/third_party/travis/retry.sh scripts/build.sh FirebaseVertexAI ${{ matrix.target }} spmbuildonly
run: scripts/setup_spm_tests.sh
- name: Build and run tests
run: scripts/third_party/travis/retry.sh scripts/build.sh FirebaseVertexAIUnit ${{ matrix.target }} spm
11 changes: 9 additions & 2 deletions FirebaseVertexAI/Tests/Unit/ChatTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
// limitations under the License.

import Foundation
@testable import GoogleGenerativeAI
import XCTest

@testable import FirebaseVertexAI

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, *)
final class ChatTests: XCTestCase {
var urlSession: URLSession!
Expand Down Expand Up @@ -46,7 +47,13 @@ final class ChatTests: XCTestCase {
return (response, fileURL.lines)
}

let model = GenerativeModel(name: "my-model", apiKey: "API_KEY", urlSession: urlSession)
let model = GenerativeModel(
name: "my-model",
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: nil,
urlSession: urlSession
)
let chat = Chat(model: model, history: [])
let input = "Test input"
let stream = chat.sendMessageStream(input)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"reason": "API_KEY_INVALID",
"domain": "googleapis.com",
"metadata": {
"service": "generativelanguage.googleapis.com"
"service": "staging-firebaseml.sandbox.googleapis.com"
}
},
{
Expand Down

This file was deleted.

This file was deleted.

90 changes: 42 additions & 48 deletions FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

@testable import GoogleGenerativeAI
import XCTest

@testable import FirebaseVertexAI

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, *)
final class GenerativeModelTests: XCTestCase {
let testPrompt = "What sorts of questions can I ask you?"
Expand All @@ -32,7 +33,13 @@ final class GenerativeModelTests: XCTestCase {
let configuration = URLSessionConfiguration.default
configuration.protocolClasses = [MockURLProtocol.self]
urlSession = try XCTUnwrap(URLSession(configuration: configuration))
model = GenerativeModel(name: "my-model", apiKey: "API_KEY", urlSession: urlSession)
model = GenerativeModel(
name: "my-model",
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: nil,
urlSession: urlSession
)
}

override func tearDown() {
Expand Down Expand Up @@ -163,6 +170,8 @@ final class GenerativeModelTests: XCTestCase {
// Model name is prefixed with "models/".
name: "models/test-model",
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: nil,
urlSession: urlSession
)

Expand All @@ -181,10 +190,13 @@ final class GenerativeModelTests: XCTestCase {
do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw GenerateContentError.internalError; no error thrown.")
} catch let GenerateContentError.invalidAPIKey(message) {
XCTAssertEqual(message, "API key not valid. Please pass a valid API key.")
} catch let GenerateContentError.internalError(error as RPCError) {
XCTAssertEqual(error.httpResponseCode, 400)
XCTAssertEqual(error.status, .invalidArgument)
XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
return
} catch {
XCTFail("Should throw GenerateContentError.invalidAPIKey; error thrown: \(error)")
XCTFail("Should throw GenerateContentError.internalError(RPCError); error thrown: \(error)")
}
}

Expand Down Expand Up @@ -342,24 +354,6 @@ final class GenerativeModelTests: XCTestCase {
}
}

func testGenerateContent_failure_unsupportedUserLocation() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-failure-unsupported-user-location",
withExtension: "json",
statusCode: 400
)

do {
_ = try await model.generateContent(testPrompt)
XCTFail("Should throw GenerateContentError.unsupportedUserLocation; no error thrown.")
} catch GenerateContentError.unsupportedUserLocation {
return
}

XCTFail("Expected an unsupported user location error.")
}

func testGenerateContent_failure_nonHTTPResponse() async throws {
MockURLProtocol.requestHandler = try nonHTTPRequestHandler()

Expand Down Expand Up @@ -468,6 +462,7 @@ final class GenerativeModelTests: XCTestCase {
name: "my-model",
apiKey: "API_KEY",
requestOptions: requestOptions,
appCheck: nil,
urlSession: urlSession
)

Expand All @@ -490,8 +485,10 @@ final class GenerativeModelTests: XCTestCase {
for try await _ in stream {
XCTFail("No content is there, this shouldn't happen.")
}
} catch GenerateContentError.invalidAPIKey {
// invalidAPIKey error is as expected, nothing else to check.
} catch let GenerateContentError.internalError(error as RPCError) {
XCTAssertEqual(error.httpResponseCode, 400)
XCTAssertEqual(error.status, .invalidArgument)
XCTAssertEqual(error.message, "API key not valid. Please pass a valid API key.")
return
}

Expand Down Expand Up @@ -747,26 +744,6 @@ final class GenerativeModelTests: XCTestCase {
XCTFail("Expected an internal decoding error.")
}

func testGenerateContentStream_failure_unsupportedUserLocation() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-failure-unsupported-user-location",
withExtension: "json",
statusCode: 400
)

let stream = model.generateContentStream(testPrompt)
do {
for try await content in stream {
XCTFail("Unexpected content in stream: \(content)")
}
} catch GenerateContentError.unsupportedUserLocation {
return
}

XCTFail("Expected an unsupported user location error.")
}

func testGenerateContentStream_requestOptions_customTimeout() async throws {
let expectedTimeout = 150.0
MockURLProtocol
Expand All @@ -780,6 +757,7 @@ final class GenerativeModelTests: XCTestCase {
name: "my-model",
apiKey: "API_KEY",
requestOptions: requestOptions,
appCheck: nil,
urlSession: urlSession
)

Expand Down Expand Up @@ -837,6 +815,7 @@ final class GenerativeModelTests: XCTestCase {
name: "my-model",
apiKey: "API_KEY",
requestOptions: requestOptions,
appCheck: nil,
urlSession: urlSession
)

Expand All @@ -851,23 +830,38 @@ final class GenerativeModelTests: XCTestCase {
let modelName = "my-model"
let modelResourceName = "models/\(modelName)"

model = GenerativeModel(name: modelName, apiKey: "API_KEY")
model = GenerativeModel(
name: modelName,
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: nil
)

XCTAssertEqual(model.modelResourceName, modelResourceName)
}

func testModelResourceName_modelsPrefix() async throws {
let modelResourceName = "models/my-model"

model = GenerativeModel(name: modelResourceName, apiKey: "API_KEY")
model = GenerativeModel(
name: modelResourceName,
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: nil
)

XCTAssertEqual(model.modelResourceName, modelResourceName)
}

func testModelResourceName_tunedModelsPrefix() async throws {
let tunedModelResourceName = "tunedModels/my-model"

model = GenerativeModel(name: tunedModelResourceName, apiKey: "API_KEY")
model = GenerativeModel(
name: tunedModelResourceName,
apiKey: "API_KEY",
requestOptions: RequestOptions(),
appCheck: nil
)

XCTAssertEqual(model.modelResourceName, tunedModelResourceName)
}
Expand Down
2 changes: 1 addition & 1 deletion FirebaseVertexAI/Tests/Unit/PartsRepresentableTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import CoreGraphics
import CoreImage
import GoogleGenerativeAI
import FirebaseVertexAI
import XCTest
#if canImport(UIKit)
import UIKit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import GoogleGenerativeAI
import FirebaseCore
import FirebaseVertexAI
import XCTest
#if canImport(AppKit)
import AppKit // For NSImage extensions.
Expand All @@ -21,8 +22,9 @@ import XCTest
#endif

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
final class GoogleGenerativeAITests: XCTestCase {
final class VertexAIAPITests: XCTestCase {
func codeSamples() async throws {
let app = FirebaseApp.app()
let config = GenerationConfig(temperature: 0.2,
topP: 0.1,
topK: 16,
Expand All @@ -32,16 +34,40 @@ final class GoogleGenerativeAITests: XCTestCase {
let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)]

// Permutations without optional arguments.
let _ = GenerativeModel(name: "gemini-1.0-pro", apiKey: "API_KEY")
let _ = GenerativeModel(name: "gemini-1.0-pro", apiKey: "API_KEY", safetySettings: filters)
let _ = GenerativeModel(name: "gemini-1.0-pro", apiKey: "API_KEY", generationConfig: config)

// All arguments passed.
let genAI = GenerativeModel(name: "gemini-1.0-pro",
apiKey: "API_KEY",
generationConfig: config, // Optional
safetySettings: filters // Optional
// TODO: Change `genAI` to `_` when safetySettings and generationConfig are added to public API.
let genAI = VertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1")
let _ = VertexAI.generativeModel(
app: app!,
modelName: "gemini-1.0-pro",
location: "us-central1"
)

// TODO: Add safetySettings to public API.
// TODO: Add permutation with `app` specified.
// let _ = VertexAI.generativeModel(
// modelName: "gemini-1.0-pro",
// location: "us-central1",
// safetySettings: filters
// )
// TODO: Add generationConfig to public API.
// TODO: Add permutation with `app` specified.
// let _ = VertexAI.generativeModel(
// modelName: "gemini-1.0-pro",
// location: "us-central1",
// generationConfig: config
// )

// All arguments passed.
// TODO: Add safetySettings and generationConfig to public API.
// TODO: Add permutation with `app` specified.
// let genAI = VertexAI.generativeModel(
// modelName: "gemini-1.0-pro",
// location: "us-central1",
// generationConfig: config, // Optional
// safetySettings: filters // Optional
// )

// Full Typed Usage
let pngData = Data() // ....
let contents = [ModelContent(role: "user",
Expand Down
9 changes: 9 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,15 @@ let package = Package(
],
path: "FirebaseVertexAI/Sources"
),
.testTarget(
name: "FirebaseVertexAIUnit",
dependencies: ["FirebaseVertexAI"],
path: "FirebaseVertexAI/Tests/Unit",
resources: [
.process("CountTokenResponses"),
.process("GenerateContentResponses"),
]
),
] + firestoreTargets(),
cLanguageStandard: .c99,
cxxLanguageStandard: CXXLanguageStandard.gnucxx14
Expand Down
Loading

0 comments on commit c68e1e3

Please sign in to comment.