Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -573,3 +573,156 @@ case class GetJsonObjectEvaluator(cachedPath: UTF8String) {
}
}
}

/**
* Evaluates multiple simple top-level JSON fields in one parse.
*/
case class MultiGetJsonObjectEvaluator(
fieldNames: Seq[String],
fallbackPaths: Seq[UTF8String]) {
import SharedFactory._

require(
fieldNames.nonEmpty &&
fieldNames.distinct.length == fieldNames.length &&
fallbackPaths.length == fieldNames.length)

@transient
private lazy val fieldToOrdinal: Map[String, Int] = fieldNames.zipWithIndex.toMap

@transient
private lazy val nullRow: InternalRow =
new GenericInternalRow(Array.ofDim[Any](fieldNames.length))

@transient
private lazy val fallbackEvaluators: Seq[GetJsonObjectEvaluator] =
fallbackPaths.map(new GetJsonObjectEvaluator(_))

@transient
private lazy val outputBuffer = new ByteArrayOutputStream()

private def fallback(json: UTF8String): InternalRow = {
new GenericInternalRow(fallbackEvaluators.map { evaluator =>
evaluator.setJson(json)
evaluator.evaluate()
}.toArray)
}

def evaluate(json: UTF8String): InternalRow = {
if (json == null) return null

val values = Array.ofDim[Any](fieldNames.length)
val matched = Array.ofDim[Boolean](fieldNames.length)

try {
val validObject = Utils.tryWithResource(
CreateJacksonParser.utf8String(jsonFactory, json)) { parser =>
if (parser.nextToken() != JsonToken.START_OBJECT) {
false
} else {
var token = parser.nextToken()
while (token != null && token != JsonToken.END_OBJECT) {
if (token == JsonToken.FIELD_NAME) {
val fieldName = parser.currentName
val ordinal = fieldToOrdinal.get(fieldName).filter(!matched(_))
val valueToken = parser.nextToken()
if (ordinal.nonEmpty && valueToken != JsonToken.VALUE_NULL) {
val index = ordinal.get
matched(index) = true
copyCurrentStructure(parser).foreach(value => values(index) = value)
} else {
parser.skipChildren()
}
} else {
parser.skipChildren()
}
token = parser.nextToken()
}
token == JsonToken.END_OBJECT
}
}
if (validObject) {
new GenericInternalRow(values)
} else {
nullRow
}
} catch {
// Every simple top-level legacy extraction scans through the root object's closing token,
// so a syntax failure makes every sibling null without needing per-path reparsing.
case _: JsonParseException => nullRow
// A parser-side rendering failure can leave the shared token stream unusable. Reparse each
// path with the legacy evaluator so one bad selected value cannot erase sibling results.
case _: JsonProcessingException => fallback(json)
}
}

private def copyCurrentStructure(parser: JsonParser): Option[UTF8String] = {
outputBuffer.reset()
var renderingFailed = false

def render(write: => Unit): Unit = {
if (!renderingFailed) {
try {
write
} catch {
// A generator-side failure does not invalidate the parser's token stream. Keep
// consuming that value so other requested fields remain independent.
case _: JsonGenerationException => renderingFailed = true
}
}
}

def copyValue(generator: JsonGenerator, rawString: Boolean): Unit = {
if (parser.currentToken == JsonToken.VALUE_STRING && rawString) {
render {
if (parser.hasTextCharacters) {
generator.writeRaw(
parser.getTextCharacters,
parser.getTextOffset,
parser.getTextLength)
} else {
generator.writeRaw(parser.getText)
}
}
} else {
// Keep this traversal iterative so a value near the configured nesting limit does not
// consume one JVM frame per level.
var depth = 0
var done = false
while (!done && parser.currentToken != null) {
parser.currentToken match {
case JsonToken.START_OBJECT =>
render(generator.writeStartObject())
depth += 1
case JsonToken.START_ARRAY =>
render(generator.writeStartArray())
depth += 1
case JsonToken.END_OBJECT =>
render(generator.writeEndObject())
depth -= 1
case JsonToken.END_ARRAY =>
render(generator.writeEndArray())
depth -= 1
case _ =>
render(generator.copyCurrentEvent(parser))
}
done = depth == 0
if (!done) {
parser.nextToken()
}
}
}
}

try {
Utils.tryWithResource(
jsonFactory.createGenerator(outputBuffer, JsonEncoding.UTF8)) { generator =>
copyValue(generator, rawString = true)
}
} catch {
case _: JsonGenerationException => renderingFailed = true
}

if (renderingFailed) None else Some(UTF8String.fromBytes(outputBuffer.toByteArray))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.json.{GetJsonObjectEvaluator, JsonExpressionUtils, JsonToStructsEvaluator, JsonTupleEvaluator, SchemaOfJsonEvaluator, StructsToJsonEvaluator}
import org.apache.spark.sql.catalyst.expressions.json.{GetJsonObjectEvaluator, JsonExpressionUtils,
JsonPathParser, JsonToStructsEvaluator, JsonTupleEvaluator, MultiGetJsonObjectEvaluator,
PathInstruction, SchemaOfJsonEvaluator, StructsToJsonEvaluator}
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, RUNTIME_REPLACEABLE, TreePattern}
import org.apache.spark.sql.catalyst.trees.TreePattern.{GET_JSON_OBJECT, JSON_TO_STRUCT,
RUNTIME_REPLACEABLE, TreePattern}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCollation
Expand Down Expand Up @@ -63,6 +66,8 @@ case class GetJsonObject(json: Expression, path: Expression)
override def nullable: Boolean = true
override def prettyName: String = "get_json_object"

final override val nodePatterns: Seq[TreePattern] = Seq(GET_JSON_OBJECT)

@transient
private lazy val evaluator = if (path.foldable) {
new GetJsonObjectEvaluator(path.eval().asInstanceOf[UTF8String])
Expand Down Expand Up @@ -136,6 +141,82 @@ case class GetJsonObject(json: Expression, path: Expression)
copy(json = newLeft, path = newRight)
}

object GetJsonObject {
private[sql] def simpleTopLevelField(path: UTF8String): Option[String] = {
try {
Option(path).flatMap(value => JsonPathParser.parse(value.toString)).collect {
case List(PathInstruction.Key, PathInstruction.Named(fieldName)) => fieldName
}
} catch {
// Numeric subscripts are parsed as Long and can overflow before the parser returns None.
case _: NumberFormatException => None
}
}
}

/**
* Extracts multiple simple top-level fields from a JSON string in one parse. This is an internal
* expression used to share sibling [[GetJsonObject]] expressions; unsupported JSON paths remain
* as independent GetJsonObject expressions.
*/
case class MultiGetJsonObject(

@dongjoon-hyun dongjoon-hyun Jun 16, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like other classes, shall we define prettyName to be more complete? multi_get_json_object?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiGetJsonObject only wraps MultiGetJsonObjectEvaluator with a hand-written eval/doGenCode. Since the evaluator already holds all the logic, the framework can generate that for you via the Invoke(Literal.create(evaluator, ObjectType(...)), "evaluate", structType, Seq(json), Seq(json.dataType)) idiom that StructsToJson/SchemaOfJson use in this same file.

Best of all would be to make MultiGetJsonObject a RuntimeReplaceable with that Invoke as its replacement: you'd drop the hand-written doGenCode and keep the readable multi_get_json_object node in the plan. That's historically blocked for an optimizer-inserted node (it's created after ReplaceExpressions, so a RuntimeReplaceable would never be replaced), but SPARK-57512 / #56575 lifts that by letting an evaluable RuntimeReplaceable survive the optimizer and be materialized just before codegen — so this becomes possible if you coordinate with that change. If you'd rather not depend on it, emitting the Invoke directly and dropping the class works today (the only cost is EXPLAIN showing invoke(...)).

Behavior is preserved either way: Invoke's default propagateNull matches the current nullIntolerant = true, the nested Project still survives CollapseProject (an Invoke isn't "cheap" and the alias is referenced more than once), and the rule stays idempotent.

Minor and independent: the inner transformExpressionsWithPruning predicate at L72 adds GET_JSON_OBJECT, but the jsonOptimization partial function it drives has no GetJsonObject/MultiGetJsonObject case — only the outer transformWithPruning at L60 needs the pattern, so it can be dropped from the inner one. Non-blocking.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed suggestion. I removed GET_JSON_OBJECT from the inner pruning predicate in the latest commit.

For the larger codegen change, I agree that RuntimeReplaceable would be cleaner once #56575 is available. Since that PR is still open/WIP, and emitting Invoke directly today would make the optimized plan less readable, I would prefer to retain the current internal expression in this PR and handle that refactor as a follow-up. The current doGenCode contains no JSON-processing logic; it only evaluates the child, handles nullability, and delegates to MultiGetJsonObjectEvaluator.

json: Expression,
fieldNames: Seq[String],
fallbackPaths: Seq[String])
extends UnaryExpression
with ExpectsInputTypes {

require(
fieldNames.nonEmpty &&
fieldNames.distinct.length == fieldNames.length &&
fallbackPaths.length == fieldNames.length)

override def child: Expression = json

override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true))

override lazy val dataType: DataType = StructType(fieldNames.indices.map { index =>
StructField(s"_$index", StringType, nullable = true)
})

override def nullable: Boolean = true

// This internal unary expression always returns null when its JSON child is null.
override def nullIntolerant: Boolean = true

override def prettyName: String = "multi_get_json_object"

final override val nodePatterns: Seq[TreePattern] = Seq(GET_JSON_OBJECT)

@transient
private lazy val evaluator = MultiGetJsonObjectEvaluator(
fieldNames,
fallbackPaths.map(UTF8String.fromString))

override def eval(input: InternalRow): Any = {
evaluator.evaluate(json.eval(input).asInstanceOf[UTF8String])
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val refEvaluator = ctx.addReferenceObj("evaluator", evaluator)
val jsonEval = json.genCode(ctx)
val resultType = CodeGenerator.javaType(dataType)
ev.copy(code = code"""
|${jsonEval.code}
|boolean ${ev.isNull} = ${jsonEval.isNull};
|$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = ($resultType) $refEvaluator.evaluate(${jsonEval.value});
| ${ev.isNull} = ${ev.value} == null;
|}
|""".stripMargin)
}

override protected def withNewChildInternal(newChild: Expression): MultiGetJsonObject =
copy(json = newChild)
}

// scalastyle:off line.size.limit line.contains.tab
@ExpressionDescription(
usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - Returns a tuple like the function get_json_object, but it takes multiple names. All the input parameters and output column types are string.",
Expand Down
Loading