From fae388a21a6e54417454b82fc7a5e1ca6c90ff1b Mon Sep 17 00:00:00 2001 From: Masajiro Iwasaki Date: Thu, 31 Oct 2024 09:17:38 +0900 Subject: [PATCH] resolve the issue #171 --- VERSION | 2 +- lib/NGT/Common.h | 2 ++ lib/NGT/Index.cpp | 36 ++++++++++++++++++------ lib/NGT/Index.h | 10 +++---- lib/NGT/NGTQ/QbgCli.cpp | 3 +- lib/NGT/PrimitiveComparator.h | 53 +++++++++++++++++++++++------------ 6 files changed, 73 insertions(+), 33 deletions(-) diff --git a/VERSION b/VERSION index 276cbf9..f90b1af 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.3.0 +2.3.2 diff --git a/lib/NGT/Common.h b/lib/NGT/Common.h index e45fc14..dd91742 100644 --- a/lib/NGT/Common.h +++ b/lib/NGT/Common.h @@ -59,6 +59,7 @@ namespace NGT { class quint8 { public: + quint8(){} quint8(uint8_t v):value(v){} quint8 &operator=(uint8_t v) { value = v; return *this; } operator uint8_t() const { return value; } @@ -67,6 +68,7 @@ namespace NGT { }; class qsint8 { public: + qsint8(){} qsint8(int8_t v):value(v){} qsint8 &operator=(int8_t v) { value = v; return *this; } operator int8_t() const { return value; } diff --git a/lib/NGT/Index.cpp b/lib/NGT/Index.cpp index a09f51d..abbed55 100644 --- a/lib/NGT/Index.cpp +++ b/lib/NGT/Index.cpp @@ -142,7 +142,17 @@ NGT::Index::createGraphAndTree(const string &database, NGT::Property &prop, cons StdOstreamRedirector redirector(redirect); redirector.begin(); try { - loadAndCreateIndex(*idx, database, dataFile, prop.threadPoolSize, dataSize); + if (idx->getObjectSpace().isQintObjectType()) { + idx->saveIndex(database); + idx->close(); + auto append = true; + auto refinement = false; + if (!dataFile.empty()) { + appendFromTextObjectFile(database, dataFile, dataSize, append, refinement, prop.threadPoolSize); + } + } else { + loadAndCreateIndex(*idx, database, dataFile, prop.threadPoolSize, dataSize); + } } catch(Exception &err) { delete idx; redirector.end(); @@ -169,7 +179,17 @@ NGT::Index::createGraph(const string &database, NGT::Property &prop, const strin StdOstreamRedirector redirector(redirect); redirector.begin(); try { - loadAndCreateIndex(*idx, database, dataFile, prop.threadPoolSize, dataSize); + if (idx->getObjectSpace().isQintObjectType()) { + idx->saveIndex(database); + idx->close(); + auto append = true; + auto refinement = false; + if (!dataFile.empty()) { + appendFromTextObjectFile(database, dataFile, dataSize, append, refinement, prop.threadPoolSize); + } + } else { + loadAndCreateIndex(*idx, database, dataFile, prop.threadPoolSize, dataSize); + } } catch(Exception &err) { delete idx; redirector.end(); @@ -248,10 +268,10 @@ NGT::Index::append(const string &database, const float *data, size_t dataSize, s } void -NGT::Index::appendFromRefinementObjectFile(const std::string &indexPath) { +NGT::Index::appendFromRefinementObjectFile(const std::string &indexPath, size_t threadSize) { NGT::Index index(indexPath); index.appendFromRefinementObjectFile(); - index.createIndex(); + index.createIndex(threadSize); index.save(); index.close(); } @@ -439,12 +459,12 @@ NGT::Index::insertFromRefinementObjectFile() { void NGT::Index::appendFromTextObjectFile(const std::string &indexPath, const std::string &data, size_t dataSize, - bool append, bool refinement) { + bool append, bool refinement, size_t threadSize) { //#define APPEND_TEST NGT::Index index(indexPath); index.appendFromTextObjectFile(data, dataSize, append, refinement); - index.createIndex(); + index.createIndex(threadSize); index.save(); index.close(); } @@ -612,10 +632,10 @@ NGT::Index::appendFromTextObjectFile(const std::string &data, size_t dataSize, b void NGT::Index::appendFromBinaryObjectFile(const std::string &indexPath, const std::string &data, - size_t dataSize, bool append, bool refinement) { + size_t dataSize, bool append, bool refinement, size_t threadSize) { NGT::Index index(indexPath); index.appendFromBinaryObjectFile(data, dataSize, append, refinement); - index.createIndex(); + index.createIndex(threadSize); index.save(); index.close(); } diff --git a/lib/NGT/Index.h b/lib/NGT/Index.h index df73076..276e3e5 100644 --- a/lib/NGT/Index.h +++ b/lib/NGT/Index.h @@ -552,14 +552,14 @@ namespace NGT { #endif static void append(const std::string &index, const std::string &dataFile, size_t threadSize, size_t dataSize); static void append(const std::string &index, const float *data, size_t dataSize, size_t threadSize); - static void appendFromRefinementObjectFile(const std::string &index); + static void appendFromRefinementObjectFile(const std::string &index, size_t threadSize = 0); void appendFromRefinementObjectFile(); void insertFromRefinementObjectFile(); - static void appendFromTextObjectFile(const std::string &index, const std::string &data, - size_t dataSize, bool append = true, bool refinement = false); + static void appendFromTextObjectFile(const std::string &index, const std::string &data, size_t dataSize, + bool append = true, bool refinement = false, size_t threadSize = 0); void appendFromTextObjectFile(const std::string &data, size_t dataSize, bool append = true, bool refinement = false); - static void appendFromBinaryObjectFile(const std::string &index, const std::string &data, - size_t dataSize, bool append = true, bool refinement = false); + static void appendFromBinaryObjectFile(const std::string &index, const std::string &data, size_t dataSize, + bool append = true, bool refinement = false, size_t threadSize = 0); void appendFromBinaryObjectFile(const std::string &data, size_t dataSize, bool apend = true, bool refinement = false); static void remove(const std::string &database, std::vector &objects, bool force = false); static void exportIndex(const std::string &database, const std::string &file); diff --git a/lib/NGT/NGTQ/QbgCli.cpp b/lib/NGT/NGTQ/QbgCli.cpp index 17ad2b7..c8e74f6 100644 --- a/lib/NGT/NGTQ/QbgCli.cpp +++ b/lib/NGT/NGTQ/QbgCli.cpp @@ -97,7 +97,8 @@ class QbgCliBuildParameters : public QBG::BuildParameters { transform(clusterDataType.begin(), clusterDataType.end(), clusterDataType.begin(), ::tolower); if (clusterDataType == "-" || clusterDataType == "pq4") { creation.localClusterDataType = NGTQ::ClusterDataTypePQ4; - } else if (clusterDataType == "sqsu8" || clusterDataType == "sqs8" || clusterDataType == "sq8") { + } else if (clusterDataType == "sqsu8" || clusterDataType == "sqs8" || clusterDataType == "sq8" || + clusterDataType == "qsu8" || clusterDataType == "qs8") { creation.localClusterDataType = NGTQ::ClusterDataTypeSQSU8; } else if (clusterDataType == "nq") { creation.localClusterDataType = NGTQ::ClusterDataTypeNQ; diff --git a/lib/NGT/PrimitiveComparator.h b/lib/NGT/PrimitiveComparator.h index 669fead..1836db7 100644 --- a/lib/NGT/PrimitiveComparator.h +++ b/lib/NGT/PrimitiveComparator.h @@ -111,21 +111,21 @@ namespace NGT { #if defined(NGT_NO_AVX) template inline static double compareL2(const OBJECT_TYPE *a, const OBJECT_TYPE *b, size_t size) { - const OBJECT_TYPE *last = a + size; - const OBJECT_TYPE *lastgroup = last - 3; + auto *last = a + size; + auto *lastgroup = last - 3; COMPARE_TYPE diff0, diff1, diff2, diff3; double d = 0.0; while (a < lastgroup) { - diff0 = static_cast(a[0] - b[0]); - diff1 = static_cast(a[1] - b[1]); - diff2 = static_cast(a[2] - b[2]); - diff3 = static_cast(a[3] - b[3]); + diff0 = static_cast(a[0]) - b[0]; + diff1 = static_cast(a[1]) - b[1]; + diff2 = static_cast(a[2]) - b[2]; + diff3 = static_cast(a[3]) - b[3]; d += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3; a += 4; b += 4; } while (a < last) { - diff0 = static_cast(*a++ - *b++); + diff0 = static_cast(*a++) - static_cast(*b++); d += diff0 * diff0; } return sqrt(static_cast(d)); @@ -148,6 +148,9 @@ namespace NGT { return compareL2(a, b, size); } #endif + inline static double compareL2(const quint8 *a, const quint8 *b, size_t size) { + return compareL2(a, b, size); + } #else inline static double compareL2(const float *a, const float *b, size_t size) { const float *last = a + size; @@ -407,7 +410,7 @@ namespace NGT { inline static double compareL2(const qsint8 *a, const quint8 *b, size_t size) { NGTThrowException("Not supported."); - return 0.00; + return 0.0; } template @@ -422,15 +425,15 @@ namespace NGT { template static double compareL1(const OBJECT_TYPE *a, const OBJECT_TYPE *b, size_t size) { - const OBJECT_TYPE *last = a + size; - const OBJECT_TYPE *lastgroup = last - 3; + auto *last = a + size; + auto *lastgroup = last - 3; COMPARE_TYPE diff0, diff1, diff2, diff3; double d = 0.0; while (a < lastgroup) { - diff0 = (COMPARE_TYPE)(a[0] - b[0]); - diff1 = (COMPARE_TYPE)(a[1] - b[1]); - diff2 = (COMPARE_TYPE)(a[2] - b[2]); - diff3 = (COMPARE_TYPE)(a[3] - b[3]); + diff0 = (COMPARE_TYPE)(a[0]) - b[0]; + diff1 = (COMPARE_TYPE)(a[1]) - b[1]; + diff2 = (COMPARE_TYPE)(a[2]) - b[2]; + diff3 = (COMPARE_TYPE)(a[3]) - b[3]; d += absolute(diff0) + absolute(diff1) + absolute(diff2) + absolute(diff3); a += 4; b += 4; @@ -464,6 +467,12 @@ namespace NGT { return compareL1(a, b, size); } #endif + inline static double compareL1(const quint8 *a, const quint8 *b, size_t size) { + return compareL1(a, b, size); + } + inline static double compareL1(const qsint8 *a, const qsint8 *b, size_t size) { + return compareL1(a, b, size); + } #else inline static double compareL1(const float *a, const float *b, size_t size) { __m256 sum = _mm256_setzero_ps(); @@ -732,6 +741,14 @@ namespace NGT { return sum; } + inline static double compareDotProduct(const qsint8 *a, const quint8 *b, size_t size) { + double sum = 0.0; + for (size_t loc = 0; loc < size; loc++) { + sum += static_cast(a[loc]) * static_cast(b[loc]); + } + return sum; + } + template inline static double compareCosine(const OBJECT_TYPE *a, const OBJECT_TYPE *b, size_t size) { double normA = 0.0; @@ -1153,6 +1170,7 @@ namespace NGT { inline static double compareCosine(const qsint8 *a, const qsint8 *b, size_t size) { return compareCosine(reinterpret_cast(a), reinterpret_cast(b), size); } +#endif // #if defined(NGT_NO_AVX) inline static double compareNormalizedCosineSimilarity(const float *a, const float *b, size_t size) { auto v = 1.0 - compareDotProduct(a, b, size); @@ -1182,7 +1200,6 @@ namespace NGT { auto v = max - compareDotProduct(a, b, size); return v; } -#endif // #if defined(NGT_NO_AVX) template inline static double compareAngleDistance(const OBJECT_TYPE *a, const OBJECT_TYPE *b, size_t size) { @@ -1512,14 +1529,14 @@ namespace NGT { class L1Qsint8 { public: inline static double compare(const void *a, const void *b, size_t size) { - NGTThrowException("Not supported."); + return PrimitiveComparator::compareL1((const qsint8*)a, (const qsint8*)b, size); } }; class CosineSimilarityQsint8 { public: inline static double compare(const void *a, const void *b, size_t size) { - NGTThrowException("Not supported."); + return PrimitiveComparator::compareCosineSimilarity((const qsint8*)a, (const qsint8*)b, size); } }; @@ -1564,7 +1581,7 @@ namespace NGT { class NormalizedCosineSimilarityQsint8 { public: inline static double compare(const void *a, const void *b, size_t size) { - float max = 127.0 * 127.0 * size; + float max = 127.0 * 255.0; auto d = max - PrimitiveComparator::compareDotProduct((const qsint8*)a, (const qsint8*)b, size); return d; }