diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java index b64ae363a591c..9c5ef6437d011 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultQueryBuilder.java @@ -289,24 +289,30 @@ public PreparedStatement prepareStatement( .map(jdbcType -> getWriteFunction(client, session, connection, jdbcType, parameter.getType())) .orElseGet(() -> getWriteFunction(client, session, parameter.getType())); Class javaType = writeFunction.getJavaType(); - Object value = parameter.getValue() - // The value must be present, since DefaultQueryBuilder never creates null parameters. Values coming from Domain's ValueSet are non-null, and - // nullable domains are handled explicitly, with SQL syntax. - .orElseThrow(() -> new VerifyException("Value is missing")); - if (javaType == boolean.class) { - ((BooleanWriteFunction) writeFunction).set(statement, parameterIndex, (boolean) value); - } - else if (javaType == long.class) { - ((LongWriteFunction) writeFunction).set(statement, parameterIndex, (long) value); - } - else if (javaType == double.class) { - ((DoubleWriteFunction) writeFunction).set(statement, parameterIndex, (double) value); - } - else if (javaType == Slice.class) { - ((SliceWriteFunction) writeFunction).set(statement, parameterIndex, (Slice) value); + + if (parameter.getValue().isEmpty()) { + writeFunction.setNull(statement, parameterIndex); } else { - ((ObjectWriteFunction) writeFunction).set(statement, parameterIndex, value); + Object value = parameter.getValue() + // The value must be present, since DefaultQueryBuilder never creates null parameters. Values coming from Domain's ValueSet are non-null, and + // nullable domains are handled explicitly, with SQL syntax. + .orElseThrow(() -> new VerifyException("Value is missing")); + if (javaType == boolean.class) { + ((BooleanWriteFunction) writeFunction).set(statement, parameterIndex, (boolean) value); + } + else if (javaType == long.class) { + ((LongWriteFunction) writeFunction).set(statement, parameterIndex, (long) value); + } + else if (javaType == double.class) { + ((DoubleWriteFunction) writeFunction).set(statement, parameterIndex, (double) value); + } + else if (javaType == Slice.class) { + ((SliceWriteFunction) writeFunction).set(statement, parameterIndex, (Slice) value); + } + else { + ((ObjectWriteFunction) writeFunction).set(statement, parameterIndex, value); + } } } diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index 1f45eb0e287b0..498dad544a470 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -1868,6 +1868,10 @@ public void testConstantUpdateWithVarcharEqualityPredicates() } assertUpdate("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 = 'A'", 1); assertQuery("SELECT * FROM " + table.getName(), "VALUES (1, 'a'), (20, 'A')"); + + // test update set to null + assertUpdate("UPDATE " + table.getName() + " SET col1 = null WHERE col2 = 'A'", 1); + assertQuery("SELECT * FROM " + table.getName(), "VALUES (1, 'a'), (null, 'A')"); } } @@ -1883,6 +1887,10 @@ public void testConstantUpdateWithVarcharInequalityPredicates() assertUpdate("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 != 'A'", 1); assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a'), (2, 'A')"); + + // test update set to null + assertUpdate("UPDATE " + table.getName() + " SET col1 = null WHERE col2 != 'A'", 1); + assertQuery("SELECT * FROM " + table.getName(), "VALUES (null, 'a'), (2, 'A')"); } } @@ -1902,6 +1910,13 @@ public void testConstantUpdateWithVarcharGreaterAndLowerPredicate() assertUpdate("UPDATE " + table.getName() + " SET col1 = 20 WHERE col2 < 'a'", 1); assertQuery("SELECT * FROM " + table.getName(), "VALUES (20, 'a'), (20, 'A')"); + + // test update set to null + assertUpdate("UPDATE " + table.getName() + " SET col1 = null WHERE col2 > 'A'", 1); + assertQuery("SELECT * FROM " + table.getName(), "VALUES (null, 'a'), (2, 'A')"); + + assertUpdate("UPDATE " + table.getName() + " SET col1 = null WHERE col2 < 'a'", 1); + assertQuery("SELECT * FROM " + table.getName(), "VALUES (null, 'a'), (null, 'A')"); } }