Skip to content

Commit

Permalink
Applying Jim's suggested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
kderusso committed Oct 16, 2024
1 parent d80feee commit 4c55ed4
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 143 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ public final QueryBuilder topDocsQuery() {
throw new IllegalStateException("Should not be called, missing a rewrite?");
}

@Override
public final QueryBuilder explainQuery() {
throw new IllegalStateException("Should not be called, missing a rewrite?");
}

@Override
public final void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
throw new IllegalStateException("Should not be called, missing a rewrite?");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ public void setRankDocs(RankDoc[] rankDocs) {
this.rankDocs = rankDocs;
}

public RankDoc[] getRankDocs() {
return rankDocs;
}

/**
* Gets the filters for this retriever.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.retriever;

import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.List;

/**
* A wrapper that can be used to modify the behaviour of an existing {@link RetrieverBuilder}.
*/
public abstract class RetrieverBuilderWrapper<T extends RetrieverBuilder> extends RetrieverBuilder {
protected final RetrieverBuilder in;

protected RetrieverBuilderWrapper(RetrieverBuilder in) {
this.in = in;
}

protected abstract T clone(RetrieverBuilder sub);

@Override
public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
var inRewrite = in.rewrite(ctx);
if (inRewrite != in) {
return clone(inRewrite);
}
return this;
}

@Override
public QueryBuilder topDocsQuery() {
return in.topDocsQuery();
}

@Override
public RetrieverBuilder minScore(Float minScore) {
return in.minScore(minScore);
}

@Override
public List<QueryBuilder> getPreFilterQueryBuilders() {
return in.preFilterQueryBuilders;
}

@Override
public ActionRequestValidationException validate(
SearchSourceBuilder source,
ActionRequestValidationException validationException,
boolean allowPartialSearchResults
) {
return in.validate(source, validationException, allowPartialSearchResults);
}

@Override
public RetrieverBuilder retrieverName(String retrieverName) {
return in.retrieverName(retrieverName);
}

@Override
public void setRankDocs(RankDoc[] rankDocs) {
in.setRankDocs(rankDocs);
}

@Override
public RankDoc[] getRankDocs() {
return in.getRankDocs();
}

@Override
public boolean isCompound() {
return in.isCompound();
}

@Override
public QueryBuilder explainQuery() {
return in.explainQuery();
}

@Override
public Float minScore() {
return in.minScore();
}

@Override
public boolean isFragment() {
return in.isFragment();
}

@Override
public String toString() {
return in.toString();
}

@Override
public String retrieverName() {
return in.retrieverName();
}

@Override
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
in.extractToSearchSourceBuilder(searchSourceBuilder, compoundUsed);
}

@Override
public String getName() {
return in.getName();
}

@Override
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
in.doToXContent(builder, params);
}

@Override
protected boolean doEquals(Object o) {
return in.equals(o);
}

@Override
protected int doHashCode() {
return in.doHashCode();
}
}
Loading

0 comments on commit 4c55ed4

Please sign in to comment.