Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DisjunctionMaxQuery shouldn't depend on disjunct order for equals checks #783

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/Lucene.Net.Tests.QueryParser/Xml/TestParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,21 @@ public void TestBooleanQueryXML()
public void TestDisjunctionMaxQueryXML()
{
Query q = Parse("DisjunctionMaxQuery.xml");
assertTrue(q is DisjunctionMaxQuery);
DisjunctionMaxQuery d = (DisjunctionMaxQuery)q;
assertEquals(0.0f, d.TieBreakerMultiplier, 0.0001f);
assertEquals(2, d.Disjuncts.size());
DisjunctionMaxQuery ndq = (DisjunctionMaxQuery)d.Disjuncts[1];
assertEquals(1.2f, ndq.TieBreakerMultiplier, 0.0001f);
assertEquals(1, ndq.Disjuncts.size());
// assertTrue(q is DisjunctionMaxQuery);
// DisjunctionMaxQuery d = (DisjunctionMaxQuery)q;
// assertEquals(0.0f, d.TieBreakerMultiplier, 0.0001f);
// assertEquals(2, d.Disjuncts.size());
// DisjunctionMaxQuery ndq = (DisjunctionMaxQuery)d.Disjuncts[1];
// assertEquals(1.2f, ndq.TieBreakerMultiplier, 0.0001f);
// assertEquals(1, ndq.Disjuncts.size());
Query expected =
new DisjunctionMaxQuery(
new List<Query>{
new TermQuery(new Term("a", "merger")),
new DisjunctionMaxQuery(
new List<Query>{new TermQuery(new Term("b", "verger"))}, 1.2f)},
0.0f);
assertEquals(expected, q);
}

[Test]
Expand Down
13 changes: 13 additions & 0 deletions src/Lucene.Net.Tests/Search/TestDisjunctionMaxQuery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
using System.Globalization;
using Lucene.Net.Documents;
using Lucene.Net.Index.Extensions;
using Lucene.Net.Support;
using NUnit.Framework;
using System.Collections.Generic;
using Assert = Lucene.Net.TestFramework.Assert;
using Console = Lucene.Net.Util.SystemConsole;

Expand Down Expand Up @@ -535,6 +537,17 @@ public virtual void TestBooleanSpanQuery()
directory.Dispose();
}

[Test]
public void TestDisjunctOrderAndEquals()
{
// the order that disjuncts are provided in should not matter for equals() comparisons
Query sub1 = Tq("hed", "albino");
Query sub2 = Tq("hed", "elephant");
Query q1 = new DisjunctionMaxQuery(new List<Query>{sub1, sub2}, 1.0f);
Query q2 = new DisjunctionMaxQuery(new List<Query>{sub2, sub1}, 1.0f);
assertEquals(q1, q2);
}

/// <summary>
/// macro </summary>
protected internal virtual Query Tq(string f, string t)
Expand Down
114 changes: 114 additions & 0 deletions src/Lucene.Net.Tests/Search/TestMultiset.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
using Lucene.Net.Search;
using Lucene.Net.Support;
using Lucene.Net.Util;
using NUnit.Framework;
using J2N.Collections.Generic;

namespace Lucene.Net
{
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for Additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

public class TestMultiset : LuceneTestCase
{
[Test]
public void TestDuplicatesMatter() {
Multiset<int> s1 = new Multiset<int>();
Multiset<int> s2 = new Multiset<int>();
assertEquals(s1.size(), s2.size());
assertEquals(s1, s2);

s1.Add(42);
s2.Add(42);
assertEquals(s1, s2);

s2.Add(42);
assertFalse(s1.equals(s2));

s1.Add(43);
s1.Add(43);
s2.Add(43);
assertEquals(s1.size(), s2.size());
assertFalse(s1.equals(s2));
}

private static Dictionary<T, int> ToCountMap<T>(Multiset<T> set) {
Dictionary<T, int> map = new();
int recomputedSize = 0;

foreach (T element in set) {
Add(map, element);
recomputedSize += 1;
}
assertEquals(set.toString(), recomputedSize, set.size());
return map;
}

private static void Add<T>(Dictionary<T, int> map, T element)
{
map.TryGetValue(element, out int value);
map.Put(element, value + 1);
}

private static void Remove<T>(Dictionary<T, int> map, T element) {
if (element is null)
{
return;
}

map.TryGetValue(element, out int cnt);
switch (cnt)
{
case 0:
return;
case 1:
map.Remove(element);
break;
default:
map.Put((T)element, cnt - 1);
break;
}
}

[Test]
public void TestRandom() {
Dictionary<int, int> reference = new();
Multiset<int> multiset = new();
int iters = AtLeast(100);
for (int i = 0; i < iters; ++i) {
int value = Random.Next(10);
switch (Random.Next(10)) {
case 0:
case 1:
case 2:
Remove(reference, value);
multiset.Remove(value);
break;
case 3:
reference.Clear();
multiset.Clear();
break;
default:
Add(reference, value);
multiset.Add(value);
break;
}
assertEquals(reference, ToCountMap(multiset));
}
}
}
}
61 changes: 26 additions & 35 deletions src/Lucene.Net/Search/DisjunctionMaxQuery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ namespace Lucene.Net.Search
/// <para/>
/// Collection initializer note: To create and populate a <see cref="DisjunctionMaxQuery"/>
/// in a single statement, you can use the following example as a guide:
///
///
/// <code>
/// var disjunctionMaxQuery = new DisjunctionMaxQuery(0.1f) {
/// new TermQuery(new Term("field1", "albino")),
/// new TermQuery(new Term("field1", "albino")),
/// new TermQuery(new Term("field2", "elephant"))
/// };
/// </code>
Expand All @@ -61,7 +61,8 @@ public class DisjunctionMaxQuery : Query, IEnumerable<Query>
/// <summary>
/// The subqueries
/// </summary>
private IList<Query> disjuncts = new JCG.List<Query>();
// private IList<Query> disjuncts = new JCG.List<Query>();
private Multiset<Query> disjuncts = new();

/// <summary>
/// Multiple of the non-max disjunct scores added into our final score. Non-zero values support tie-breaking.
Expand Down Expand Up @@ -119,7 +120,7 @@ IEnumerator IEnumerable.GetEnumerator()
}

/// <returns> The disjuncts. </returns>
public virtual IList<Query> Disjuncts => disjuncts;
public virtual Multiset<Query> Disjuncts => disjuncts;

/// <returns> Tie breaker value for multiple matches. </returns>
public virtual float TieBreakerMultiplier => tieBreakerMultiplier;
Expand Down Expand Up @@ -244,9 +245,12 @@ public override Weight CreateWeight(IndexSearcher searcher)
public override Query Rewrite(IndexReader reader)
{
int numDisjunctions = disjuncts.Count;
var it = disjuncts.GetEnumerator();
if (numDisjunctions == 1)
{
Query singleton = disjuncts[0];
it.MoveNext();
Query singleton = it.Current;
it.Dispose();
Query result = singleton.Rewrite(reader);
if (Boost != 1.0f)
{
Expand All @@ -258,28 +262,17 @@ public override Query Rewrite(IndexReader reader)
}
return result;
}
DisjunctionMaxQuery clone = null;
for (int i = 0; i < numDisjunctions; i++)
// DisjunctionMaxQuery clone = null;
bool actuallyRewritten = false;
IList<Query> rewrittenDisjuncts = new JCG.List<Query>();;
foreach (var sub in disjuncts)
{
Query clause = disjuncts[i];
Query rewrite = clause.Rewrite(reader);
if (rewrite != clause)
{
if (clone is null)
{
clone = (DisjunctionMaxQuery)this.Clone();
}
clone.disjuncts[i] = rewrite;
}
}
if (clone != null)
{
return clone;
}
else
{
return this;
Query rewrittenSub = sub.Rewrite(reader);
actuallyRewritten |= !rewrittenSub.Equals(sub);
rewrittenDisjuncts.Add(rewrittenSub);
}

return actuallyRewritten ? new DisjunctionMaxQuery(rewrittenDisjuncts, tieBreakerMultiplier) : this;
}

/// <summary>
Expand All @@ -288,7 +281,7 @@ public override Query Rewrite(IndexReader reader)
public override object Clone()
{
DisjunctionMaxQuery clone = (DisjunctionMaxQuery)base.Clone();
clone.disjuncts = new JCG.List<Query>(this.disjuncts);
clone.disjuncts = new Multiset<Query>(this.disjuncts);
return clone;
}

Expand All @@ -313,10 +306,8 @@ public override string ToString(string field)
{
StringBuilder buffer = new StringBuilder();
buffer.Append('(');
int numDisjunctions = disjuncts.Count;
for (int i = 0; i < numDisjunctions; i++)
foreach (var subquery in disjuncts)
{
Query subquery = disjuncts[i];
if (subquery is BooleanQuery) // wrap sub-bools in parens
{
buffer.Append('(');
Expand All @@ -327,11 +318,11 @@ public override string ToString(string field)
{
buffer.Append(subquery.ToString(field));
}
if (i != numDisjunctions - 1)
{
buffer.Append(" | ");
}

buffer.Append(" | ");
}

buffer.Remove(buffer.Length - 3, 3);
buffer.Append(')');
if (tieBreakerMultiplier != 0.0f)
{
Expand Down Expand Up @@ -368,8 +359,8 @@ public override bool Equals(object o)
/// <returns> the hash code </returns>
public override int GetHashCode()
{
return J2N.BitConversion.SingleToInt32Bits(Boost)
+ J2N.BitConversion.SingleToInt32Bits(tieBreakerMultiplier)
return J2N.BitConversion.SingleToInt32Bits(Boost)
+ J2N.BitConversion.SingleToInt32Bits(tieBreakerMultiplier)
+ disjuncts.GetHashCode();
}
}
Expand Down
Loading