Skip to content

Commit

Permalink
[SPARK-49966][SQL] Use Invoke to implement JsonToStructs(`from_js…
Browse files Browse the repository at this point in the history
…on`)

### What changes were proposed in this pull request?
The pr aims to use `Invoke` to implement `JsonToStructs`(`from_json`).

### Why are the changes needed?
Based on cloud-fan's suggestion, I believe that implementing `JsonToStructs`(`from_json`) with `Invoke` can greatly simplify the code.
#48466 (comment)

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Update existed UT.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48509 from panbingkun/SPARK-49966_FOLLOWUP.

Authored-by: panbingkun <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
panbingkun authored and MaxGekk committed Oct 18, 2024
1 parent ff47dd9 commit a1fc7e6
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ object JsonExpressionEvalUtils {
}
}

class JsonToStructsEvaluator(
case class JsonToStructsEvaluator(
options: Map[String, String],
nullableSchema: DataType,
nameOfCorruptRecord: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils, JsonToStructsEvaluator}
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, TreePattern}
import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, RUNTIME_REPLACEABLE, TreePattern}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -637,9 +637,9 @@ case class JsonToStructs(
timeZoneId: Option[String] = None,
variantAllowDuplicateKeys: Boolean = SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS))
extends UnaryExpression
with TimeZoneAwareExpression
with RuntimeReplaceable
with ExpectsInputTypes
with NullIntolerant
with TimeZoneAwareExpression
with QueryErrorsBase {

// The JSON input data might be missing certain fields. We force the nullability
Expand All @@ -649,7 +649,7 @@ case class JsonToStructs(

override def nullable: Boolean = true

final override def nodePatternsInternal(): Seq[TreePattern] = Seq(JSON_TO_STRUCT)
override def nodePatternsInternal(): Seq[TreePattern] = Seq(JSON_TO_STRUCT, RUNTIME_REPLACEABLE)

// Used in `FunctionRegistry`
def this(child: Expression, schema: Expression, options: Map[String, String]) =
Expand Down Expand Up @@ -683,32 +683,6 @@ case class JsonToStructs(
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))

@transient
private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)

@transient
private lazy val evaluator = new JsonToStructsEvaluator(
options, nullableSchema, nameOfCorruptRecord, timeZoneId, variantAllowDuplicateKeys)

override def nullSafeEval(json: Any): Any = evaluator.evaluate(json.asInstanceOf[UTF8String])

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val refEvaluator = ctx.addReferenceObj("evaluator", evaluator)
val eval = child.genCode(ctx)
val resultType = CodeGenerator.boxedType(dataType)
val resultTerm = ctx.freshName("result")
ev.copy(code =
code"""
|${eval.code}
|$resultType $resultTerm = ($resultType) $refEvaluator.evaluate(${eval.value});
|boolean ${ev.isNull} = $resultTerm == null;
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $resultTerm;
|}
|""".stripMargin)
}

override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: Nil

override def sql: String = schema match {
Expand All @@ -718,6 +692,21 @@ case class JsonToStructs(

override def prettyName: String = "from_json"

@transient
private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)

@transient
lazy val evaluator: JsonToStructsEvaluator = JsonToStructsEvaluator(
options, nullableSchema, nameOfCorruptRecord, timeZoneId, variantAllowDuplicateKeys)

override def replacement: Expression = Invoke(
Literal.create(evaluator, ObjectType(classOf[JsonToStructsEvaluator])),
"evaluate",
dataType,
Seq(child),
Seq(child.dataType)
)

override protected def withNewChildInternal(newChild: Expression): JsonToStructs =
copy(child = newChild)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
test("from_json escaping") {
val schema = StructType(StructField("\"quote", IntegerType) :: Nil)
GenerateUnsafeProjection.generate(
JsonToStructs(schema, Map.empty, Literal("\"quote"), UTC_OPT) :: Nil)
JsonToStructs(schema, Map.empty, Literal("\"quote"), UTC_OPT).replacement :: Nil)
}

test("from_json") {
Expand Down Expand Up @@ -729,7 +729,8 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
test("from/to json - interval support") {
val schema = StructType(StructField("i", CalendarIntervalType) :: Nil)
checkEvaluation(
JsonToStructs(schema, Map.empty, Literal.create("""{"i":"1 year 1 day"}""", StringType)),
JsonToStructs(schema, Map.empty, Literal.create("""{"i":"1 year 1 day"}""", StringType),
UTC_OPT),
InternalRow(new CalendarInterval(12, 1, 0)))

Seq(MapType(CalendarIntervalType, IntegerType), MapType(IntegerType, CalendarIntervalType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper {

Seq("""{"a":1, "b":2, "c": 123, "d": "test"}""", null).foreach(v => {
val row = create_row(v)
checkEvaluation(e1, e2.eval(row), row)
checkEvaluation(e1, replace(e2).eval(row), row)
})
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0]
Project [invoke(JsonToStructsEvaluator(Map(),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),_corrupt_record,Some(America/Los_Angeles),false).evaluate(g#0)) AS from_json(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0]
Project [invoke(JsonToStructsEvaluator(Map(),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),_corrupt_record,Some(America/Los_Angeles),false).evaluate(g#0)) AS from_json(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0]
Project [invoke(JsonToStructsEvaluator(Map(),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),_corrupt_record,Some(America/Los_Angeles),false).evaluate(g#0)) AS from_json(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ trait BaseScriptTransformationExec extends UnaryExecNode {
val complexTypeFactory = JsonToStructs(attr.dataType,
ioschema.outputSerdeProps.toMap, Literal(null), Some(conf.sessionLocalTimeZone))
wrapperConvertException(data =>
complexTypeFactory.nullSafeEval(UTF8String.fromString(data)), any => any)
complexTypeFactory.evaluator.evaluate(UTF8String.fromString(data)), any => any)
case udt: UserDefinedType[_] =>
wrapperConvertException(data => udt.deserialize(data), converter)
case dt =>
Expand Down

0 comments on commit a1fc7e6

Please sign in to comment.