From 521f1d4995cd35406ddf4a67e69a1fc93380aaf2 Mon Sep 17 00:00:00 2001 From: Vladyslav Lyutenko Date: Fri, 22 Nov 2024 19:51:42 +0100 Subject: [PATCH] Add partial aggregation pushdown support for Vertica --- .../plugin/vertica/ImplementAvgBigint.java | 26 +++++++++++ .../trino/plugin/vertica/VerticaClient.java | 45 +++++++++++++++++++ .../vertica/TestVerticaConnectorTest.java | 44 +++++++++++++++++- .../vertica/TestVerticaTableStatistics.java | 9 ---- 4 files changed, 114 insertions(+), 10 deletions(-) create mode 100644 plugin/trino-vertica/src/main/java/io/trino/plugin/vertica/ImplementAvgBigint.java diff --git a/plugin/trino-vertica/src/main/java/io/trino/plugin/vertica/ImplementAvgBigint.java b/plugin/trino-vertica/src/main/java/io/trino/plugin/vertica/ImplementAvgBigint.java new file mode 100644 index 0000000000000..15decfef83018 --- /dev/null +++ b/plugin/trino-vertica/src/main/java/io/trino/plugin/vertica/ImplementAvgBigint.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.vertica; + +import io.trino.plugin.jdbc.aggregation.BaseImplementAvgBigint; + +public class ImplementAvgBigint + extends BaseImplementAvgBigint +{ + @Override + protected String getRewriteFormatExpression() + { + return "avg(CAST(%s AS double precision))"; + } +} diff --git a/plugin/trino-vertica/src/main/java/io/trino/plugin/vertica/VerticaClient.java b/plugin/trino-vertica/src/main/java/io/trino/plugin/vertica/VerticaClient.java index 1ff5e328e9285..be5fe151f8966 100644 --- a/plugin/trino-vertica/src/main/java/io/trino/plugin/vertica/VerticaClient.java +++ b/plugin/trino-vertica/src/main/java/io/trino/plugin/vertica/VerticaClient.java @@ -15,6 +15,8 @@ import com.google.common.collect.ImmutableSet; import com.google.inject.Inject; +import io.trino.plugin.base.aggregation.AggregateFunctionRewriter; +import io.trino.plugin.base.aggregation.AggregateFunctionRule; import io.trino.plugin.base.expression.ConnectorExpressionRewriter; import io.trino.plugin.base.mapping.IdentifierMapping; import io.trino.plugin.jdbc.BaseJdbcClient; @@ -22,6 +24,7 @@ import io.trino.plugin.jdbc.ColumnMapping; import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; import io.trino.plugin.jdbc.JdbcJoinCondition; import io.trino.plugin.jdbc.JdbcMetadata; import io.trino.plugin.jdbc.JdbcStatisticsConfig; @@ -31,10 +34,22 @@ import io.trino.plugin.jdbc.PreparedQuery; import io.trino.plugin.jdbc.QueryBuilder; import io.trino.plugin.jdbc.WriteMapping; +import io.trino.plugin.jdbc.aggregation.ImplementAvgFloatingPoint; +import io.trino.plugin.jdbc.aggregation.ImplementCorr; +import io.trino.plugin.jdbc.aggregation.ImplementCount; +import io.trino.plugin.jdbc.aggregation.ImplementCountAll; +import io.trino.plugin.jdbc.aggregation.ImplementCountDistinct; +import io.trino.plugin.jdbc.aggregation.ImplementCovariancePop; +import io.trino.plugin.jdbc.aggregation.ImplementCovarianceSamp; +import io.trino.plugin.jdbc.aggregation.ImplementMinMax; +import io.trino.plugin.jdbc.aggregation.ImplementRegrIntercept; +import io.trino.plugin.jdbc.aggregation.ImplementRegrSlope; +import io.trino.plugin.jdbc.aggregation.ImplementSum; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.ParameterizedExpression; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.spi.TrinoException; +import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.SchemaTableName; @@ -121,6 +136,7 @@ public class VerticaClient private final boolean statisticsEnabled; private final ConnectorExpressionRewriter connectorExpressionRewriter; + private final AggregateFunctionRewriter aggregateFunctionRewriter; @Inject public VerticaClient( @@ -143,6 +159,30 @@ public VerticaClient( .map("$greater_than(left: supported_type, right: supported_type)").to("left > right") .map("$greater_than_or_equal(left: supported_type, right: supported_type)").to("left >= right") .build(); + + JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>( + this.connectorExpressionRewriter, + ImmutableSet.>builder() + .add(new ImplementCountAll(bigintTypeHandle)) + .add(new ImplementMinMax(true)) + .add(new ImplementCount(bigintTypeHandle)) + .add(new ImplementCountDistinct(bigintTypeHandle, true)) + .add(new ImplementSum(VerticaClient::toTypeHandle)) + .add(new ImplementAvgFloatingPoint()) + .add(new ImplementAvgBigint()) + .add(new ImplementCovarianceSamp()) + .add(new ImplementCovariancePop()) + .add(new ImplementCorr()) + .add(new ImplementRegrIntercept()) + .add(new ImplementRegrSlope()) + .build()); + } + + @Override + public Optional implementAggregation(ConnectorSession session, AggregateFunction aggregate, Map assignments) + { + return aggregateFunctionRewriter.rewrite(session, aggregate, assignments); } @Override @@ -435,4 +475,9 @@ public OptionalInt getMaxColumnNameLength(ConnectorSession session) { return this.getMaxColumnNameLengthFromDatabaseMetaData(session); } + + private static Optional toTypeHandle(DecimalType decimalType) + { + return Optional.of(new JdbcTypeHandle(Types.NUMERIC, Optional.of("decimal"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty())); + } } diff --git a/plugin/trino-vertica/src/test/java/io/trino/plugin/vertica/TestVerticaConnectorTest.java b/plugin/trino-vertica/src/test/java/io/trino/plugin/vertica/TestVerticaConnectorTest.java index 19b1ac487abb4..09819c7099fac 100644 --- a/plugin/trino-vertica/src/test/java/io/trino/plugin/vertica/TestVerticaConnectorTest.java +++ b/plugin/trino-vertica/src/test/java/io/trino/plugin/vertica/TestVerticaConnectorTest.java @@ -13,10 +13,12 @@ */ package io.trino.plugin.vertica; +import com.google.common.collect.ImmutableList; import io.trino.Session; import io.trino.plugin.jdbc.BaseJdbcConnectorTest; import io.trino.plugin.jdbc.JoinOperator; import io.trino.spi.connector.JoinCondition; +import io.trino.sql.planner.plan.AggregationNode; import io.trino.testing.MaterializedResult; import io.trino.testing.QueryRunner; import io.trino.testing.TestingConnectorBehavior; @@ -69,7 +71,9 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) case SUPPORTS_JOIN_PUSHDOWN -> true; case SUPPORTS_ARRAY, SUPPORTS_ADD_COLUMN_WITH_COMMENT, - SUPPORTS_AGGREGATION_PUSHDOWN, + // Vertica returns NaN for stddev functions in case of single value for example, but trino expects null + SUPPORTS_AGGREGATION_PUSHDOWN_STDDEV, + SUPPORTS_AGGREGATION_PUSHDOWN_VARIANCE, SUPPORTS_COMMENT_ON_COLUMN, SUPPORTS_COMMENT_ON_TABLE, SUPPORTS_CREATE_TABLE_WITH_COLUMN_COMMENT, @@ -371,6 +375,44 @@ public void testInsertIntoNotNullColumn() abort("TODO Enable this test"); } + @Test + @Override + public void testNumericAggregationPushdown() + { + String schemaName = getSession().getSchema().orElseThrow(); + // empty table + try (TestTable emptyTable = createAggregationTestTable(schemaName + ".test_num_agg_pd", ImmutableList.of())) { + assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + emptyTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + emptyTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + emptyTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + emptyTable.getName())).isNotFullyPushedDown(AggregationNode.class); + } + + try (TestTable testTable = createAggregationTestTable(schemaName + ".test_num_agg_pd", + ImmutableList.of("100.000, 100000000.000000000, 100.000, 100000000", "123.321, 123456789.987654321, 123.321, 123456789"))) { + assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + testTable.getName())).isFullyPushedDown(); + // Vertica just truncates when we cast decimal values instead of rounding, the only way to overcome it: + // cast to higher scale and then round to original scale, but it looks error-prone + assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + testTable.getName())).isNotFullyPushedDown(AggregationNode.class); + + // smoke testing of more complex cases + // WHERE on aggregation column + assertThat(query("SELECT min(short_decimal), min(long_decimal) FROM " + testTable.getName() + " WHERE short_decimal < 110 AND long_decimal < 124")).isFullyPushedDown(); + // WHERE on non-aggregation column + assertThat(query("SELECT min(long_decimal) FROM " + testTable.getName() + " WHERE short_decimal < 110")).isFullyPushedDown(); + // GROUP BY + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTable.getName() + " GROUP BY short_decimal")).isFullyPushedDown(); + // GROUP BY with WHERE on both grouping and aggregation column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTable.getName() + " WHERE short_decimal < 110 AND long_decimal < 124 GROUP BY short_decimal")).isFullyPushedDown(); + // GROUP BY with WHERE on grouping column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTable.getName() + " WHERE short_decimal < 110 GROUP BY short_decimal")).isFullyPushedDown(); + // GROUP BY with WHERE on aggregation column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTable.getName() + " WHERE long_decimal < 124 GROUP BY short_decimal")).isFullyPushedDown(); + } + } + @Override protected OptionalInt maxSchemaNameLength() { diff --git a/plugin/trino-vertica/src/test/java/io/trino/plugin/vertica/TestVerticaTableStatistics.java b/plugin/trino-vertica/src/test/java/io/trino/plugin/vertica/TestVerticaTableStatistics.java index 3df522ca38613..8c2637d6f2d2e 100644 --- a/plugin/trino-vertica/src/test/java/io/trino/plugin/vertica/TestVerticaTableStatistics.java +++ b/plugin/trino-vertica/src/test/java/io/trino/plugin/vertica/TestVerticaTableStatistics.java @@ -277,15 +277,6 @@ protected void testCaseColumnNames(String tableName) } } - @Test - @Override - public void testStatsWithAggregationPushdown() - { - assertThatThrownBy(super::testStatsWithAggregationPushdown) - .hasMessageContaining("Plan does not match"); - abort("Aggregate pushdown is unsupported in Vertica connector"); - } - @Test @Override public void testStatsWithTopNPushdown()