From c166253540d81fcd02ce485573ee3562c5ca1ffc Mon Sep 17 00:00:00 2001 From: congqixia Date: Mon, 9 Sep 2024 22:49:06 +0800 Subject: [PATCH] fix: [2.4] Make legacy non-lexicographic branch break swtich (#36126) Cherry-pick from master pr: #36125 Related to #35941 Previous PR: #36034 This patch makes the switch branching logic correct and make the unit test work for cases which does not select the whole dataset. Signed-off-by: Congqi Xia --- internal/core/src/index/StringIndexMarisa.cpp | 4 +- internal/core/unittest/test_string_index.cpp | 47 +++++++++++++++---- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/internal/core/src/index/StringIndexMarisa.cpp b/internal/core/src/index/StringIndexMarisa.cpp index ba714e4e2587d..557b604877958 100644 --- a/internal/core/src/index/StringIndexMarisa.cpp +++ b/internal/core/src/index/StringIndexMarisa.cpp @@ -458,7 +458,6 @@ StringIndexMarisa::Range(std::string value, OpType op) { } ids.push_back(agent.key().id()); } - break; } else { // lexicographic order is not guaranteed, check all values while (trie_.predictive_search(agent)) { @@ -469,6 +468,7 @@ StringIndexMarisa::Range(std::string value, OpType op) { } }; } + break; } case OpType::LessEqual: { if (in_lexico_order) { @@ -480,7 +480,6 @@ StringIndexMarisa::Range(std::string value, OpType op) { } ids.push_back(agent.key().id()); } - break; } else { // lexicographic order is not guaranteed, check all values while (trie_.predictive_search(agent)) { @@ -491,6 +490,7 @@ StringIndexMarisa::Range(std::string value, OpType op) { } }; } + break; } default: PanicInfo( diff --git a/internal/core/unittest/test_string_index.cpp b/internal/core/unittest/test_string_index.cpp index f26a59645cbe1..b58b88d618738 100644 --- a/internal/core/unittest/test_string_index.cpp +++ b/internal/core/unittest/test_string_index.cpp @@ -21,6 +21,7 @@ #include "test_utils/indexbuilder_test_utils.h" #include "test_utils/AssertUtils.h" #include +#include #include "test_utils/storage_test_utils.h" constexpr int64_t nb = 100; @@ -83,39 +84,67 @@ TEST_F(StringIndexMarisaTest, NotIn) { TEST_F(StringIndexMarisaTest, Range) { auto index = milvus::index::CreateStringIndexMarisa(); std::vector strings(nb); + std::vector counts(10); for (int i = 0; i < nb; ++i) { - strings[i] = std::to_string(std::rand() % 10); + int val = std::rand() % 10; + counts[val]++; + strings[i] = std::to_string(val); } index->Build(nb, strings.data()); { + // [0...9] auto bitset = index->Range("0", milvus::OpType::GreaterEqual); ASSERT_EQ(bitset.size(), nb); ASSERT_EQ(Count(bitset), nb); } { - auto bitset = index->Range("90", milvus::OpType::LessThan); + // [5...9] + int expect = std::accumulate(counts.begin() + 5, counts.end(), 0); + auto bitset = index->Range("5", milvus::OpType::GreaterEqual); ASSERT_EQ(bitset.size(), nb); - ASSERT_EQ(Count(bitset), nb); + ASSERT_EQ(Count(bitset), expect); } { - auto bitset = index->Range("9", milvus::OpType::LessEqual); + // [6...9] + int expect = std::accumulate(counts.begin() + 6, counts.end(), 0); + auto bitset = index->Range("5", milvus::OpType::GreaterThan); ASSERT_EQ(bitset.size(), nb); - ASSERT_EQ(Count(bitset), nb); + ASSERT_EQ(Count(bitset), expect); } { - auto bitset = index->Range("0", true, "9", true); + // [0...3] + int expect = std::accumulate(counts.begin(), counts.begin() + 4, 0); + auto bitset = index->Range("4", milvus::OpType::LessThan); ASSERT_EQ(bitset.size(), nb); - ASSERT_EQ(Count(bitset), nb); + ASSERT_EQ(Count(bitset), expect); } { - auto bitset = index->Range("0", true, "90", false); + // [0...4] + int expect = std::accumulate(counts.begin(), counts.begin() + 5, 0); + auto bitset = index->Range("4", milvus::OpType::LessEqual); ASSERT_EQ(bitset.size(), nb); - ASSERT_EQ(Count(bitset), nb); + ASSERT_EQ(Count(bitset), expect); + } + + { + // [2...8] + int expect = std::accumulate(counts.begin() + 2, counts.begin() + 9, 0); + auto bitset = index->Range("2", true, "8", true); + ASSERT_EQ(bitset.size(), nb); + ASSERT_EQ(Count(bitset), expect); + } + + { + // [0...8] + int expect = std::accumulate(counts.begin(), counts.begin() + 9, 0); + auto bitset = index->Range("0", true, "9", false); + ASSERT_EQ(bitset.size(), nb); + ASSERT_EQ(Count(bitset), expect); } }