diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala index d3dfc462da5c8..c0203dd762f7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala @@ -573,3 +573,156 @@ case class GetJsonObjectEvaluator(cachedPath: UTF8String) { } } } + +/** + * Evaluates multiple simple top-level JSON fields in one parse. + */ +case class MultiGetJsonObjectEvaluator( + fieldNames: Seq[String], + fallbackPaths: Seq[UTF8String]) { + import SharedFactory._ + + require( + fieldNames.nonEmpty && + fieldNames.distinct.length == fieldNames.length && + fallbackPaths.length == fieldNames.length) + + @transient + private lazy val fieldToOrdinal: Map[String, Int] = fieldNames.zipWithIndex.toMap + + @transient + private lazy val nullRow: InternalRow = + new GenericInternalRow(Array.ofDim[Any](fieldNames.length)) + + @transient + private lazy val fallbackEvaluators: Seq[GetJsonObjectEvaluator] = + fallbackPaths.map(new GetJsonObjectEvaluator(_)) + + @transient + private lazy val outputBuffer = new ByteArrayOutputStream() + + private def fallback(json: UTF8String): InternalRow = { + new GenericInternalRow(fallbackEvaluators.map { evaluator => + evaluator.setJson(json) + evaluator.evaluate() + }.toArray) + } + + def evaluate(json: UTF8String): InternalRow = { + if (json == null) return null + + val values = Array.ofDim[Any](fieldNames.length) + val matched = Array.ofDim[Boolean](fieldNames.length) + + try { + val validObject = Utils.tryWithResource( + CreateJacksonParser.utf8String(jsonFactory, json)) { parser => + if (parser.nextToken() != JsonToken.START_OBJECT) { + false + } else { + var token = parser.nextToken() + while (token != null && token != JsonToken.END_OBJECT) { + if (token == JsonToken.FIELD_NAME) { + val fieldName = parser.currentName + val ordinal = fieldToOrdinal.get(fieldName).filter(!matched(_)) + val valueToken = parser.nextToken() + if (ordinal.nonEmpty && valueToken != JsonToken.VALUE_NULL) { + val index = ordinal.get + matched(index) = true + copyCurrentStructure(parser).foreach(value => values(index) = value) + } else { + parser.skipChildren() + } + } else { + parser.skipChildren() + } + token = parser.nextToken() + } + token == JsonToken.END_OBJECT + } + } + if (validObject) { + new GenericInternalRow(values) + } else { + nullRow + } + } catch { + // Every simple top-level legacy extraction scans through the root object's closing token, + // so a syntax failure makes every sibling null without needing per-path reparsing. + case _: JsonParseException => nullRow + // A parser-side rendering failure can leave the shared token stream unusable. Reparse each + // path with the legacy evaluator so one bad selected value cannot erase sibling results. + case _: JsonProcessingException => fallback(json) + } + } + + private def copyCurrentStructure(parser: JsonParser): Option[UTF8String] = { + outputBuffer.reset() + var renderingFailed = false + + def render(write: => Unit): Unit = { + if (!renderingFailed) { + try { + write + } catch { + // A generator-side failure does not invalidate the parser's token stream. Keep + // consuming that value so other requested fields remain independent. + case _: JsonGenerationException => renderingFailed = true + } + } + } + + def copyValue(generator: JsonGenerator, rawString: Boolean): Unit = { + if (parser.currentToken == JsonToken.VALUE_STRING && rawString) { + render { + if (parser.hasTextCharacters) { + generator.writeRaw( + parser.getTextCharacters, + parser.getTextOffset, + parser.getTextLength) + } else { + generator.writeRaw(parser.getText) + } + } + } else { + // Keep this traversal iterative so a value near the configured nesting limit does not + // consume one JVM frame per level. + var depth = 0 + var done = false + while (!done && parser.currentToken != null) { + parser.currentToken match { + case JsonToken.START_OBJECT => + render(generator.writeStartObject()) + depth += 1 + case JsonToken.START_ARRAY => + render(generator.writeStartArray()) + depth += 1 + case JsonToken.END_OBJECT => + render(generator.writeEndObject()) + depth -= 1 + case JsonToken.END_ARRAY => + render(generator.writeEndArray()) + depth -= 1 + case _ => + render(generator.copyCurrentEvent(parser)) + } + done = depth == 0 + if (!done) { + parser.nextToken() + } + } + } + } + + try { + Utils.tryWithResource( + jsonFactory.createGenerator(outputBuffer, JsonEncoding.UTF8)) { generator => + copyValue(generator, rawString = true) + } + } catch { + case _: JsonGenerationException => renderingFailed = true + } + + if (renderingFailed) None else Some(UTF8String.fromBytes(outputBuffer.toByteArray)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 3ec17c6ca584a..9116bb60c4a7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -22,10 +22,13 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper -import org.apache.spark.sql.catalyst.expressions.json.{GetJsonObjectEvaluator, JsonExpressionUtils, JsonToStructsEvaluator, JsonTupleEvaluator, SchemaOfJsonEvaluator, StructsToJsonEvaluator} +import org.apache.spark.sql.catalyst.expressions.json.{GetJsonObjectEvaluator, JsonExpressionUtils, + JsonPathParser, JsonToStructsEvaluator, JsonTupleEvaluator, MultiGetJsonObjectEvaluator, + PathInstruction, SchemaOfJsonEvaluator, StructsToJsonEvaluator} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, RUNTIME_REPLACEABLE, TreePattern} +import org.apache.spark.sql.catalyst.trees.TreePattern.{GET_JSON_OBJECT, JSON_TO_STRUCT, + RUNTIME_REPLACEABLE, TreePattern} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeWithCollation @@ -63,6 +66,8 @@ case class GetJsonObject(json: Expression, path: Expression) override def nullable: Boolean = true override def prettyName: String = "get_json_object" + final override val nodePatterns: Seq[TreePattern] = Seq(GET_JSON_OBJECT) + @transient private lazy val evaluator = if (path.foldable) { new GetJsonObjectEvaluator(path.eval().asInstanceOf[UTF8String]) @@ -136,6 +141,82 @@ case class GetJsonObject(json: Expression, path: Expression) copy(json = newLeft, path = newRight) } +object GetJsonObject { + private[sql] def simpleTopLevelField(path: UTF8String): Option[String] = { + try { + Option(path).flatMap(value => JsonPathParser.parse(value.toString)).collect { + case List(PathInstruction.Key, PathInstruction.Named(fieldName)) => fieldName + } + } catch { + // Numeric subscripts are parsed as Long and can overflow before the parser returns None. + case _: NumberFormatException => None + } + } +} + +/** + * Extracts multiple simple top-level fields from a JSON string in one parse. This is an internal + * expression used to share sibling [[GetJsonObject]] expressions; unsupported JSON paths remain + * as independent GetJsonObject expressions. + */ +case class MultiGetJsonObject( + json: Expression, + fieldNames: Seq[String], + fallbackPaths: Seq[String]) + extends UnaryExpression + with ExpectsInputTypes { + + require( + fieldNames.nonEmpty && + fieldNames.distinct.length == fieldNames.length && + fallbackPaths.length == fieldNames.length) + + override def child: Expression = json + + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCollation(supportsTrimCollation = true)) + + override lazy val dataType: DataType = StructType(fieldNames.indices.map { index => + StructField(s"_$index", StringType, nullable = true) + }) + + override def nullable: Boolean = true + + // This internal unary expression always returns null when its JSON child is null. + override def nullIntolerant: Boolean = true + + override def prettyName: String = "multi_get_json_object" + + final override val nodePatterns: Seq[TreePattern] = Seq(GET_JSON_OBJECT) + + @transient + private lazy val evaluator = MultiGetJsonObjectEvaluator( + fieldNames, + fallbackPaths.map(UTF8String.fromString)) + + override def eval(input: InternalRow): Any = { + evaluator.evaluate(json.eval(input).asInstanceOf[UTF8String]) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val refEvaluator = ctx.addReferenceObj("evaluator", evaluator) + val jsonEval = json.genCode(ctx) + val resultType = CodeGenerator.javaType(dataType) + ev.copy(code = code""" + |${jsonEval.code} + |boolean ${ev.isNull} = ${jsonEval.isNull}; + |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${ev.isNull}) { + | ${ev.value} = ($resultType) $refEvaluator.evaluate(${jsonEval.value}); + | ${ev.isNull} = ${ev.value} == null; + |} + |""".stripMargin) + } + + override protected def withNewChildInternal(newChild: Expression): MultiGetJsonObject = + copy(json = newChild) +} + // scalastyle:off line.size.limit line.contains.tab @ExpressionDescription( usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - Returns a tuple like the function get_json_object, but it takes multiple names. All the input parameters and output column types are string.", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvJsonExprs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvJsonExprs.scala index 04cc230f99b44..29d4073916f56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvJsonExprs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvJsonExprs.scala @@ -17,12 +17,16 @@ package org.apache.spark.sql.catalyst.optimizer +import scala.collection.mutable + import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{CREATE_NAMED_STRUCT, EXTRACT_VALUE, - JSON_TO_STRUCT} -import org.apache.spark.sql.types.{ArrayType, StructType} + GET_JSON_OBJECT, JSON_TO_STRUCT} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ArrayType, StringType, StructType} +import org.apache.spark.unsafe.types.UTF8String /** * Simplify redundant csv/json related expressions. @@ -34,16 +38,38 @@ import org.apache.spark.sql.types.{ArrayType, StructType} * If(IsNull(json), nullStruct, KnownNotNull(JsonToStructs(prunedSchema, ..., json))) * if JsonToStructs(json) is shared among all fields of CreateNamedStruct. `prunedSchema` * contains all accessed fields in original CreateNamedStruct. - * 4. Prune unnecessary columns from GetStructField + CsvToStructs. + * 4. Share one MultiGetJsonObject when a Project extracts multiple simple paths from a JSON. + * 5. Prune unnecessary columns from GetStructField + CsvToStructs. */ object OptimizeCsvJsonExprs extends Rule[LogicalPlan] { private def nameOfCorruptRecord = conf.columnNameOfCorruptRecord + private case class SharedJsonFields( + json: Expression, + fieldNames: Seq[String], + alias: Alias) { + val ordinalMapping: Map[String, Int] = fieldNames.zipWithIndex.toMap + } + + private def evaluatesLeftFirst(binary: BinaryArithmetic): Boolean = binary match { + case _: Add | _: Subtract | _: Multiply | _: BitwiseAnd | _: BitwiseOr | _: BitwiseXor => true + case _ => false + } + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( - _.containsAnyPattern(CREATE_NAMED_STRUCT, EXTRACT_VALUE, JSON_TO_STRUCT), ruleId) { + _.containsAnyPattern(CREATE_NAMED_STRUCT, EXTRACT_VALUE, GET_JSON_OBJECT, JSON_TO_STRUCT), + ruleId) { case p => val optimized = if (conf.jsonExpressionOptimization) { - p.transformExpressionsWithPruning( + val withSharedJsonPaths = p match { + case project: Project + if conf.getJsonObjectSharedParsingEnabled && + !conf.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE) && + project.projectList.exists(_.exists(_.isInstanceOf[GetJsonObject])) => + shareGetJsonObjects(project) + case _ => p + } + withSharedJsonPaths.transformExpressionsWithPruning( _.containsAnyPattern(CREATE_NAMED_STRUCT, EXTRACT_VALUE, JSON_TO_STRUCT) )(jsonOptimization) } else { @@ -58,6 +84,122 @@ object OptimizeCsvJsonExprs extends Rule[LogicalPlan] { } } + /** + * Share simple top-level GetJsonObject paths without changing the Hive-compatible semantics of + * nested paths, wildcards, or array subscripts. [[MultiGetJsonObject]] preserves the first + * non-null duplicate-key match used by GetJsonObject, unlike JsonTuple. + */ + private def shareGetJsonObjects(project: Project): Project = { + val candidates = project.projectList.flatMap(collectGetJsonObjectFields) + val groups = mutable.ArrayBuffer.empty[ + (Expression, mutable.ArrayBuffer[(String, String)])] + val groupsByHash = mutable.HashMap.empty[ + Int, mutable.ArrayBuffer[(Expression, mutable.ArrayBuffer[(String, String)])]] + + candidates.foreach { case (getJsonObject, fieldName, path) => + val bucket = groupsByHash.getOrElseUpdate( + getJsonObject.json.semanticHash(), mutable.ArrayBuffer.empty) + bucket.find(_._1.semanticEquals(getJsonObject.json)) match { + case Some((_, fields)) => fields += fieldName -> path + case None => + val group = getJsonObject.json -> mutable.ArrayBuffer(fieldName -> path) + bucket += group + groups += group + } + } + + val sharedFields = groups.flatMap { case (json, requestedFields) => + val fieldsByName = mutable.LinkedHashMap.empty[String, String] + requestedFields.foreach { case (fieldName, path) => + fieldsByName.getOrElseUpdate(fieldName, path) + } + val fieldNames = fieldsByName.keys.toSeq + if (fieldNames.length > 1) { + val alias = Alias( + MultiGetJsonObject(json, fieldNames, fieldsByName.values.toSeq), + "_shared_json_paths")() + Some(SharedJsonFields(json, fieldNames, alias)) + } else { + None + } + }.toSeq + + if (sharedFields.isEmpty) { + project + } else { + val sharedFieldsByHash = sharedFields.groupBy(_.json.semanticHash()) + val rewrittenProjectList = project.projectList.map { expression => + rewriteGetJsonObjectFields(expression, sharedFieldsByHash) + .asInstanceOf[NamedExpression] + } + val innerProjectList = project.child.output ++ sharedFields.map(_.alias) + Project(rewrittenProjectList, Project(innerProjectList, project.child)) + } + } + + private def collectGetJsonObjectFields( + expression: Expression): Seq[(GetJsonObject, String, String)] = { + expression match { + case getJsonObject @ GetJsonObject(_: Attribute, Literal(path: UTF8String, StringType)) + if getJsonObject.deterministic => + GetJsonObject.simpleTopLevelField(path) + .map(fieldName => (getJsonObject, fieldName, path.toString)).toSeq + + case _: GetJsonObject => + Nil + + case other => + getJsonObjectTraversalChild(other).toSeq.flatMap(collectGetJsonObjectFields) + } + } + + private def rewriteGetJsonObjectFields( + expression: Expression, + sharedFieldsByHash: Map[Int, Seq[SharedJsonFields]]): Expression = { + expression match { + case getJsonObject @ GetJsonObject(json, Literal(path: UTF8String, StringType)) => + val replacement = for { + fieldName <- GetJsonObject.simpleTopLevelField(path) + shared <- sharedFieldsByHash.getOrElse(json.semanticHash(), Nil).find { candidate => + candidate.json.semanticEquals(json) && candidate.ordinalMapping.contains(fieldName) + } + } yield GetStructField(shared.alias.toAttribute, shared.ordinalMapping(fieldName)) + replacement.getOrElse(getJsonObject) + + case _: GetJsonObject => + expression + + case other => + getJsonObjectTraversalChild(other).map { child => + other.withNewChildren( + rewriteGetJsonObjectFields(child, sharedFieldsByHash) +: other.children.tail) + }.getOrElse(other) + } + } + + private def getJsonObjectTraversalChild(expression: Expression): Option[Expression] = { + expression match { + case _: ConditionalExpression | _: And | _: Or | _: In | _: TryEval | + _: LambdaFunction | _: CreateNamedStruct => + None + + case alias: Alias => + Some(alias.child) + + case getStructField: GetStructField => + Some(getStructField.child) + + case cast: Cast => + Some(cast.child) + + case binary: BinaryArithmetic if evaluatesLeftFirst(binary) => + Some(binary.left) + + case _ => + None + } + } + private val jsonOptimization: PartialFunction[Expression, Expression] = { case c: CreateNamedStruct // If we create struct from various fields of the same `JsonToStructs`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index f676750a054a4..d0d912ae4cc6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -58,6 +58,7 @@ object TreePattern extends Enumeration { val FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION: Value = Value val GENERATOR: Value = Value val GROUPING_ANALYTICS: Value = Value + val GET_JSON_OBJECT: Value = Value val HIGH_ORDER_FUNCTION: Value = Value val IF: Value = Value val IN: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4f76dec323ef3..31e3e976b2e42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3833,6 +3833,17 @@ object SQLConf { .booleanConf .createWithDefault(true) + val GET_JSON_OBJECT_SHARED_PARSING_ENABLED = + buildConf("spark.sql.optimizer.getJsonObjectSharedParsing.enabled") + .internal() + .doc(s"When true and '${JSON_EXPRESSION_OPTIMIZATION.key}' is also true, the optimizer " + + "replaces repeated simple top-level get_json_object expressions over the same input " + + "with one shared parse.") + .version("4.3.0") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) + .booleanConf + .createWithDefault(false) + val CSV_EXPRESSION_OPTIMIZATION = buildConf("spark.sql.optimizer.enableCsvExpressionOptimization") .doc("Whether to optimize CSV expressions in SQL optimizer. It includes pruning " + @@ -8259,6 +8270,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def jsonExpressionOptimization: Boolean = getConf(SQLConf.JSON_EXPRESSION_OPTIMIZATION) + def getJsonObjectSharedParsingEnabled: Boolean = + getConf(SQLConf.GET_JSON_OBJECT_SHARED_PARSING_ENABLED) + def csvExpressionOptimization: Boolean = getConf(SQLConf.CSV_EXPRESSION_OPTIMIZATION) def parallelFileListingInStatsComputation: Boolean = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala index eed06da609f8e..85d621d2bc798 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.DateTimeUtils.getZoneId import org.apache.spark.sql.internal.SQLConf @@ -30,18 +30,40 @@ import org.apache.spark.sql.types._ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { private var jsonExpressionOptimizeEnabled: Boolean = _ + private var getJsonObjectSharedParsingEnabled: Boolean = _ protected override def beforeAll(): Unit = { jsonExpressionOptimizeEnabled = SQLConf.get.jsonExpressionOptimization + getJsonObjectSharedParsingEnabled = SQLConf.get.getJsonObjectSharedParsingEnabled + SQLConf.get.setConf(SQLConf.GET_JSON_OBJECT_SHARED_PARSING_ENABLED, true) } protected override def afterAll(): Unit = { SQLConf.get.setConf(SQLConf.JSON_EXPRESSION_OPTIMIZATION, jsonExpressionOptimizeEnabled) + SQLConf.get.setConf( + SQLConf.GET_JSON_OBJECT_SHARED_PARSING_ENABLED, getJsonObjectSharedParsingEnabled) } object Optimizer extends RuleExecutor[LogicalPlan] { val batches = Batch("Json optimization", FixedPoint(10), OptimizeCsvJsonExprs) :: Nil } + object OptimizerWithCollapseProject extends RuleExecutor[LogicalPlan] { + val batches = Batch( + "Json optimization with project collapse", + FixedPoint(10), + CollapseProject, + OptimizeCsvJsonExprs) :: Nil + } + + object OptimizerWithColumnPruning extends RuleExecutor[LogicalPlan] { + val batches = Batch( + "Json optimization with column pruning", + FixedPoint(10), + ColumnPruning, + CollapseProject, + OptimizeCsvJsonExprs) :: Nil + } + val schema = StructType.fromDDL("a int, b int") private val structAtt = $"struct".struct(schema).notNull @@ -175,6 +197,130 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { comparePlans(optimized2, expected2) } + test("SPARK-47670: share simple top-level get_json_object paths") { + val query = testRelation2.select( + Cast(GetJsonObject($"json", Literal("$.b")), LongType).as("b"), + GetJsonObject($"json", Literal("$['a']")).as("a")) + val optimized = Optimizer.execute(query.analyze) + + optimized match { + case Project(projectList, Project(innerProjectList, _: LocalRelation)) => + val sharedAlias = innerProjectList.collectFirst { + case alias @ Alias(_: MultiGetJsonObject, "_shared_json_paths") => alias + }.getOrElse(fail(s"Missing shared JSON paths in plan:\n$optimized")) + val shared = sharedAlias.child.asInstanceOf[MultiGetJsonObject] + assert(shared.fieldNames == Seq("b", "a")) + assert(shared.fallbackPaths == Seq("$.b", "$['a']")) + + val sharedAttr = sharedAlias.toAttribute + val extractedFields = projectList.flatMap(_.collect { + case getStructField: GetStructField + if getStructField.child.semanticEquals(sharedAttr) => getStructField + }) + assert(extractedFields.map(_.ordinal) == Seq(0, 1)) + + case _ => + fail(s"Expected shared JSON paths below the project, but found:\n$optimized") + } + } + + test("SPARK-47670: shared get_json_object parsing is disabled by default") { + assert(!new SQLConf().getJsonObjectSharedParsingEnabled) + val query = testRelation2.select( + GetJsonObject($"json", Literal("$.a")).as("a"), + GetJsonObject($"json", Literal("$.b")).as("b")) + + withSQLConf(SQLConf.GET_JSON_OBJECT_SHARED_PARSING_ENABLED.key -> "false") { + comparePlans(Optimizer.execute(query.analyze), query.analyze) + } + } + + test("SPARK-47670: do not fail optimization for unparseable get_json_object paths") { + val oversizedIndex = "$[" + "9" * 100 + "]" + val query = testRelation2.select( + GetJsonObject($"json", Literal(oversizedIndex)).as("value")) + + comparePlans(Optimizer.execute(query.analyze), query.analyze) + } + + test("SPARK-47670: do not combine nested get_json_object paths") { + val nested = GetJsonObject($"json", Literal("$.a.b")) + val query = testRelation2.select( + GetJsonObject($"json", Literal("$.a")).as("a"), + nested.as("nested"), + GetJsonObject($"json", Literal("$.c")).as("c")) + val optimized = Optimizer.execute(query.analyze) + + val shared = optimized.collect { + case Project(projectList, _) => projectList.collectFirst { + case alias @ Alias(_: MultiGetJsonObject, "_shared_json_paths") => alias + } + }.flatten.headOption.getOrElse(fail(s"Missing shared JSON paths in plan:\n$optimized")) + assert(shared.child.asInstanceOf[MultiGetJsonObject].fieldNames == Seq("a", "c")) + assert(optimized.expressions.exists(_.exists { + case GetJsonObject(_, Literal(path, StringType)) => path.toString == "$.a.b" + case _ => false + })) + } + + test("SPARK-47670: shared get_json_object paths survive project collapsing") { + val query = testRelation2.select( + GetJsonObject($"json", Literal("$.a")).as("a"), + GetJsonObject($"json", Literal("$.b")).as("b")) + + val optimized = OptimizerWithCollapseProject.execute(query.analyze) + + assert(optimized.exists { plan => + plan.expressions.exists(_.exists(_.isInstanceOf[MultiGetJsonObject])) + }) + assert(optimized.collect { case _: Project => true }.length == 2) + } + + test("SPARK-47670: do not share get_json_object paths that are guarded or pruned") { + val guardedQuery = testRelation2.select( + If( + IsNull($"json"), + Literal(null, StringType), + GetJsonObject($"json", Literal("$.a"))).as("a"), + GetJsonObject($"json", Literal("$.b")).as("b")) + assert(!Optimizer.execute(guardedQuery.analyze).exists { plan => + plan.expressions.exists(_.exists(_.isInstanceOf[MultiGetJsonObject])) + }) + + val lowerProject = testRelation2.select( + GetJsonObject($"json", Literal("$.a")).as("a"), + GetJsonObject($"json", Literal("$.b")).as("b")) + val prunedQuery = lowerProject.select(lowerProject.output.head) + assert(!OptimizerWithColumnPruning.execute(prunedQuery.analyze).exists { plan => + plan.expressions.exists(_.exists(_.isInstanceOf[MultiGetJsonObject])) + }) + } + + test("SPARK-47670: do not share separately projected from_json fields") { + val schema = StructType.fromDDL("a int, b struct") + val fromJson = JsonToStructs(schema, Map.empty, $"json") + val query = testRelation2.select( + GetStructField(fromJson, 0).as("a"), + GetStructField(fromJson, 1).as("b")) + + val optimized = Optimizer.execute(query.analyze) + val parsedSchemas = optimized.expressions.flatMap(_.collect { + case jsonToStructs: JsonToStructs => jsonToStructs.schema + }) + assert(parsedSchemas == Seq( + StructType.fromDDL("a int"), + StructType.fromDDL("b struct"))) + } + + test("SPARK-47670: do not share get_json_object below right-first arithmetic") { + val query = testRelation2.select( + Pmod(Cast(GetJsonObject($"json", Literal("$.a")), IntegerType), Literal(0)).as("a"), + GetJsonObject($"json", Literal("$.b")).as("b")) + assert(!Optimizer.execute(query.analyze).exists { plan => + plan.expressions.exists(_.exists(_.isInstanceOf[MultiGetJsonObject])) + }) + } + test("SPARK-32958: prune unnecessary columns from GetArrayStructFields + from_json") { val options = Map.empty[String, String] val schema1 = ArrayType(StructType.fromDDL("a int, b int"), containsNull = true) diff --git a/sql/core/benchmarks/SharedJsonParseBenchmark-jdk21-results.txt b/sql/core/benchmarks/SharedJsonParseBenchmark-jdk21-results.txt new file mode 100644 index 0000000000000..dcda87f622937 --- /dev/null +++ b/sql/core/benchmarks/SharedJsonParseBenchmark-jdk21-results.txt @@ -0,0 +1,33 @@ +================================================================================================ +Benchmark for sharing repeated get_json_object parsing +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +get_json_object extracting 2 of 32 fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------- +shared parsing off 1686 1712 25 0.1 8429.2 1.0X +shared parsing on 979 1004 23 0.2 4892.7 1.7X + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +get_json_object extracting 4 of 32 fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------- +shared parsing off 3273 3279 6 0.1 16365.1 1.0X +shared parsing on 1242 1250 14 0.2 6209.0 2.6X + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +get_json_object extracting 8 of 32 fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------- +shared parsing off 6449 6465 16 0.0 32247.0 1.0X +shared parsing on 1317 1331 12 0.2 6586.7 4.9X + +OpenJDK 64-Bit Server VM 21.0.11+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +get_json_object extracting 16 of 32 fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------- +shared parsing off 13118 13134 15 0.0 65592.1 1.0X +shared parsing on 1823 1830 9 0.1 9113.0 7.2X + + diff --git a/sql/core/benchmarks/SharedJsonParseBenchmark-jdk25-results.txt b/sql/core/benchmarks/SharedJsonParseBenchmark-jdk25-results.txt new file mode 100644 index 0000000000000..b09050a476f3b --- /dev/null +++ b/sql/core/benchmarks/SharedJsonParseBenchmark-jdk25-results.txt @@ -0,0 +1,33 @@ +================================================================================================ +Benchmark for sharing repeated get_json_object parsing +================================================================================================ + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +get_json_object extracting 2 of 32 fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------- +shared parsing off 1663 1680 26 0.1 8315.1 1.0X +shared parsing on 964 980 14 0.2 4818.8 1.7X + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +get_json_object extracting 4 of 32 fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------- +shared parsing off 3200 3215 14 0.1 15998.6 1.0X +shared parsing on 1096 1111 14 0.2 5478.2 2.9X + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +get_json_object extracting 8 of 32 fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------- +shared parsing off 6233 6248 13 0.0 31163.9 1.0X +shared parsing on 1304 1319 14 0.2 6519.4 4.8X + +OpenJDK 64-Bit Server VM 25.0.3+9-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +get_json_object extracting 16 of 32 fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------- +shared parsing off 12453 12474 19 0.0 62264.1 1.0X +shared parsing on 1776 1785 9 0.1 8879.8 7.0X + + diff --git a/sql/core/benchmarks/SharedJsonParseBenchmark-results.txt b/sql/core/benchmarks/SharedJsonParseBenchmark-results.txt new file mode 100644 index 0000000000000..fdee387321846 --- /dev/null +++ b/sql/core/benchmarks/SharedJsonParseBenchmark-results.txt @@ -0,0 +1,33 @@ +================================================================================================ +Benchmark for sharing repeated get_json_object parsing +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +get_json_object extracting 2 of 32 fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------- +shared parsing off 1526 1558 35 0.1 7629.3 1.0X +shared parsing on 902 907 6 0.2 4508.0 1.7X + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +get_json_object extracting 4 of 32 fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------- +shared parsing off 2978 2999 25 0.1 14890.0 1.0X +shared parsing on 1027 1029 1 0.2 5136.9 2.9X + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +get_json_object extracting 8 of 32 fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------- +shared parsing off 5871 5889 16 0.0 29355.9 1.0X +shared parsing on 1245 1251 6 0.2 6226.5 4.7X + +OpenJDK 64-Bit Server VM 17.0.19+10-LTS on Linux 6.17.0-1018-azure +AMD EPYC 7763 64-Core Processor +get_json_object extracting 16 of 32 fields: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------- +shared parsing off 11620 11679 54 0.0 58098.9 1.0X +shared parsing on 1704 1707 3 0.1 8519.4 6.8X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 8857a9977b573..50d5b4be24eed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -23,13 +23,15 @@ import java.util.Locale import scala.jdk.CollectionConverters._ +import com.fasterxml.jackson.core.StreamReadConstraints + import org.apache.spark.{SparkException, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{JsonToStructs, Literal, MultiGetJsonObject} import org.apache.spark.sql.catalyst.expressions.Cast._ -import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.TimestampNanosTestUtils.foreachNanosPrecision import org.apache.spark.sql.errors.DataTypeErrors.toSQLType -import org.apache.spark.sql.execution.WholeStageCodegenExec +import org.apache.spark.sql.execution.{InputAdapter, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -83,6 +85,174 @@ class JsonFunctionsSuite extends SharedSparkSession { expected) } + test("SPARK-47670: share simple top-level get_json_object paths") { + val input = Seq[String]( + """{"a":"one","a.b":"dotted","b":2,"obj":{"x":1},"arr":[1,2]}""", + """{"a":"first","a":"second","b":3}""", + """{"a":null,"a":"after_null","b":4}""", + """{"a":"before_error","b":"""", + """{'a':'single','b':5}""", + """[1,2,3]""", + """{"a":"trailing","b":6} trailing text""", + """{}""", + null) + + def result(jsonOptimization: Boolean, sharedParsing: Boolean): Seq[Row] = { + var rows = Seq.empty[Row] + withSQLConf( + SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> jsonOptimization.toString, + SQLConf.GET_JSON_OBJECT_SHARED_PARSING_ENABLED.key -> sharedParsing.toString) { + val df = input.toDF("json") + rows = df.select( + get_json_object($"json", "$.a"), + get_json_object($"json", "$['a.b']"), + get_json_object($"json", "$.b"), + get_json_object($"json", "$.b").cast(IntegerType), + get_json_object($"json", "$.obj"), + get_json_object($"json", "$.arr"), + get_json_object($"json", "$.missing")).collect().toSeq + } + rows + } + + val legacy = result(jsonOptimization = false, sharedParsing = false) + assert(result(jsonOptimization = true, sharedParsing = false) == legacy) + assert(result(jsonOptimization = true, sharedParsing = true) == legacy) + } + + test("SPARK-47670: shared get_json_object isolates value rendering failures") { + val invalidSurrogate = "\\" + "uD800" + val input = Seq( + s"""{"a":"before","b":"$invalidSurrogate","c":"after"}""", + s"""{"a":"before","b":{"nested":"$invalidSurrogate"},"c":"after"}""", + s"""{"a":"before","b":"$invalidSurrogate","b":"valid","c":"after"}""") + + def result(jsonOptimization: Boolean): Seq[Row] = { + var rows = Seq.empty[Row] + withSQLConf( + SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> jsonOptimization.toString, + SQLConf.GET_JSON_OBJECT_SHARED_PARSING_ENABLED.key -> "true") { + rows = input.toDF("json").select( + get_json_object($"json", "$.a"), + get_json_object($"json", "$.b"), + get_json_object($"json", "$.c")).collect().toSeq + } + rows + } + + assert(result(jsonOptimization = true) == result(jsonOptimization = false)) + assert(result(jsonOptimization = true) == Seq( + Row("before", null, "after"), + Row("before", s"""{"nested":"$invalidSurrogate"}""", "after"), + Row("before", null, "after"))) + } + + test("SPARK-47670: shared get_json_object falls back after parser-side rendering failure") { + val oversized = "x" * (StreamReadConstraints.DEFAULT_MAX_STRING_LEN + 1) + + def result(sharedParsingEnabled: Boolean): Seq[Row] = { + var rows = Seq.empty[Row] + withSQLConf( + SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> "true", + SQLConf.GET_JSON_OBJECT_SHARED_PARSING_ENABLED.key -> sharedParsingEnabled.toString) { + val query = Seq(s"""{"a":"$oversized","b.c":2}""").toDF("json").select( + get_json_object($"json", "$.a"), + get_json_object($"json", "$['b.c']")) + + if (sharedParsingEnabled) { + assert(query.queryExecution.optimizedPlan.exists { plan => + plan.expressions.exists(_.exists(_.isInstanceOf[MultiGetJsonObject])) + }) + } + rows = query.collect().toSeq + } + rows + } + + val sharedResult = result(sharedParsingEnabled = true) + assert(sharedResult == result(sharedParsingEnabled = false)) + assert(sharedResult == Seq(Row(null, "2"))) + } + + test("SPARK-47670: shared get_json_object does not return partial malformed results") { + val malformed = """{"a":1,"b":"\q}"}""" + + def result(jsonOptimization: Boolean): Seq[Row] = { + var rows = Seq.empty[Row] + withSQLConf( + SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> jsonOptimization.toString, + SQLConf.GET_JSON_OBJECT_SHARED_PARSING_ENABLED.key -> "true") { + val query = Seq(malformed).toDF("json").select( + get_json_object($"json", "$.a"), + get_json_object($"json", "$.b")) + if (jsonOptimization) { + assert(query.queryExecution.optimizedPlan.exists { plan => + plan.expressions.exists(_.exists(_.isInstanceOf[MultiGetJsonObject])) + }) + } + rows = query.collect().toSeq + } + rows + } + + val shared = result(jsonOptimization = true) + val legacy = result(jsonOptimization = false) + assert(shared == legacy) + assert(shared == Seq(Row(null, null))) + } + + test("SPARK-47670: shared get_json_object handles deeply nested values on a small stack") { + val depth = 999 + val nested = "[" * depth + "1" + "]" * depth + val expression = MultiGetJsonObject( + Literal(s"""{"a":$nested,"b":2}"""), Seq("a", "b"), Seq("$.a", "$.b")) + val result = Array.ofDim[Any](1) + val thread = new Thread( + null, + new Runnable { + override def run(): Unit = { + try { + result(0) = expression.eval() + } catch { + case error: Throwable => result(0) = error + } + } + }, + "deep-json-copy", + 256 * 1024) + + thread.start() + thread.join(30000) + assert(!thread.isAlive, "Deep JSON extraction did not finish") + result(0) match { + case error: Throwable => fail("Deep JSON extraction failed", error) + case row: InternalRow => + assert(row.getUTF8String(0).numBytes() == 2 * depth + 1) + assert(row.getUTF8String(1).toString == "2") + case other => fail(s"Unexpected deep JSON extraction result: $other") + } + } + + test("SPARK-47670: shared get_json_object supports project code generation") { + withSQLConf(SQLConf.GET_JSON_OBJECT_SHARED_PARSING_ENABLED.key -> "true") { + val df = Seq("""{"a":1,"b":2}""").toDF("json").select( + get_json_object($"json", "$.a"), + get_json_object($"json", "$.b")) + + checkAnswer(df, Row("1", "2")) + def containsSharedExtraction(plan: SparkPlan): Boolean = plan match { + case _: InputAdapter => false + case other + if other.expressions.exists(_.exists(_.isInstanceOf[MultiGetJsonObject])) => true + case other => other.children.exists(containsSharedExtraction) + } + assert(df.queryExecution.executedPlan.exists { + case stage: WholeStageCodegenExec => containsSharedExtraction(stage.child) + case _ => false + }, s"Shared get_json_object project was outside whole-stage codegen:\n${df.queryExecution}") + } + } + test("SPARK-42782: Hive compatibility check for get_json_object") { val book0 = "{\"author\":\"Nigel Rees\",\"title\":\"Sayings of the Century\"" + ",\"category\":\"reference\",\"price\":8.95}" @@ -1365,6 +1535,25 @@ class JsonFunctionsSuite extends SharedSparkSession { } } + test("SPARK-47670: separately projected from_json fields preserve malformed-input semantics") { + withSQLConf( + SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> "true", + SQLConf.JSON_ENABLE_PARTIAL_RESULTS.key -> "true") { + val schema = StructType.fromDDL("a int, b struct") + val input = Seq("""{"a":1,"b":{"x":""").toDF("json") + val parsed = from_json($"json", schema) + val query = input.select(parsed.getField("a"), parsed.getField("b")) + + val parsedSchemas = query.queryExecution.optimizedPlan.expressions.flatMap(_.collect { + case jsonToStructs: JsonToStructs => jsonToStructs.schema + }) + assert(parsedSchemas == Seq( + StructType.fromDDL("a int"), + StructType.fromDDL("b struct"))) + checkAnswer(query, Row(null, null)) + } + } + test("SPARK-35982: from_json/to_json for map types where value types are year-month intervals") { val ymDF = Seq(Period.of(1, 2, 0)).toDF() Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/SharedJsonParseBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/SharedJsonParseBenchmark.scala new file mode 100644 index 0000000000000..7d2f74b43a0c8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/SharedJsonParseBenchmark.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.json + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StringType + +/** + * Benchmarks sharing repeated simple top-level get_json_object parsing. + * + * To run this benchmark: + * {{{ + * 1. build/sbt "sql/Test/runMain " + * 2. Generate the result file: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/Test/runMain " + * Results will be written to + * "sql/core/benchmarks/SharedJsonParseBenchmark-results.txt" on JDK 17. + * }}} + */ +object SharedJsonParseBenchmark extends SqlBasedBenchmark { + import spark.implicits._ + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("Benchmark for sharing repeated get_json_object parsing") { + val rows = 200000 + val fieldCount = 32 + val fieldValue = concat(lit("value-"), $"id".cast(StringType), lit("-" + "x" * 64)) + val data = spark.range(0, rows, 1, 4) + .select(to_json(struct(Seq.tabulate(fieldCount) { index => + fieldValue.as(s"field_$index") + }: _*)).as("json")) + .cache() + data.count() + + Seq(2, 4, 8, 16).foreach { selectedFieldCount => + val pathBenchmark = new Benchmark( + s"get_json_object extracting $selectedFieldCount of $fieldCount fields", + rows, + output = output) + + def extractPaths(sharedParsing: Boolean): Unit = { + withSQLConf( + SQLConf.JSON_EXPRESSION_OPTIMIZATION.key -> "true", + SQLConf.GET_JSON_OBJECT_SHARED_PARSING_ENABLED.key -> sharedParsing.toString) { + data.select(Seq.tabulate(selectedFieldCount) { index => + get_json_object($"json", s"$$.field_$index") + }: _*).noop() + } + } + + pathBenchmark.addCase("shared parsing off", 3) { _ => + extractPaths(sharedParsing = false) + } + pathBenchmark.addCase("shared parsing on", 3) { _ => + extractPaths(sharedParsing = true) + } + pathBenchmark.run() + } + + data.unpersist() + } + } +}