Skip to content

Commit

Permalink
Refactoring around Query and tests for it.
Browse files Browse the repository at this point in the history
  • Loading branch information
cincuranet committed Oct 1, 2024
1 parent dabb4e3 commit 2c1f1a4
Show file tree
Hide file tree
Showing 11 changed files with 184 additions and 43 deletions.
8 changes: 4 additions & 4 deletions ChromaDB.Client.Tests/CollectionClientGetTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -283,24 +283,24 @@ public async Task GetWhereOperatorIncludeAll()
}

[Test]
public async Task GetWhereDocumentIncludeDocument()
public async Task GetWhereDocumentIncludeDocuments()
{
using var httpClient = new ChromaDBHttpClient(ConfigurationOptions);
var client = await Init(httpClient);
var result = await client.Get(new CollectionGetRequest()
{
WhereDocument = new Dictionary<string, object> { { "$not_contains", Doc2[^1] } },
Include = [],
Include = ["documents"],
});
Assert.That(result.Success, Is.True);
Assert.That(result.Data!.Count, Is.EqualTo(1));
Assert.That(result.Data![0].Id, Is.EqualTo(Id1));
Assert.That(result.Data![0].Embeddings, Is.Null);
Assert.That(result.Data![0].Metadata, Is.Null);
Assert.That(result.Data![0].Document, Is.Null);
Assert.That(result.Data![0].Document, Is.EqualTo(Doc1));
}

static readonly string Id1 = "Id1";
static readonly string Id1 = "id1";
static readonly string Id2 = "id2";
static readonly List<float> Embeddings1 = [1, 2, 3];
static readonly List<float> Embeddings2 = [1.4f, 1.5f, 99.33f];
Expand Down
103 changes: 103 additions & 0 deletions ChromaDB.Client.Tests/CollectionClientQueryTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
using ChromaDB.Client.Models.Requests;
using ChromaDB.Client.Services.Implementations;
using NUnit.Framework;

namespace ChromaDB.Client.Tests;

[TestFixture]
public class CollectionClientQueryTests : ChromaDBTestsBase
{
[Test]
public async Task SimpleQuerySingle()
{
using var httpClient = new ChromaDBHttpClient(ConfigurationOptions);
var client = await Init(httpClient);
var result = await client.Query(new CollectionQueryRequest()
{
QueryEmbeddings = [Embeddings1],
Include = ["distances", "embeddings"],
});
Assert.That(result.Success, Is.True);
Assert.That(result.Data!.Count, Is.EqualTo(1));
Assert.That(result.Data![0].Count, Is.EqualTo(2));
Assert.That(result.Data![0].Select(x => x.Distance), Has.Some.Not.EqualTo(0));
Assert.That(result.Data![0][0].Embeddings, Is.Not.Null.And.Not.Empty);
Assert.That(result.Data![0][1].Embeddings, Is.Not.Null.And.Not.Empty);
}

[Test]
public async Task SimpleQueryMultiple()
{
using var httpClient = new ChromaDBHttpClient(ConfigurationOptions);
var client = await Init(httpClient);
var result = await client.Query(new CollectionQueryRequest()
{
QueryEmbeddings = [Embeddings1, Embeddings2],
Include = ["distances", "embeddings"],
});
Assert.That(result.Success, Is.True);
Assert.That(result.Data!.Count, Is.EqualTo(2));
Assert.That(result.Data![0].Count, Is.EqualTo(2));
Assert.That(result.Data![0].Select(x => x.Distance), Has.Some.Not.EqualTo(0));
Assert.That(result.Data![0][0].Embeddings, Is.Not.Null.And.Not.Empty);
Assert.That(result.Data![0][1].Embeddings, Is.Not.Null.And.Not.Empty);
Assert.That(result.Data![1].Count, Is.EqualTo(2));
Assert.That(result.Data![1].Select(x => x.Distance), Has.Some.Not.EqualTo(0));
Assert.That(result.Data![1][0].Embeddings, Is.Not.Null.And.Not.Empty);
Assert.That(result.Data![1][1].Embeddings, Is.Not.Null.And.Not.Empty);
}

[Test]
public async Task QuerySingleNResults1()
{
using var httpClient = new ChromaDBHttpClient(ConfigurationOptions);
var client = await Init(httpClient);
var result = await client.Query(new CollectionQueryRequest()
{
QueryEmbeddings = [Embeddings1],
Include = ["distances", "embeddings"],
NResults = 1,
});
Assert.That(result.Success, Is.True);
Assert.That(result.Data!.Count, Is.EqualTo(1));
Assert.That(result.Data![0].Count, Is.EqualTo(1));
}

static readonly string Id1 = "id1";
static readonly string Id2 = "id2";
static readonly List<float> Embeddings1 = [1, 2, 3];
static readonly List<float> Embeddings2 = [1.4f, 1.5f, 99.33f];
static readonly string MetadataKey1 = "key1";
static readonly string MetadataKey2 = "key2";
static readonly Dictionary<string, object> Metadata1 = new()
{
{ MetadataKey1, "1" },
{ MetadataKey2, 1 },
};
static readonly Dictionary<string, object> Metadata2 = new()
{
{ MetadataKey1, "2" },
{ MetadataKey2, 2 },
};
static readonly string Doc1 = "Doc1";
static readonly string Doc2 = "Doc2";

async Task<ChromaDBCollectionClient> Init(ChromaDBHttpClient httpClient)
{
var name = $"collection{Random.Shared.Next()}";
var client = new ChromaDBClient(ConfigurationOptions, httpClient);
var collectionResponse = await client.CreateCollection(new CreateCollectionRequest { Name = name });
Assert.That(collectionResponse.Success, Is.True);
var collection = collectionResponse.Data!;
var collectionClient = new ChromaDBCollectionClient(collection, httpClient);
var addResponse = await collectionClient.Add(new CollectionAddRequest()
{
Ids = [Id1, Id2],
Embeddings = [Embeddings1, Embeddings2],
Metadatas = [Metadata1, Metadata2],
Documents = [Doc1, Doc2],
});
Assert.That(addResponse.Success, Is.True);
return collectionClient;
}
}
2 changes: 1 addition & 1 deletion ChromaDB.Client/Common/Mappers/CollectionEntryMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace ChromaDB.Client.Common.Mappers;

public static class CollectionEntryMapper
{
public static List<CollectionEntry> Map(this CollectionEntriesResponse response)
public static List<CollectionEntry> Map(this CollectionEntriesGetResponse response)
{
return response.Ids
.Select((id, i) => new CollectionEntry(id)
Expand Down
24 changes: 24 additions & 0 deletions ChromaDB.Client/Common/Mappers/CollectionQueryEntryMapper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using ChromaDB.Client.Models;
using ChromaDB.Client.Models.Responses;

namespace ChromaDB.Client.Common.Mappers;

public static class CollectionQueryEntryMapper
{
public static List<List<CollectionQueryEntry>> Map(this CollectionEntriesQueryResponse response)
{
return response.Ids
.Select((_, i) => response.Ids[i]
.Select((id, j) => new CollectionQueryEntry(id)
{
Distance = response.Distances[i][j],
Metadata = response.Metadatas?[i][j],
Embeddings = response.Embeddings?[i][j],
Document = response.Documents?[i][j],
Uris = response.Uris?[i][j],
Data = response.Data,
})
.ToList())
.ToList();
}
}
17 changes: 17 additions & 0 deletions ChromaDB.Client/Models/CollectionQueryEntry.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
namespace ChromaDB.Client.Models;

public class CollectionQueryEntry
{
public string Id { get; }
public float Distance { get; init; }
public Dictionary<string, object>? Metadata { get; init; }
public List<float>? Embeddings { get; init; }
public string? Document { get; init; }
public List<string?>? Uris { get; init; }
public dynamic? Data { get; init; }

public CollectionQueryEntry(string id)
{
Id = id;
}
}
21 changes: 6 additions & 15 deletions ChromaDB.Client/Models/Requests/CollectionQueryRequest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,18 @@ namespace ChromaDB.Client.Models.Requests;

public class CollectionQueryRequest
{
[JsonPropertyName("ids")]
public required List<string> Ids { get; init; }
[JsonPropertyName("query_embeddings")]
public List<List<float>>? QueryEmbeddings { get; init; }

[JsonPropertyName("n_results")]
public int NResults { get; init; } = 10;

[JsonPropertyName("where")]
public IDictionary<string, object>? Where { get; init; }

[JsonPropertyName("where_document")]
public IDictionary<string, object>? WhereDocument { get; init; }

[JsonPropertyName("query_embeddings")]
public required List<List<float>> QueryEmbeddings { get; init; }

[JsonPropertyName("sort")]
public string? Sort { get; init; }

[JsonPropertyName("limit")]
public int? Limit { get; init; }

[JsonPropertyName("offset")]
public int? Offset { get; init; }

[JsonPropertyName("include")]
public required List<string> Include { get; init; }
public List<string> Include { get; init; } = ["metadatas", "documents", "distances"];
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

namespace ChromaDB.Client.Models.Responses;

public class CollectionEntriesResponse
public class CollectionEntriesGetResponse
{
[JsonPropertyName("ids")]
public required List<string> Ids { get; init; }
Expand Down
19 changes: 11 additions & 8 deletions ChromaDB.Client/Models/Responses/CollectionEntriesQueryResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,21 @@ public class CollectionEntriesQueryResponse
[JsonPropertyName("ids")]
public required List<List<string>> Ids { get; init; }

[JsonPropertyName("embeddings")]
public required List<List<List<float>>> Embeddings { get; init; }
[JsonPropertyName("distances")]
public required List<List<float>> Distances { get; init; }

[JsonPropertyName("metadatas")]
public required List<List<Dictionary<string, object>>> Metadatas { get; init; }
public required List<List<Dictionary<string, object>>>? Metadatas { get; init; }

[JsonPropertyName("embeddings")]
public required List<List<List<float>>>? Embeddings { get; init; }

[JsonPropertyName("documents")]
public required List<List<string?>>? Documents { get; init; }

[JsonPropertyName("uris")]
public required List<List<string?>> Uris { get; init; }
public required List<List<List<string?>>>? Uris { get; init; }

[JsonPropertyName("data")]
public required dynamic Data { get; init; }

[JsonPropertyName("distances")]
public required List<List<float>> Distances { get; init; }
public required dynamic? Data { get; init; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,18 @@ public async Task<BaseResponse<List<CollectionEntry>>> Get(CollectionGetRequest
{
RequestQueryParams requestParams = new RequestQueryParams()
.Insert("{collection_id}", _collection.Id);
BaseResponse<CollectionEntriesResponse> response = await _httpClient.Post<CollectionGetRequest, CollectionEntriesResponse>("collections/{collection_id}/get", request, requestParams);
BaseResponse<CollectionEntriesGetResponse> response = await _httpClient.Post<CollectionGetRequest, CollectionEntriesGetResponse>("collections/{collection_id}/get", request, requestParams);
List<CollectionEntry> entries = response.Data?.Map() ?? [];
return new BaseResponse<List<CollectionEntry>>(entries, response.StatusCode, response.ReasonPhrase);
}

public async Task<BaseResponse<CollectionEntriesQueryResponse>> Query(CollectionQueryRequest request)
public async Task<BaseResponse<List<List<CollectionQueryEntry>>>> Query(CollectionQueryRequest request)
{
RequestQueryParams requestParams = new RequestQueryParams()
.Insert("{collection_id}", _collection.Id);
return await _httpClient.Post<CollectionQueryRequest, CollectionEntriesQueryResponse>("collections/{collection_id}/query", request, requestParams);
var response = await _httpClient.Post<CollectionQueryRequest, CollectionEntriesQueryResponse>("collections/{collection_id}/query", request, requestParams);
List<List<CollectionQueryEntry>> entries = response.Data?.Map() ?? [];
return new BaseResponse<List<List<CollectionQueryEntry>>>(entries, response.StatusCode, response.ReasonPhrase);
}

public async Task<BaseResponse<BaseResponse.None>> Add(CollectionAddRequest request)
Expand Down Expand Up @@ -75,7 +77,7 @@ public async Task<BaseResponse<List<CollectionEntry>>> Peek(CollectionPeekReques
{
RequestQueryParams requestParams = new RequestQueryParams()
.Insert("{collection_id}", _collection.Id);
BaseResponse<CollectionEntriesResponse> response = await _httpClient.Post<CollectionPeekRequest, CollectionEntriesResponse>("collections/{collection_id}/get", request, requestParams);
BaseResponse<CollectionEntriesGetResponse> response = await _httpClient.Post<CollectionPeekRequest, CollectionEntriesGetResponse>("collections/{collection_id}/get", request, requestParams);
List<CollectionEntry> entries = response.Data?.Map() ?? [];
return new BaseResponse<List<CollectionEntry>>(entries, response.StatusCode, response.ReasonPhrase);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using ChromaDB.Client.Models;
using ChromaDB.Client.Models.Requests;
using ChromaDB.Client.Models.Responses;

namespace ChromaDB.Client.Services.Interfaces;

Expand All @@ -9,7 +8,7 @@ public interface IChromaDBCollectionClient
Collection Collection { get; }

Task<BaseResponse<List<CollectionEntry>>> Get(CollectionGetRequest request);
Task<BaseResponse<CollectionEntriesQueryResponse>> Query(CollectionQueryRequest request);
Task<BaseResponse<List<List<CollectionQueryEntry>>>> Query(CollectionQueryRequest request);
Task<BaseResponse<BaseResponse.None>> Add(CollectionAddRequest request);
Task<BaseResponse<BaseResponse.None>> Update(CollectionUpdateRequest request);
Task<BaseResponse<BaseResponse.None>> Upsert(CollectionUpsertRequest request);
Expand Down
18 changes: 10 additions & 8 deletions Samples/ChromaDB.Client.Sample/Program.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using ChromaDB.Client;
using ChromaDB.Client.Models;
using ChromaDB.Client.Models.Requests;
using ChromaDB.Client.Models.Responses;
using ChromaDB.Client.Services.Implementations;
using ChromaDB.Client.Services.Interfaces;

Expand All @@ -24,23 +23,26 @@
{
foreach (var entry in getResponse.Data!)
{
Console.WriteLine(entry.Id);
Console.WriteLine($"ID: {entry.Id}");
}
}

BaseResponse<CollectionEntriesQueryResponse> queryResponse = await string5Client.Query(new CollectionQueryRequest()
BaseResponse<List<List<CollectionQueryEntry>>> queryResponse = await string5Client.Query(new CollectionQueryRequest()
{
Ids = ["340a36ad-c38a-406c-be38-250174aee5a4"],
Include = ["metadatas", "documents", "embeddings"],
QueryEmbeddings =
[
[1f, 0.5f, 0f, -0.5f, -1f]
[1f, 0.5f, 0f, -0.5f, -1f],
[1.5f, 0f, 2f, -1f, -1.5f],
],
Include = ["metadatas", "distances"],
});
if (queryResponse.Success)
{
foreach (var id in queryResponse.Data!.Ids)
foreach (var item in queryResponse.Data!)
{
Console.WriteLine(id);
foreach (var entry in item)
{
Console.WriteLine($"ID: {entry.Id} | Distance: {entry.Distance}");
}
}
}

0 comments on commit 2c1f1a4

Please sign in to comment.