diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala index 6291e62304a38..efa5c930b73da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala @@ -57,7 +57,7 @@ object JsonExpressionEvalUtils { } } -class JsonToStructsEvaluator( +case class JsonToStructsEvaluator( options: Map[String, String], nullableSchema: DataType, nameOfCorruptRecord: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index a553336015b88..d884e76f5256d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils, JsonToStructsEvaluator} -import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, TreePattern} +import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, RUNTIME_REPLACEABLE, TreePattern} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf @@ -637,9 +637,9 @@ case class JsonToStructs( timeZoneId: Option[String] = None, variantAllowDuplicateKeys: Boolean = SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS)) extends UnaryExpression - with TimeZoneAwareExpression + with RuntimeReplaceable with ExpectsInputTypes - with NullIntolerant + with TimeZoneAwareExpression with QueryErrorsBase { // The JSON input data might be missing certain fields. We force the nullability @@ -649,7 +649,7 @@ case class JsonToStructs( override def nullable: Boolean = true - final override def nodePatternsInternal(): Seq[TreePattern] = Seq(JSON_TO_STRUCT) + override def nodePatternsInternal(): Seq[TreePattern] = Seq(JSON_TO_STRUCT, RUNTIME_REPLACEABLE) // Used in `FunctionRegistry` def this(child: Expression, schema: Expression, options: Map[String, String]) = @@ -683,32 +683,6 @@ case class JsonToStructs( override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - @transient - private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD) - - @transient - private lazy val evaluator = new JsonToStructsEvaluator( - options, nullableSchema, nameOfCorruptRecord, timeZoneId, variantAllowDuplicateKeys) - - override def nullSafeEval(json: Any): Any = evaluator.evaluate(json.asInstanceOf[UTF8String]) - - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val refEvaluator = ctx.addReferenceObj("evaluator", evaluator) - val eval = child.genCode(ctx) - val resultType = CodeGenerator.boxedType(dataType) - val resultTerm = ctx.freshName("result") - ev.copy(code = - code""" - |${eval.code} - |$resultType $resultTerm = ($resultType) $refEvaluator.evaluate(${eval.value}); - |boolean ${ev.isNull} = $resultTerm == null; - |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - |if (!${ev.isNull}) { - | ${ev.value} = $resultTerm; - |} - |""".stripMargin) - } - override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil override def sql: String = schema match { @@ -718,6 +692,21 @@ case class JsonToStructs( override def prettyName: String = "from_json" + @transient + private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD) + + @transient + lazy val evaluator: JsonToStructsEvaluator = JsonToStructsEvaluator( + options, nullableSchema, nameOfCorruptRecord, timeZoneId, variantAllowDuplicateKeys) + + override def replacement: Expression = Invoke( + Literal.create(evaluator, ObjectType(classOf[JsonToStructsEvaluator])), + "evaluate", + dataType, + Seq(child), + Seq(child.dataType) + ) + override protected def withNewChildInternal(newChild: Expression): JsonToStructs = copy(child = newChild) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 0afaf4ec097c8..edb7b93ecdf68 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -420,7 +420,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("from_json escaping") { val schema = StructType(StructField("\"quote", IntegerType) :: Nil) GenerateUnsafeProjection.generate( - JsonToStructs(schema, Map.empty, Literal("\"quote"), UTC_OPT) :: Nil) + JsonToStructs(schema, Map.empty, Literal("\"quote"), UTC_OPT).replacement :: Nil) } test("from_json") { @@ -729,7 +729,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with test("from/to json - interval support") { val schema = StructType(StructField("i", CalendarIntervalType) :: Nil) checkEvaluation( - JsonToStructs(schema, Map.empty, Literal.create("""{"i":"1 year 1 day"}""", StringType)), + JsonToStructs(schema, Map.empty, Literal.create("""{"i":"1 year 1 day"}""", StringType), + UTC_OPT), InternalRow(new CalendarInterval(12, 1, 0))) Seq(MapType(CalendarIntervalType, IntegerType), MapType(IntegerType, CalendarIntervalType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala index eed06da609f8e..7af2be2db01d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala @@ -292,7 +292,7 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { Seq("""{"a":1, "b":2, "c": 123, "d": "test"}""", null).foreach(v => { val row = create_row(v) - checkEvaluation(e1, e2.eval(row), row) + checkEvaluation(e1, replace(e2).eval(row), row) }) } diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json.explain index 8d1d122d156ff..9bc33b3b97d2c 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json.explain @@ -1,2 +1,2 @@ -Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0] +Project [invoke(JsonToStructsEvaluator(Map(),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),_corrupt_record,Some(America/Los_Angeles),false).evaluate(g#0)) AS from_json(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_orphaned.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_orphaned.explain index 8d1d122d156ff..9bc33b3b97d2c 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_orphaned.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_orphaned.explain @@ -1,2 +1,2 @@ -Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0] +Project [invoke(JsonToStructsEvaluator(Map(),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),_corrupt_record,Some(America/Los_Angeles),false).evaluate(g#0)) AS from_json(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_with_json_schema.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_with_json_schema.explain index 8d1d122d156ff..9bc33b3b97d2c 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_with_json_schema.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_with_json_schema.explain @@ -1,2 +1,2 @@ -Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0] +Project [invoke(JsonToStructsEvaluator(Map(),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),_corrupt_record,Some(America/Los_Angeles),false).evaluate(g#0)) AS from_json(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala index af3a8d67e3c29..2a1554d287a8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BaseScriptTransformationExec.scala @@ -239,7 +239,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode { val complexTypeFactory = JsonToStructs(attr.dataType, ioschema.outputSerdeProps.toMap, Literal(null), Some(conf.sessionLocalTimeZone)) wrapperConvertException(data => - complexTypeFactory.nullSafeEval(UTF8String.fromString(data)), any => any) + complexTypeFactory.evaluator.evaluate(UTF8String.fromString(data)), any => any) case udt: UserDefinedType[_] => wrapperConvertException(data => udt.deserialize(data), converter) case dt =>