Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49992][SQL] Session level collation should not impact DDL queries #48436

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin}
import org.apache.spark.sql.errors.QueryParsingErrors
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType}
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DefaultStringType, DoubleType, FloatType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, VarcharType, VariantType, YearMonthIntervalType}

class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
protected def typedVisit[T](ctx: ParseTree): T = {
Expand Down Expand Up @@ -74,7 +74,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
case (TIMESTAMP_LTZ, Nil) => TimestampType
case (STRING, Nil) =>
typeCtx.children.asScala.toSeq match {
case Seq(_) => SqlApiConf.get.defaultStringType
case Seq(_) => DefaultStringType()
case Seq(_, ctx: CollateClauseContext) =>
val collationName = visitCollateClause(ctx)
val collationId = CollationFactory.collationNameToId(collationName)
Expand Down
26 changes: 26 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/types/DataType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,32 @@ abstract class DataType extends AbstractDataType {
*/
private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = f(this)

/**
* Recursively applies the provided partial function `f` to transform this DataType tree.
*/
private[spark] def transformRecursively(f: PartialFunction[DataType, DataType]): DataType = {
this match {
case _ if f.isDefinedAt(this) =>
f(this)

case ArrayType(elementType, containsNull) =>
ArrayType(elementType.transformRecursively(f), containsNull)

case MapType(keyType, valueType, valueContainsNull) =>
MapType(
keyType.transformRecursively(f),
valueType.transformRecursively(f),
valueContainsNull)

case StructType(fields) =>
StructType(fields.map { field =>
field.copy(dataType = field.dataType.transformRecursively(f))
})

case _ => this
}
}

final override private[sql] def defaultConcreteType: DataType = this

override private[sql] def acceptsType(other: DataType): Boolean = sameType(other)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.json4s.JsonAST.{JString, JValue}

import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.internal.SqlApiConf

/**
* The data type representing `String` values. Please use the singleton `DataTypes.StringType`.
Expand All @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.CollationFactory
* The id of collation for this StringType.
*/
@Stable
class StringType private (val collationId: Int) extends AtomicType with Serializable {
class StringType private[sql] (val collationId: Int) extends AtomicType with Serializable {

/**
* Support for Binary Equality implies that strings are considered equal only if they are byte
Expand Down Expand Up @@ -77,6 +78,10 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ
if (isUTF8BinaryCollation) "string"
else s"string collate ${CollationFactory.fetchCollation(collationId).collationName}"

override def toString: String =
if (isUTF8BinaryCollation) "StringType"
else s"StringType($collationId)"

// Due to backwards compatibility and compatibility with other readers
// all string types are serialized in json as regular strings and
// the collation information is written to struct field metadata
Expand Down Expand Up @@ -109,3 +114,19 @@ case object StringType extends StringType(0) {
new StringType(collationId)
stefankandic marked this conversation as resolved.
Show resolved Hide resolved
}
}

/**
* The result type of literals, column definitions without explicit collation, casts to string and
* some expressions that produce strings but whose output type is not based on the types of its
* children. Idea is to have this behave like a string with the default collation of the session,
* but that we can still differentiate it from a regular string type, because in some places
* default string is not the one with the session collation (e.g. in DDL commands).
*/
private[spark] class DefaultStringType private (collationId: Int)
extends StringType(collationId) {}

private[spark] object DefaultStringType {
def apply(): DefaultStringType = {
new DefaultStringType(SqlApiConf.get.defaultStringType.collationId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ResolveProcedures ::
BindProcedures ::
ResolveTableSpec ::
ReplaceDefaultStringType ::
ResolveAliases ::
ResolveSubquery ::
ResolveSubqueryColumnAliases ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ import scala.annotation.tailrec
import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType}
import org.apache.spark.sql.types.{ArrayType, DataType, DefaultStringType, MapType, StringType}

object CollationTypeCasts extends TypeCoercionRule {
override val transform: PartialFunction[Expression, Expression] = {
Expand Down Expand Up @@ -196,7 +195,7 @@ object CollationTypeCasts extends TypeCoercionRule {
)
}
else {
implicitTypes.headOption.map(StringType(_)).getOrElse(SQLConf.get.defaultStringType)
implicitTypes.headOption.map(StringType(_)).getOrElse(DefaultStringType())
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, ColumnDefinition, LogicalPlan, QualifiedColType, ReplaceColumns, V1DDLCommand, V2CreateTablePlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{DataType, DefaultStringType, StringType}
import org.apache.spark.sql.util.SchemaUtils

/**
* Replaces default string types in DDL commands. DDL commands should have a default collation
* based on the object's collation, however, this is not implemented yet. So, we will just use
* UTF8_BINARY for now.
*/
object ReplaceDefaultStringType extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan resolveOperators {
case createTable: V2CreateTablePlan =>
transformPlan(createTable, StringType)

case v1Ddl: V1DDLCommand =>
transformPlan(v1Ddl, StringType)

case addCols: AddColumns =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we assume this rule ReplaceDefaultStringType will run before ResolveSessionCatalog? Otherwise these commands may be converted to v1 and can't be matched here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, however I am having some issues there. For example, create view can be converted here: https://github.com/stefankandic/spark/blob/b1ff7672cba12750d41d803f0faeb3487d934601/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala#L1563. I talked with Max a bit about this issue in the comment above, do you think there is a good way to catch v1 stuff as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some more digging I think it's best we handle views separately, as they need to be treated differently; it's not enough to just replace create/alter view as we still need to resolve default types correctly once the view is queried.

addCols.copy(columnsToAdd = replaceColumnTypes(addCols.columnsToAdd, StringType))

case replaceCols: ReplaceColumns =>
replaceCols.copy(columnsToAdd = replaceColumnTypes(replaceCols.columnsToAdd, StringType))

case alter: AlterColumn
if alter.dataType.isDefined && SchemaUtils.hasDefaultStringType(alter.dataType.get) =>
alter.copy(dataType = Some(replaceDefaultStringType(alter.dataType.get, StringType)))
}
}

private def transformPlan(plan: LogicalPlan, newType: StringType): LogicalPlan = {
plan resolveOperators { operator =>
operator.transformExpressionsUp { expression =>
transformExpression(expression, newType)
}
}
}

private def transformExpression(expression: Expression, newType: StringType): Expression = {
expression match {
case columnDef: ColumnDefinition if SchemaUtils.hasDefaultStringType(columnDef.dataType) =>
columnDef.copy(dataType = replaceDefaultStringType(columnDef.dataType, newType))

case cast: Cast if SchemaUtils.hasDefaultStringType(cast.dataType) =>
cast.copy(dataType = replaceDefaultStringType(cast.dataType, newType))

case Literal(value, dt) if SchemaUtils.hasDefaultStringType(dt) =>
val replaced = replaceDefaultStringType(dt, newType)
Literal(value, replaced)

case other => other
}
}

private def replaceDefaultStringType(dataType: DataType, newType: StringType): DataType = {
dataType.transformRecursively {
case _: DefaultStringType => newType
}
}

private def replaceColumnTypes(
colTypes: Seq[QualifiedColType],
newType: StringType): Seq[QualifiedColType] = {
colTypes.map {
case colWithDefault if SchemaUtils.hasDefaultStringType(colWithDefault.dataType) =>
val replaced = replaceDefaultStringType(colWithDefault.dataType, newType)
colWithDefault.copy(dataType = replaced)

case col => col
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2071,15 +2071,40 @@ class AstBuilder extends DataTypeAstBuilder
Seq.tabulate(rows.head.size)(i => s"col${i + 1}")
}

val unresolvedTable = UnresolvedInlineTable(aliases, rows.toSeq)
val table = if (conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)) {
val rowSeq = rows.toSeq
val unresolvedTable = UnresolvedInlineTable(aliases, rowSeq)
val table = if (canEagerlyEvaluateInlineTable(rowSeq)) {
EvaluateUnresolvedInlineTable.evaluate(unresolvedTable)
} else {
unresolvedTable
}
table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan)
}

/**
* Determines if the inline table can be eagerly evaluated. Eager evaluation is not allowed
* if the session-level collation is set and there are string literals present in the expressions,
* because the result may depend on the collation of the input.
*/
private def canEagerlyEvaluateInlineTable(rows: Seq[Seq[Expression]]): Boolean = {
val configSet = conf.getConf(SQLConf.EAGER_EVAL_OF_UNRESOLVED_INLINE_TABLE_ENABLED)
val sessionCollationUnchanged = DefaultStringType() == StringType

configSet &&
(sessionCollationUnchanged || !rows.exists(_.exists(containsStringLiteral)))
}

private def containsStringLiteral(expression: Expression): Boolean = {
def inner(expr: Expression): Boolean = expr match {
case Literal(_, dataType) =>
dataType.existsRecursively(_.isInstanceOf[StringType])
case _ =>
expr.children.exists(inner)
}

expression.resolved && inner(expression)
}

/**
* Create an alias (SubqueryAlias) for a join relation. This is practically the same as
* visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different
Expand Down Expand Up @@ -3290,7 +3315,7 @@ class AstBuilder extends DataTypeAstBuilder
* Create a String literal expression.
*/
override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) {
Literal.create(createString(ctx), conf.defaultStringType)
Literal.create(createString(ctx), DefaultStringType())
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ trait V2PartitionCommand extends UnaryCommand {
override def child: LogicalPlan = table
}

/**
* Trait to add to v1 ddl commands so that we can still catch those plan nodes
* in the catalyst analyzer rules.
*/
trait V1DDLCommand {}

/**
* Append data to an existing table.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression}
import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, NamedTransform, Transform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, DefaultStringType, MapType, StringType, StructField, StructType}
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.SparkSchemaUtils

Expand Down Expand Up @@ -304,6 +304,13 @@ private[spark] object SchemaUtils {
}
}

/**
* Checks if a given data type has a default string type.
*/
def hasDefaultStringType(dataType: DataType): Boolean = {
dataType.existsRecursively(_.isInstanceOf[DefaultStringType])
}

/**
* Replaces any collated string type with non collated StringType
* recursively in the given data type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ class DataTypeSuite extends SparkFunSuite {
checkEqualsIgnoreCompatibleCollation(StringType, StringType("UTF8_LCASE"),
expected = true)
checkEqualsIgnoreCompatibleCollation(
StringType("UTF8_BINARY"), StringType("UTF8_LCASE"), expected = true)
StringType("UTF8_LCASE"), StringType("UTF8_BINARY"), expected = true)
// Complex types.
checkEqualsIgnoreCompatibleCollation(
ArrayType(StringType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, V1DDLCommand}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.command.{DDLUtils, LeafRunnableCommand}
Expand All @@ -43,7 +43,7 @@ import org.apache.spark.sql.types._
case class CreateTable(
tableDesc: CatalogTable,
mode: SaveMode,
query: Option[LogicalPlan]) extends LogicalPlan {
query: Option[LogicalPlan]) extends LogicalPlan with V1DDLCommand {
assert(tableDesc.provider.isDefined, "The table to be created must have a provider.")

if (query.isEmpty) {
Expand Down Expand Up @@ -79,7 +79,7 @@ case class CreateTempViewUsing(
replace: Boolean,
global: Boolean,
provider: String,
options: Map[String, String]) extends LeafRunnableCommand {
options: Map[String, String]) extends LeafRunnableCommand with V1DDLCommand {

if (tableIdent.database.isDefined) {
throw QueryCompilationErrors.cannotSpecifyDatabaseForTempViewError(tableIdent)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@ class CollationSQLRegexpSuite
val tableNameLcase = "T_LCASE"
withTable(tableNameLcase) {
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UTF8_LCASE") {
sql(s"CREATE TABLE IF NOT EXISTS $tableNameLcase(c STRING) using PARQUET")
sql(s"""
|CREATE TABLE IF NOT EXISTS $tableNameLcase(
| c STRING COLLATE UTF8_LCASE
|) using PARQUET
|""".stripMargin)
sql(s"INSERT INTO $tableNameLcase(c) VALUES('ABC')")
checkAnswer(sql(s"select c like 'ab%' FROM $tableNameLcase"), Row(true))
checkAnswer(sql(s"select c like '%bc' FROM $tableNameLcase"), Row(true))
Expand Down
23 changes: 0 additions & 23 deletions sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1022,29 +1022,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
}
}

test("SPARK-47431: Default collation set to UNICODE, column type test") {
withTable("t") {
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") {
sql(s"CREATE TABLE t(c1 STRING) USING PARQUET")
sql(s"INSERT INTO t VALUES ('a')")
checkAnswer(sql(s"SELECT collation(c1) FROM t"), Seq(Row("UNICODE")))
}
}
}

test("SPARK-47431: Create table with UTF8_BINARY, make sure collation persists on read") {
withTable("t") {
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UTF8_BINARY") {
sql("CREATE TABLE t(c1 STRING) USING PARQUET")
sql("INSERT INTO t VALUES ('a')")
checkAnswer(sql("SELECT collation(c1) FROM t"), Seq(Row("UTF8_BINARY")))
}
withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") {
checkAnswer(sql("SELECT collation(c1) FROM t"), Seq(Row("UTF8_BINARY")))
}
}
}

test("Create dataframe with non utf8 binary collation") {
val schema = StructType(Seq(StructField("Name", StringType("UNICODE_CI"))))
val data = Seq(Row("Alice"), Row("Bob"), Row("bob"))
Expand Down
Loading