From 4c55ed496e4e97647b006fee219758cc1dc23663 Mon Sep 17 00:00:00 2001 From: Kathleen DeRusso Date: Wed, 16 Oct 2024 16:15:15 -0400 Subject: [PATCH] Applying Jim's suggested changes --- .../retriever/CompoundRetrieverBuilder.java | 5 + .../search/retriever/RetrieverBuilder.java | 4 + .../retriever/RetrieverBuilderWrapper.java | 136 +++++++++++++ .../retriever/QueryRuleRetrieverBuilder.java | 190 ++++++------------ .../TextSimilarityRankRetrieverBuilder.java | 16 -- 5 files changed, 208 insertions(+), 143 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilderWrapper.java diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index 8b6a2c4e7b078..203f3facfd025 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -163,6 +163,11 @@ public final QueryBuilder topDocsQuery() { throw new IllegalStateException("Should not be called, missing a rewrite?"); } + @Override + public final QueryBuilder explainQuery() { + throw new IllegalStateException("Should not be called, missing a rewrite?"); + } + @Override public final void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { throw new IllegalStateException("Should not be called, missing a rewrite?"); diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java index 882d44adb79c3..fa8d2ee1cbefc 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java @@ -218,6 +218,10 @@ public void setRankDocs(RankDoc[] rankDocs) { this.rankDocs = rankDocs; } + public RankDoc[] getRankDocs() { + return rankDocs; + } + /** * Gets the filters for this retriever. */ diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilderWrapper.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilderWrapper.java new file mode 100644 index 0000000000000..7c9e71dbf41f1 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilderWrapper.java @@ -0,0 +1,136 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.retriever; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; + +/** + * A wrapper that can be used to modify the behaviour of an existing {@link RetrieverBuilder}. + */ +public abstract class RetrieverBuilderWrapper extends RetrieverBuilder { + protected final RetrieverBuilder in; + + protected RetrieverBuilderWrapper(RetrieverBuilder in) { + this.in = in; + } + + protected abstract T clone(RetrieverBuilder sub); + + @Override + public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { + var inRewrite = in.rewrite(ctx); + if (inRewrite != in) { + return clone(inRewrite); + } + return this; + } + + @Override + public QueryBuilder topDocsQuery() { + return in.topDocsQuery(); + } + + @Override + public RetrieverBuilder minScore(Float minScore) { + return in.minScore(minScore); + } + + @Override + public List getPreFilterQueryBuilders() { + return in.preFilterQueryBuilders; + } + + @Override + public ActionRequestValidationException validate( + SearchSourceBuilder source, + ActionRequestValidationException validationException, + boolean allowPartialSearchResults + ) { + return in.validate(source, validationException, allowPartialSearchResults); + } + + @Override + public RetrieverBuilder retrieverName(String retrieverName) { + return in.retrieverName(retrieverName); + } + + @Override + public void setRankDocs(RankDoc[] rankDocs) { + in.setRankDocs(rankDocs); + } + + @Override + public RankDoc[] getRankDocs() { + return in.getRankDocs(); + } + + @Override + public boolean isCompound() { + return in.isCompound(); + } + + @Override + public QueryBuilder explainQuery() { + return in.explainQuery(); + } + + @Override + public Float minScore() { + return in.minScore(); + } + + @Override + public boolean isFragment() { + return in.isFragment(); + } + + @Override + public String toString() { + return in.toString(); + } + + @Override + public String retrieverName() { + return in.retrieverName(); + } + + @Override + public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { + in.extractToSearchSourceBuilder(searchSourceBuilder, compoundUsed); + } + + @Override + public String getName() { + return in.getName(); + } + + @Override + protected void doToXContent(XContentBuilder builder, Params params) throws IOException { + in.doToXContent(builder, params); + } + + @Override + protected boolean doEquals(Object o) { + return in.equals(o); + } + + @Override + protected int doHashCode() { + return in.doHashCode(); + } +} diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java index ada1471f1e694..37a456482bdee 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.application.rules.retriever; +import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.SearchRequest; @@ -22,10 +23,8 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.StoredFieldsContext; import org.elasticsearch.search.rank.RankDoc; -import org.elasticsearch.search.retriever.CompoundRetrieverBuilder; -import org.elasticsearch.search.retriever.RankDocsRetrieverBuilder; -import org.elasticsearch.search.retriever.RetrieverBuilder; -import org.elasticsearch.search.retriever.RetrieverParserContext; +import org.elasticsearch.search.retriever.*; +import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder; import org.elasticsearch.search.sort.ShardDocSortField; @@ -51,7 +50,7 @@ /** * A query rule retriever applies query rules defined in one or more rulesets to the underlying retriever. */ -public final class QueryRuleRetrieverBuilder extends RetrieverBuilder { +public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder { public static final String NAME = "rule"; public static final NodeFeature QUERY_RULE_RETRIEVERS_SUPPORTED = new NodeFeature("query_rule_retriever_supported"); @@ -97,29 +96,30 @@ public static QueryRuleRetrieverBuilder fromXContent(XContentParser parser, Retr private final List rulesetIds; private final Map matchCriteria; - private final CompoundRetrieverBuilder.RetrieverSource subRetriever; - private final int rankWindowSize; - private boolean executed = false; public QueryRuleRetrieverBuilder( List rulesetIds, Map matchCriteria, - RetrieverBuilder subRetriever, + RetrieverBuilder retrieverBuilder, int rankWindowSize ) { - this(rulesetIds, matchCriteria, new CompoundRetrieverBuilder.RetrieverSource(subRetriever, null), rankWindowSize); + super(new ArrayList<>(), rankWindowSize); + this.rulesetIds = rulesetIds; + this.matchCriteria = matchCriteria; + addChild(new QueryRuleRetrieverBuilderWrapper(retrieverBuilder)); } - private QueryRuleRetrieverBuilder( + public QueryRuleRetrieverBuilder( List rulesetIds, Map matchCriteria, - CompoundRetrieverBuilder.RetrieverSource subRetriever, - int rankWindowSize + List retrieverSource, + int rankWindowSize, + String retrieverName ) { - this.subRetriever = subRetriever; + super(retrieverSource, rankWindowSize); this.rulesetIds = rulesetIds; this.matchCriteria = matchCriteria; - this.rankWindowSize = rankWindowSize; + this.retrieverName = retrieverName; } @Override @@ -128,100 +128,21 @@ public String getName() { } @Override - public boolean isCompound() { - return executed == false; - } - - @Override - public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { - if (ctx.getPointInTimeBuilder() == null) { - throw new IllegalStateException("PIT is required"); - } - - if (executed) { - return this; - } - - // Rewrite prefilters - var newPreFilters = rewritePreFilters(ctx); - if (newPreFilters != preFilterQueryBuilders) { - var ret = new QueryRuleRetrieverBuilder(rulesetIds, matchCriteria, subRetriever, rankWindowSize); - ret.preFilterQueryBuilders = newPreFilters; - return ret; - } - - // Rewrite retriever sources - var newRetriever = subRetriever.retriever().rewrite(ctx); - if (newRetriever != subRetriever.retriever()) { - return new QueryRuleRetrieverBuilder(rulesetIds, matchCriteria, newRetriever, rankWindowSize); - } else { - var newSource = subRetriever.source() != null - ? subRetriever.source() - : createSearchSourceBuilder(ctx.getPointInTimeBuilder(), newRetriever); - var rewrittenSource = newSource.rewrite(ctx); - if (rewrittenSource != subRetriever.source()) { - return new QueryRuleRetrieverBuilder( - rulesetIds, - matchCriteria, - new CompoundRetrieverBuilder.RetrieverSource(newRetriever, rewrittenSource), - rankWindowSize - ); - } - } - - // execute searches - final SetOnce results = new SetOnce<>(); - final SearchRequest searchRequest = new SearchRequest().source(subRetriever.source()); - // The can match phase can reorder shards, so we disable it to ensure the stable ordering - searchRequest.setPreFilterShardSize(Integer.MAX_VALUE); - ctx.registerAsyncAction((client, listener) -> { - client.execute(TransportSearchAction.TYPE, searchRequest, new ActionListener<>() { - @Override - public void onResponse(SearchResponse resp) { - var rankDocs = getRankDocs(resp); - results.set(rankDocs); - listener.onResponse(null); - } - - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - }); - }); - - executed = true; - return new RankDocsRetrieverBuilder(rankWindowSize, List.of(this), results::get, newPreFilters); - } - - @Override - public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) { - throw new IllegalStateException("Should not be called, missing a rewrite?"); - } - protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { - var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit) - .trackTotalHits(false) - .storedFields(new StoredFieldsContext(false)) - .size(rankWindowSize); - retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true); - // TODO: ensure that the inner sort is by relevance and throw an error otherwise. - - QueryBuilder query = sourceBuilder.query(); - if (query != null && query instanceof RuleQueryBuilder == false) { - QueryBuilder ruleQuery = new RuleQueryBuilder(query, matchCriteria, rulesetIds); - sourceBuilder.query(ruleQuery); - } + var ret = super.createSearchSourceBuilder(pit, retrieverBuilder); + checkValidSort(ret.sorts()); + ret.query(new RuleQueryBuilder(ret.query(), matchCriteria, rulesetIds)); + return ret; + } - // Record the shard id in the sort result - List> sortBuilders = sourceBuilder.sorts() != null ? new ArrayList<>(sourceBuilder.sorts()) : new ArrayList<>(); + private static void checkValidSort(List> sortBuilders) { if (sortBuilders.isEmpty()) { - sortBuilders.add(new ScoreSortBuilder()); + return; } - sortBuilders.add(new FieldSortBuilder(FieldSortBuilder.SHARD_DOC_FIELD_NAME)); - sourceBuilder.sort(sortBuilders); - return sourceBuilder; + if (sortBuilders.get(0) instanceof ScoreSortBuilder == false) { + throw new IllegalArgumentException("Rule retrievers can only sort documents by relevance score, got: " + sortBuilders); + } } @Override @@ -233,41 +154,56 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept } @Override - public QueryBuilder topDocsQuery() { - QueryBuilder topDocsQuery = subRetriever.source().query(); - if (preFilterQueryBuilders.isEmpty()) { - topDocsQuery.queryName(this.retrieverName); - return topDocsQuery; + protected QueryRuleRetrieverBuilder clone(List newChildRetrievers) { + return new QueryRuleRetrieverBuilder(rulesetIds, matchCriteria, newChildRetrievers, rankWindowSize, retrieverName); + } + + @Override + protected RankDoc[] combineInnerRetrieverResults(List rankResults) { + assert rankResults.size() == 1; + ScoreDoc[] scoreDocs = rankResults.getFirst(); + RankDoc[] rankDocs = new RankDoc[scoreDocs.length]; + for (int i = 0; i < scoreDocs.length; i++) { + ScoreDoc scoreDoc = scoreDocs[i]; + rankDocs[i] = new RankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex); + rankDocs[i].rank = i + 1; } - var ret = new BoolQueryBuilder().filter(topDocsQuery).queryName(this.retrieverName); - preFilterQueryBuilders.stream().forEach(ret::filter); - return subRetriever.source().query(); + return rankDocs; } @Override public boolean doEquals(Object o) { QueryRuleRetrieverBuilder that = (QueryRuleRetrieverBuilder) o; - return Objects.equals(rulesetIds, that.rulesetIds) - && Objects.equals(matchCriteria, that.matchCriteria) - && subRetriever.equals(that.subRetriever); + return super.doEquals(o) && Objects.equals(rulesetIds, that.rulesetIds) && Objects.equals(matchCriteria, that.matchCriteria); } @Override public int doHashCode() { - return Objects.hash(subRetriever, rulesetIds, matchCriteria); + return Objects.hash(super.doHashCode(), rulesetIds, matchCriteria); } - private RankDoc[] getRankDocs(SearchResponse searchResponse) { - int size = searchResponse.getHits().getHits().length; - RankDoc[] docs = new RankDoc[size]; - for (int i = 0; i < size; i++) { - var hit = searchResponse.getHits().getAt(i); - long sortValue = (long) hit.getRawSortValues()[hit.getRawSortValues().length - 1]; - int doc = ShardDocSortField.decodeDoc(sortValue); - int shardRequestIndex = ShardDocSortField.decodeShardRequestIndex(sortValue); - docs[i] = new RankDoc(doc, hit.getScore(), shardRequestIndex); - docs[i].rank = i + 1; + class QueryRuleRetrieverBuilderWrapper extends RetrieverBuilderWrapper { + protected QueryRuleRetrieverBuilderWrapper(RetrieverBuilder in) { + super(in); + } + + @Override + protected QueryRuleRetrieverBuilderWrapper clone(RetrieverBuilder in) { + return new QueryRuleRetrieverBuilderWrapper(in); + } + + @Override + public QueryBuilder topDocsQuery() { + return new RuleQueryBuilder(in.topDocsQuery(), matchCriteria, rulesetIds); + } + + @Override + public QueryBuilder explainQuery() { + return new RankDocsQueryBuilder( + in.getRankDocs(), + new QueryBuilder[] { new RuleQueryBuilder(in.explainQuery(), matchCriteria, rulesetIds) }, + true + ); } - return docs; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index 8bccf6e7d1022..fc78cb8889e83 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -158,12 +158,6 @@ protected RankDoc[] combineInnerRetrieverResults(List rankResults) { return textSimilarityRankDocs; } - @Override - public QueryBuilder explainQuery() { - // the original matching set of the TextSimilarityRank retriever is specified by its nested retriever - return new RankDocsQueryBuilder(rankDocs, new QueryBuilder[] { innerRetrievers.getFirst().retriever().explainQuery() }, true); - } - @Override protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) { var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit) @@ -175,16 +169,6 @@ protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, } retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true); - // apply the pre-filters - if (preFilterQueryBuilders.size() > 0) { - QueryBuilder query = sourceBuilder.query(); - BoolQueryBuilder newQuery = new BoolQueryBuilder(); - if (query != null) { - newQuery.must(query); - } - preFilterQueryBuilders.forEach(newQuery::filter); - sourceBuilder.query(newQuery); - } sourceBuilder.rankBuilder( new TextSimilarityRankBuilder(this.field, this.inferenceId, this.inferenceText, this.rankWindowSize, this.minScore) );