Skip to content

Commit

Permalink
Add ClientInterface and use request header instead of query parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
erdemkose committed Dec 24, 2023
1 parent 6fae105 commit aa176c7
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 26 deletions.
19 changes: 8 additions & 11 deletions src/Client.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

namespace GeminiAPI;

use GeminiAPI\ClientInterface as GeminiClientInterface;
use GeminiAPI\Enums\ModelName;
use GeminiAPI\Requests\CountTokensRequest;
use GeminiAPI\Requests\GenerateContentRequest;
Expand All @@ -15,20 +16,19 @@
use Http\Discovery\Psr17FactoryDiscovery;
use Http\Discovery\Psr18ClientDiscovery;
use Psr\Http\Client\ClientExceptionInterface;
use Psr\Http\Client\ClientInterface;
use Psr\Http\Client\ClientInterface as HttpClientInterface;
use Psr\Http\Message\RequestFactoryInterface;
use Psr\Http\Message\StreamFactoryInterface;
use RuntimeException;

use function json_decode;

class Client
class Client implements GeminiClientInterface
{
private string $baseUrl = 'https://generativelanguage.googleapis.com';

public function __construct(
private readonly string $apiKey,
private ?ClientInterface $client = null,
private ?HttpClientInterface $client = null,
private ?RequestFactoryInterface $requestFactory = null,
private ?StreamFactoryInterface $streamFactory = null,
) {
Expand Down Expand Up @@ -106,13 +106,10 @@ private function doRequest(RequestInterface $request): string
throw new RuntimeException('Missing client or factory for Gemini API operation');
}

$uri = sprintf(
'%s/v1/%s?key=%s',
$this->baseUrl,
$request->getOperation(),
$this->apiKey,
);
$httpRequest = $this->requestFactory->createRequest($request->getHttpMethod(), $uri);
$uri = "{$this->baseUrl}/v1/{$request->getOperation()}";
$httpRequest = $this->requestFactory
->createRequest($request->getHttpMethod(), $uri)
->withAddedHeader(self::API_KEY_HEADER_NAME, $this->apiKey);

$payload = $request->getHttpPayload();
if (!empty($payload)) {
Expand Down
26 changes: 26 additions & 0 deletions src/ClientInterface.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<?php

declare(strict_types=1);

namespace GeminiAPI;

use GeminiAPI\Enums\ModelName;
use GeminiAPI\Requests\CountTokensRequest;
use GeminiAPI\Requests\GenerateContentRequest;
use GeminiAPI\Responses\CountTokensResponse;
use GeminiAPI\Responses\GenerateContentResponse;
use GeminiAPI\Responses\ListModelsResponse;

/**
* @since v1.1.0
*/
interface ClientInterface
{
public const API_KEY_HEADER_NAME = 'x-goog-api-key';

public function countTokens(CountTokensRequest $request): CountTokensResponse;
public function generateContent(GenerateContentRequest $request): GenerateContentResponse;
public function generativeModel(ModelName $modelName): GenerativeModel;
public function listModels(): ListModelsResponse;
public function withBaseUrl(string $baseUrl): self;
}
37 changes: 22 additions & 15 deletions tests/Unit/ClientTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace GeminiAPI\Tests\Unit;

use GeminiAPI\Client;
use GeminiAPI\ClientInterface as GeminiAPIClientInterface;
use GeminiAPI\Enums\ModelName;
use GeminiAPI\GenerativeModel;
use GeminiAPI\Requests\CountTokensRequest;
Expand All @@ -14,7 +15,7 @@
use GuzzleHttp\Psr7\Response;
use GuzzleHttp\Psr7\Utils;
use PHPUnit\Framework\TestCase;
use Psr\Http\Client\ClientInterface;
use Psr\Http\Client\ClientInterface as HttpClientInterface;
use Psr\Http\Message\RequestFactoryInterface;
use Psr\Http\Message\StreamFactoryInterface;

Expand All @@ -24,7 +25,7 @@ public function testConstructor()
{
$client = new Client(
'test-api-key',
$this->createMock(ClientInterface::class),
$this->createMock(HttpClientInterface::class),
);
self::assertInstanceOf(Client::class, $client);
}
Expand All @@ -33,7 +34,7 @@ public function testWithBaseUrl()
{
$client = new Client(
'test-api-key',
$this->createMock(ClientInterface::class),
$this->createMock(HttpClientInterface::class),
);
$client = $client->withBaseUrl('test-base-url');
self::assertInstanceOf(Client::class, $client);
Expand All @@ -43,7 +44,7 @@ public function testGeminiPro()
{
$client = new Client(
'test-api-key',
$this->createMock(ClientInterface::class),
$this->createMock(HttpClientInterface::class),
);
$model = $client->geminiPro();
self::assertInstanceOf(GenerativeModel::class, $model);
Expand All @@ -54,7 +55,7 @@ public function testGeminiProVision()
{
$client = new Client(
'test-api-key',
$this->createMock(ClientInterface::class),
$this->createMock(HttpClientInterface::class),
);
$model = $client->geminiProVision();
self::assertInstanceOf(GenerativeModel::class, $model);
Expand All @@ -65,7 +66,7 @@ public function testGenerativeModel()
{
$client = new Client(
'test-api-key',
$this->createMock(ClientInterface::class),
$this->createMock(HttpClientInterface::class),
);
$model = $client->generativeModel(ModelName::Embedding);
self::assertInstanceOf(GenerativeModel::class, $model);
Expand All @@ -76,7 +77,7 @@ public function testGenerateContent()
{
$httpRequest = new Request(
'POST',
'https://generativelanguage.googleapis.com/v1/models/gemini-pro:generateContent?key=test-api-key',
'https://generativelanguage.googleapis.com/v1/models/gemini-pro:generateContent',
);
$httpResponse = new Response(
body: <<<BODY
Expand Down Expand Up @@ -115,17 +116,19 @@ public function testGenerateContent()
$requestFactory = $this->createMock(RequestFactoryInterface::class);
$requestFactory->expects(self::once())
->method('createRequest')
->with('POST', 'https://generativelanguage.googleapis.com/v1/models/gemini-pro:generateContent?key=test-api-key')
->with('POST', 'https://generativelanguage.googleapis.com/v1/models/gemini-pro:generateContent')
->willReturn($httpRequest);

$httpRequest = $httpRequest->withAddedHeader(GeminiAPIClientInterface::API_KEY_HEADER_NAME, 'test-api-key');

$stream = Utils::streamFor('{"model":"models\/gemini-pro","contents":[{"parts":[{"text":"this is a text"}],"role":"user"}]}');
$streamFactory = $this->createMock(StreamFactoryInterface::class);
$streamFactory->expects(self::once())
->method('createStream')
->with('{"model":"models\/gemini-pro","contents":[{"parts":[{"text":"this is a text"}],"role":"user"}]}')
->willReturn($stream);

$httpClient = $this->createMock(ClientInterface::class);
$httpClient = $this->createMock(HttpClientInterface::class);
$httpClient->expects(self::once())
->method('sendRequest')
->with($httpRequest->withBody($stream))
Expand All @@ -149,7 +152,7 @@ public function testCountTokens()
{
$httpRequest = new Request(
'POST',
'https://generativelanguage.googleapis.com/v1/models/gemini-pro:countTokens?key=test-api-key',
'https://generativelanguage.googleapis.com/v1/models/gemini-pro:countTokens',
);
$httpResponse = new Response(
body: <<<BODY
Expand All @@ -161,17 +164,19 @@ public function testCountTokens()
$requestFactory = $this->createMock(RequestFactoryInterface::class);
$requestFactory->expects(self::once())
->method('createRequest')
->with('POST', 'https://generativelanguage.googleapis.com/v1/models/gemini-pro:countTokens?key=test-api-key')
->with('POST', 'https://generativelanguage.googleapis.com/v1/models/gemini-pro:countTokens')
->willReturn($httpRequest);

$httpRequest = $httpRequest->withAddedHeader(GeminiAPIClientInterface::API_KEY_HEADER_NAME, 'test-api-key');

$stream = Utils::streamFor('{"model":"models\/gemini-pro","contents":[{"parts":[{"text":"this is a text"}],"role":"user"}]}');
$streamFactory = $this->createMock(StreamFactoryInterface::class);
$streamFactory->expects(self::once())
->method('createStream')
->with('{"model":"models\/gemini-pro","contents":[{"parts":[{"text":"this is a text"}],"role":"user"}]}')
->willReturn($stream);

$httpClient = $this->createMock(ClientInterface::class);
$httpClient = $this->createMock(HttpClientInterface::class);
$httpClient->expects(self::once())
->method('sendRequest')
->with($httpRequest->withBody($stream))
Expand All @@ -195,7 +200,7 @@ public function testListModels()
{
$httpRequest = new Request(
'GET',
'https://generativelanguage.googleapis.com/v1/models?key=test-api-key',
'https://generativelanguage.googleapis.com/v1/models',
);
$httpResponse = new Response(
body: <<<BODY
Expand Down Expand Up @@ -238,10 +243,12 @@ public function testListModels()
$requestFactory = $this->createMock(RequestFactoryInterface::class);
$requestFactory->expects(self::once())
->method('createRequest')
->with('GET', 'https://generativelanguage.googleapis.com/v1/models?key=test-api-key')
->with('GET', 'https://generativelanguage.googleapis.com/v1/models')
->willReturn($httpRequest);

$httpClient = $this->createMock(ClientInterface::class);
$httpRequest = $httpRequest->withAddedHeader(GeminiAPIClientInterface::API_KEY_HEADER_NAME, 'test-api-key');

$httpClient = $this->createMock(HttpClientInterface::class);
$httpClient->expects(self::once())
->method('sendRequest')
->with($httpRequest)
Expand Down

0 comments on commit aa176c7

Please sign in to comment.