diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java index b64ae363a591c..9c5ef6437d011 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java @@ -289,24 +289,30 @@ public PreparedStatement prepareStatement( .map(jdbcType -> getWriteFunction(client, session, connection, jdbcType, parameter.getType())) .orElseGet(() -> getWriteFunction(client, session, parameter.getType())); Class javaType = writeFunction.getJavaType(); - Object value = parameter.getValue() - // The value must be present, since DefaultQueryBuilder never creates null parameters. Values coming from Domain's ValueSet are non-null, and - // nullable domains are handled explicitly, with SQL syntax. - .orElseThrow(() -> new VerifyException("Value is missing")); - if (javaType == boolean.class) { - ((BooleanWriteFunction) writeFunction).set(statement, parameterIndex, (boolean) value); - } - else if (javaType == long.class) { - ((LongWriteFunction) writeFunction).set(statement, parameterIndex, (long) value); - } - else if (javaType == double.class) { - ((DoubleWriteFunction) writeFunction).set(statement, parameterIndex, (double) value); - } - else if (javaType == Slice.class) { - ((SliceWriteFunction) writeFunction).set(statement, parameterIndex, (Slice) value); + + if (parameter.getValue().isEmpty()) { + writeFunction.setNull(statement, parameterIndex); } else { - ((ObjectWriteFunction) writeFunction).set(statement, parameterIndex, value); + Object value = parameter.getValue() + // The value must be present, since DefaultQueryBuilder never creates null parameters. Values coming from Domain's ValueSet are non-null, and + // nullable domains are handled explicitly, with SQL syntax. + .orElseThrow(() -> new VerifyException("Value is missing")); + if (javaType == boolean.class) { + ((BooleanWriteFunction) writeFunction).set(statement, parameterIndex, (boolean) value); + } + else if (javaType == long.class) { + ((LongWriteFunction) writeFunction).set(statement, parameterIndex, (long) value); + } + else if (javaType == double.class) { + ((DoubleWriteFunction) writeFunction).set(statement, parameterIndex, (double) value); + } + else if (javaType == Slice.class) { + ((SliceWriteFunction) writeFunction).set(statement, parameterIndex, (Slice) value); + } + else { + ((ObjectWriteFunction) writeFunction).set(statement, parameterIndex, value); + } } } diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java index c34a6ba33e41e..2b5d3aab7f7e7 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java @@ -353,6 +353,14 @@ public void testUpdateMultipleCondition() .hasMessage(MODIFYING_NON_TRANSACTIONAL_TABLE_MESSAGE); } + @Test + @Override + public void testUpdateWithNullValues() + { + assertThatThrownBy(super::testUpdateWithNullValues) + .hasMessage(MODIFYING_NON_TRANSACTIONAL_TABLE_MESSAGE); + } + @Test @Override public void testRowLevelUpdate() diff --git a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java index 73faaa95d02f4..a674689d46cad 100644 --- a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java +++ b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java @@ -951,6 +951,38 @@ public void testUpdate() @Override public void testUpdateMultipleCondition() {} + /** + * This test fails intermittently because Kudu doesn't have strong enough + * semantics to support writing from multiple threads. + */ + @Test + @Disabled + @Override + public void testUpdateWithNullValues() + { + withTableName("test_update_nulls", tableName -> { + assertUpdate(createTableForWrites("CREATE TABLE %s %s".formatted(tableName, NATION_COLUMNS))); + assertUpdate("INSERT INTO " + tableName + " SELECT * FROM nation", 25); + assertUpdate("UPDATE " + tableName + " SET nationkey = 100 WHERE regionkey = 2", 5); + assertQuery("SELECT count(*) FROM " + tableName + " WHERE nationkey = 100", "VALUES 5"); + + assertQuery("SELECT count(*) FROM " + tableName + " WHERE nationkey IS NULL", "VALUES 0"); + assertUpdate("UPDATE " + tableName + " SET nationkey = NULL WHERE regionkey = 2", 5); + assertQuery("SELECT count(*) FROM " + tableName + " WHERE nationkey IS NULL", "VALUES 5"); + }); + + withTableName("test_update_nulls", tableName -> { + assertUpdate(createKuduTableForWrites("CREATE TABLE %s %s".formatted(tableName, "(not_null_col INTEGER, nullable_col1 INTEGER, nullable_col2 INTEGER)"))); + assertUpdate("INSERT INTO " + tableName + " VALUES (1, 1, 1), (2, 2, 2)", 2); + assertUpdate("UPDATE " + tableName + " SET nullable_col1 = 10 WHERE not_null_col = 1", 1); + assertQuery("SELECT * FROM " + tableName, "VALUES (1, 10, 1), (2, 2, 2)"); + + assertUpdate("UPDATE " + tableName + " SET nullable_col2 = null WHERE not_null_col = 1", 1); + assertQuery("SELECT * FROM " + tableName, "VALUES (1, 10, null), (2, 2, 2)"); + assertQueryFails("UPDATE " + tableName + " SET not_null_col = TRY(1 / 0) WHERE not_null_col = 2", "NULL value not allowed for NOT NULL column: not_null_col"); + }); + } + /** * This test fails intermittently because Kudu doesn't have strong enough * semantics to support writing from multiple threads. diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java index 2eaa9869bcc87..1efcaec843e07 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseConnectorTest.java @@ -4896,6 +4896,43 @@ public void testUpdateMultipleCondition() } } + @Test + public void testUpdateWithNullValues() + { + skipTestUnless(hasBehavior(SUPPORTS_UPDATE)); + + try (TestTable table = new TestTable(getQueryRunner()::execute, "test_update_nulls", "AS SELECT * FROM nation")) { + String tableName = table.getName(); + assertUpdate("UPDATE " + tableName + " SET nationkey = 100 WHERE regionkey = 2", 5); + assertQuery("SELECT count(*) FROM " + tableName + " WHERE nationkey = 100", "VALUES 5"); + + assertQuery("SELECT count(*) FROM " + tableName + " WHERE nationkey IS NULL", "VALUES 0"); + assertUpdate("UPDATE " + tableName + " SET nationkey = NULL WHERE regionkey = 2", 5); + assertQuery("SELECT count(*) FROM " + tableName + " WHERE nationkey IS NULL", "VALUES 5"); + } + + if (!hasBehavior(SUPPORTS_NOT_NULL_CONSTRAINT)) { + return; + } + + try (TestTable table = new TestTable(getQueryRunner()::execute, "test_update_nulls", "(nullable_col1 INTEGER, nullable_col2 INTEGER, not_null_col INTEGER NOT NULL)")) { + String tableName = table.getName(); + assertUpdate("INSERT INTO " + tableName + " VALUES (1, 1, 1), (2, 2, 2)", 2); + assertUpdate("UPDATE " + tableName + " SET nullable_col1 = 10 WHERE not_null_col = 1", 1); + assertQuery("SELECT * FROM " + tableName, "VALUES (10, 1, 1), (2, 2, 2)"); + + assertUpdate("UPDATE " + tableName + " SET nullable_col2 = null WHERE not_null_col = 1", 1); + assertQuery("SELECT * FROM " + tableName, "VALUES (10, null, 1), (2, 2, 2)"); + + if (hasBehavior(SUPPORTS_ROW_LEVEL_UPDATE)) { + assertQueryFails("UPDATE " + tableName + " SET not_null_col = TRY(1 / 0) WHERE not_null_col = 2", "NULL value not allowed for NOT NULL column: not_null_col"); + } + else { + assertQueryFails("UPDATE " + tableName + " SET not_null_col = TRY(1 / 0) WHERE not_null_col = 2", MODIFYING_ROWS_MESSAGE); + } + } + } + @Test public void testRowLevelUpdate() {