Skip to content

Commit

Permalink
Add partial aggregation pushdown support for Vertica
Browse files Browse the repository at this point in the history
  • Loading branch information
vlad-lyutenko committed Nov 25, 2024
1 parent e8300c0 commit 521f1d4
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.vertica;

import io.trino.plugin.jdbc.aggregation.BaseImplementAvgBigint;

public class ImplementAvgBigint
extends BaseImplementAvgBigint
{
@Override
protected String getRewriteFormatExpression()
{
return "avg(CAST(%s AS double precision))";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@

import com.google.common.collect.ImmutableSet;
import com.google.inject.Inject;
import io.trino.plugin.base.aggregation.AggregateFunctionRewriter;
import io.trino.plugin.base.aggregation.AggregateFunctionRule;
import io.trino.plugin.base.expression.ConnectorExpressionRewriter;
import io.trino.plugin.base.mapping.IdentifierMapping;
import io.trino.plugin.jdbc.BaseJdbcClient;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.ColumnMapping;
import io.trino.plugin.jdbc.ConnectionFactory;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcJoinCondition;
import io.trino.plugin.jdbc.JdbcMetadata;
import io.trino.plugin.jdbc.JdbcStatisticsConfig;
Expand All @@ -31,10 +34,22 @@
import io.trino.plugin.jdbc.PreparedQuery;
import io.trino.plugin.jdbc.QueryBuilder;
import io.trino.plugin.jdbc.WriteMapping;
import io.trino.plugin.jdbc.aggregation.ImplementAvgFloatingPoint;
import io.trino.plugin.jdbc.aggregation.ImplementCorr;
import io.trino.plugin.jdbc.aggregation.ImplementCount;
import io.trino.plugin.jdbc.aggregation.ImplementCountAll;
import io.trino.plugin.jdbc.aggregation.ImplementCountDistinct;
import io.trino.plugin.jdbc.aggregation.ImplementCovariancePop;
import io.trino.plugin.jdbc.aggregation.ImplementCovarianceSamp;
import io.trino.plugin.jdbc.aggregation.ImplementMinMax;
import io.trino.plugin.jdbc.aggregation.ImplementRegrIntercept;
import io.trino.plugin.jdbc.aggregation.ImplementRegrSlope;
import io.trino.plugin.jdbc.aggregation.ImplementSum;
import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder;
import io.trino.plugin.jdbc.expression.ParameterizedExpression;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.SchemaTableName;
Expand Down Expand Up @@ -121,6 +136,7 @@ public class VerticaClient

private final boolean statisticsEnabled;
private final ConnectorExpressionRewriter<ParameterizedExpression> connectorExpressionRewriter;
private final AggregateFunctionRewriter<JdbcExpression, ?> aggregateFunctionRewriter;

@Inject
public VerticaClient(
Expand All @@ -143,6 +159,30 @@ public VerticaClient(
.map("$greater_than(left: supported_type, right: supported_type)").to("left > right")
.map("$greater_than_or_equal(left: supported_type, right: supported_type)").to("left >= right")
.build();

JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>(
this.connectorExpressionRewriter,
ImmutableSet.<AggregateFunctionRule<JdbcExpression, ParameterizedExpression>>builder()
.add(new ImplementCountAll(bigintTypeHandle))
.add(new ImplementMinMax(true))
.add(new ImplementCount(bigintTypeHandle))
.add(new ImplementCountDistinct(bigintTypeHandle, true))
.add(new ImplementSum(VerticaClient::toTypeHandle))
.add(new ImplementAvgFloatingPoint())
.add(new ImplementAvgBigint())
.add(new ImplementCovarianceSamp())
.add(new ImplementCovariancePop())
.add(new ImplementCorr())
.add(new ImplementRegrIntercept())
.add(new ImplementRegrSlope())
.build());
}

@Override
public Optional<JdbcExpression> implementAggregation(ConnectorSession session, AggregateFunction aggregate, Map<String, ColumnHandle> assignments)
{
return aggregateFunctionRewriter.rewrite(session, aggregate, assignments);
}

@Override
Expand Down Expand Up @@ -435,4 +475,9 @@ public OptionalInt getMaxColumnNameLength(ConnectorSession session)
{
return this.getMaxColumnNameLengthFromDatabaseMetaData(session);
}

private static Optional<JdbcTypeHandle> toTypeHandle(DecimalType decimalType)
{
return Optional.of(new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
*/
package io.trino.plugin.vertica;

import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.plugin.jdbc.BaseJdbcConnectorTest;
import io.trino.plugin.jdbc.JoinOperator;
import io.trino.spi.connector.JoinCondition;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.testing.MaterializedResult;
import io.trino.testing.QueryRunner;
import io.trino.testing.TestingConnectorBehavior;
Expand Down Expand Up @@ -69,7 +71,9 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior)
case SUPPORTS_JOIN_PUSHDOWN -> true;
case SUPPORTS_ARRAY,
SUPPORTS_ADD_COLUMN_WITH_COMMENT,
SUPPORTS_AGGREGATION_PUSHDOWN,
// Vertica returns NaN for stddev functions in case of single value for example, but trino expects null
SUPPORTS_AGGREGATION_PUSHDOWN_STDDEV,
SUPPORTS_AGGREGATION_PUSHDOWN_VARIANCE,
SUPPORTS_COMMENT_ON_COLUMN,
SUPPORTS_COMMENT_ON_TABLE,
SUPPORTS_CREATE_TABLE_WITH_COLUMN_COMMENT,
Expand Down Expand Up @@ -371,6 +375,44 @@ public void testInsertIntoNotNullColumn()
abort("TODO Enable this test");
}

@Test
@Override
public void testNumericAggregationPushdown()
{
String schemaName = getSession().getSchema().orElseThrow();
// empty table
try (TestTable emptyTable = createAggregationTestTable(schemaName + ".test_num_agg_pd", ImmutableList.of())) {
assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + emptyTable.getName())).isFullyPushedDown();
assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + emptyTable.getName())).isFullyPushedDown();
assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + emptyTable.getName())).isFullyPushedDown();
assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + emptyTable.getName())).isNotFullyPushedDown(AggregationNode.class);
}

try (TestTable testTable = createAggregationTestTable(schemaName + ".test_num_agg_pd",
ImmutableList.of("100.000, 100000000.000000000, 100.000, 100000000", "123.321, 123456789.987654321, 123.321, 123456789"))) {
assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + testTable.getName())).isFullyPushedDown();
assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + testTable.getName())).isFullyPushedDown();
assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + testTable.getName())).isFullyPushedDown();
// Vertica just truncates when we cast decimal values instead of rounding, the only way to overcome it:
// cast to higher scale and then round to original scale, but it looks error-prone
assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + testTable.getName())).isNotFullyPushedDown(AggregationNode.class);

// smoke testing of more complex cases
// WHERE on aggregation column
assertThat(query("SELECT min(short_decimal), min(long_decimal) FROM " + testTable.getName() + " WHERE short_decimal < 110 AND long_decimal < 124")).isFullyPushedDown();
// WHERE on non-aggregation column
assertThat(query("SELECT min(long_decimal) FROM " + testTable.getName() + " WHERE short_decimal < 110")).isFullyPushedDown();
// GROUP BY
assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTable.getName() + " GROUP BY short_decimal")).isFullyPushedDown();
// GROUP BY with WHERE on both grouping and aggregation column
assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTable.getName() + " WHERE short_decimal < 110 AND long_decimal < 124 GROUP BY short_decimal")).isFullyPushedDown();
// GROUP BY with WHERE on grouping column
assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTable.getName() + " WHERE short_decimal < 110 GROUP BY short_decimal")).isFullyPushedDown();
// GROUP BY with WHERE on aggregation column
assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTable.getName() + " WHERE long_decimal < 124 GROUP BY short_decimal")).isFullyPushedDown();
}
}

@Override
protected OptionalInt maxSchemaNameLength()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,15 +277,6 @@ protected void testCaseColumnNames(String tableName)
}
}

@Test
@Override
public void testStatsWithAggregationPushdown()
{
assertThatThrownBy(super::testStatsWithAggregationPushdown)
.hasMessageContaining("Plan does not match");
abort("Aggregate pushdown is unsupported in Vertica connector");
}

@Test
@Override
public void testStatsWithTopNPushdown()
Expand Down

0 comments on commit 521f1d4

Please sign in to comment.