Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,21 @@ abstract class JsonDataSource extends Serializable with Logging {
* Streams a tar archive (`.tar`/`.tar.gz`/`.tgz`) entry by entry through the JSON parser without
* unpacking it to disk. The whole archive is a single split (see `JsonFileFormat.isSplitable`);
* each entry's bytes are parsed exactly like a standalone JSON file via [[readStream]], so this
* is mode-agnostic (line-delimited and multi-line both flow through `readStream`) and a single
* `parser` serves every entry -- unlike CSV there is no per-entry header to rebuild. Kept apart
* from [[readFile]] because only the V1 `JsonFileFormat` read path supports archives; the V2 data
* source calls [[readFile]] directly and is intentionally left untouched.
* is mode-agnostic (line-delimited and multi-line both flow through `readStream`). Each entry is
* parsed with its own parser -- matching the per-file parser of a non-archive read -- and unlike
* CSV there is no per-entry header to rebuild. Kept apart from [[readFile]] because only the V1
* `JsonFileFormat` read path supports archives; the V2 data source calls [[readFile]] directly
* and is intentionally left untouched.
*
* @param parser builds a fresh JSON parser for each entry.
*/
def readArchive(
conf: Configuration,
file: PartitionedFile,
parser: JacksonParser,
parser: () => JacksonParser,
schema: StructType): Iterator[InternalRow] =
ArchiveReader(file.toPath).readEntries(conf) { (_, in) =>
readStream(in, parser, schema)
readStream(in, parser(), schema)
}

final def inferSchema(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,19 @@ case class JsonFileFormat() extends TextBasedFileFormat with DataSourceRegister
}

(file: PartitionedFile) => {
val parser = new JacksonParser(
def parser() = new JacksonParser(
actualSchema,
parsedOptions,
allowArrayAsStructs = true,
filters)
if (parsedOptions.archiveFormatEnabled && ArchiveReader.isArchivePath(file.toPath)) {
JsonDataSource(parsedOptions).readArchive(
broadcastedHadoopConf.value.value, file, parser, requiredSchema)
broadcastedHadoopConf.value.value, file, () => parser(), requiredSchema)
} else {
JsonDataSource(parsedOptions).readFile(
broadcastedHadoopConf.value.value,
file,
parser,
parser(),
requiredSchema)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution.datasources.xml

import java.io.{FileNotFoundException, IOException}
import java.io.{ByteArrayInputStream, FileNotFoundException, InputStream, IOException}
import java.nio.charset.{Charset, StandardCharsets}

import scala.util.control.NonFatal
Expand Down Expand Up @@ -59,6 +59,38 @@ abstract class XmlDataSource extends Serializable with Logging {
parser: StaxXmlParser,
schema: StructType): Iterator[InternalRow]

/**
* Parse a single already-open [[InputStream]] -- one decompressed archive entry -- into 0 or more
* [[InternalRow]] instances, the same way this mode reads a standalone file: line by line for
* [[TextInputXmlDataSource]], as one whole document for [[MultiLineXmlDataSource]]. Used only by
* [[readArchive]]; the stream is not closed here.
*/
protected def readStream(
in: InputStream,
parser: StaxXmlParser,
schema: StructType): Iterator[InternalRow]

/**
* Streams a tar archive (`.tar`/`.tar.gz`/`.tgz`) entry by entry through the XML parser without
* unpacking it to disk. The whole archive is a single split (see `XmlFileFormat.isSplitable`);
* each entry's bytes are parsed exactly like a standalone XML file via [[readStream]], which each
* mode overrides (single-line splits into lines, multi-line parses the whole entry). Each entry
* is parsed with its own parser -- matching the per-file parser of a non-archive read.
*
* Kept separate from [[readFile]] (rather than dispatched inside it) because only the V1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking: the JSON archive support you're porting keeps readArchive concrete in the base — it calls an abstract readStream(in, parser(), schema) that each mode overrides, so the ArchiveReader.readEntries wiring lives in one place. Here readArchive is left abstract and both modes (TextInputXmlDataSource :239, MultiLineXmlDataSource :332) re-implement the ArchiveReader(file.toPath).readEntries(conf) { ... } wrapper, differing only in the per-entry body.

Mirroring JSON — a concrete base readArchive over an abstract readStream (single-line readStream = lines + FailureSafeParser; multi-line readStream = the legacy/optimized branch) — would centralize that wiring and match the peer. Behavior is identical either way, so this is purely a maintainability/consistency call.

* `XmlFileFormat` read path supports archives; XML has no DSv2 reader.
*
* @param parser builds a fresh XML parser for each entry.
*/
def readArchive(
conf: Configuration,
file: PartitionedFile,
parser: () => StaxXmlParser,
schema: StructType): Iterator[InternalRow] =
ArchiveReader(file.toPath).readEntries(conf) { (_, in) =>
readStream(in, parser(), schema)
}

/**
* Infers the schema from `inputPaths` files.
*/
Expand All @@ -69,7 +101,15 @@ abstract class XmlDataSource extends Serializable with Logging {
parsedOptions.singleVariantColumn match {
case Some(columnName) => Some(StructType(Array(StructField(columnName, VariantType))))
case None =>
if (inputPaths.nonEmpty) {
// When any input is a tar archive, infer over all inputs in a single pass -- archive
// entries are streamed (never unpacked to disk) and tokenized as XML records alongside any
// loose files -- so the result matches a directory read of the same files. XML has no DSv2
// reader, so this archive scan is always V1.
val hasArchive = parsedOptions.archiveFormatEnabled &&
inputPaths.exists(f => ArchiveReader.isArchivePath(f.getPath))
if (hasArchive) {
Some(inferWithArchives(sparkSession, inputPaths, parsedOptions))
} else if (inputPaths.nonEmpty) {
Some(infer(sparkSession, inputPaths, parsedOptions))
} else {
None
Expand All @@ -81,6 +121,90 @@ abstract class XmlDataSource extends Serializable with Logging {
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: XmlOptions): StructType

/**
* Infers an XML schema when at least one input is a tar archive (`.tar`/`.tar.gz`/`.tgz`). Every
* archive entry (streamed through `ArchiveReader`, never unpacked to disk) and every loose file
* is tokenized into records and fed to a single [[XmlInferSchema]] pass, exactly as a directory
* of the same files would infer. Tokenization is per-mode so it matches this mode's scan:
* multi-line splits the whole stream into `rowTag`-delimited records, single-line treats each
* line as a record (mirroring [[readFile]] and JSON's `inferWithArchives`).
*/
private def inferWithArchives(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: XmlOptions): StructType = {
val baseRdd = createBaseRdd(sparkSession, inputPaths, parsedOptions)
val ignoreCorruptFiles = parsedOptions.ignoreCorruptFiles
val ignoreMissingFiles = parsedOptions.ignoreMissingFiles

// Applies `perEntry` to each input -- an archive entry by entry (streamed, so only one entry's
// bytes are in flight at a time), a loose file directly -- skipping a whole input when it is
// corrupt/missing and the ignore flags are set.
def perInput(perEntry: InputStream => Iterator[String]): RDD[String] = baseRdd.flatMap {
stream =>
val path = new Path(stream.getPath())
try {
if (ArchiveReader.isArchivePath(path)) {
ArchiveReader(path).readEntries(stream.getConfiguration) { (_, in) => perEntry(in) }
} else {
perEntry(
CodecStreams.createInputStreamWithCloseResource(stream.getConfiguration, path))
}
} catch {
case e: FileNotFoundException if ignoreMissingFiles =>
logWarning("Skipped missing file", e)
Iterator.empty[String]
case NonFatal(e) =>
Utils.getRootCause(e) match {
case root @ (_: AccessControlException | _: BlockMissingException) => throw root
case _: RuntimeException | _: IOException if ignoreCorruptFiles =>
logWarning("Skipped the rest of the content in the corrupted file", e)
Iterator.empty[String]
case other => throw other
}
}
}

// Tokenize each input the way this mode's scan reads records, so the inferred schema matches a
// directory read: multi-line splits the whole stream into rowTag-delimited records, single-line
// treats each line as a record (mirroring TextInputXmlDataSource.readFile).
val tokenRDD: RDD[String] = if (parsedOptions.multiLine) {
perInput(in => StaxXmlParser.tokenizeStream(in, parsedOptions))
} else {
val charset = parsedOptions.charset
perInput(in => ArchiveReader.lineIterator(in, None).map { line =>
new String(line.getBytes, 0, line.getLength, charset)
})
}
SQLExecution.withSQLConfPropagated(sparkSession) {
new XmlInferSchema(parsedOptions, sparkSession.sessionState.conf.caseSensitiveAnalysis)
.infer(tokenRDD)
}
}

protected def createBaseRdd(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
options: XmlOptions): RDD[PortableDataStream] = {
val paths = inputPaths.map(_.getPath)
val name = paths.mkString(",")
val job = Job.getInstance(sparkSession.sessionState.newHadoopConfWithOptions(
options.parameters))
FileInputFormat.setInputPaths(job, paths: _*)
val conf = job.getConfiguration

val rdd = new BinaryFileRDD(
sparkSession.sparkContext,
classOf[StreamInputFormat],
classOf[String],
classOf[PortableDataStream],
conf,
sparkSession.sparkContext.defaultMinPartitions)

// Only returns `PortableDataStream`s without paths.
rdd.setName(s"XMLFile: $name").values
}
}

object XmlDataSource extends Logging {
Expand Down Expand Up @@ -120,6 +244,27 @@ object TextInputXmlDataSource extends XmlDataSource {
lines.flatMap(safeParser.parse)
}

/**
* Mirrors [[readFile]] for an archive entry: split it into lines and run each line through a
* [[FailureSafeParser]], so a single-line archive entry gets the same per-record corrupt-record
* handling as a non-archive single-line read. (Whole-stream parsing, as the multi-line override
* uses, would bypass that handling for single-line input.)
*/
override protected def readStream(
in: InputStream,
parser: StaxXmlParser,
schema: StructType): Iterator[InternalRow] = {
val lines = ArchiveReader.lineIterator(in, None).map { line =>
new String(line.getBytes, 0, line.getLength, parser.options.charset)
}
val safeParser = new FailureSafeParser[String](
input => parser.parse(input),
parser.options.parseMode,
schema,
parser.options.columnNameOfCorruptRecord)
lines.flatMap(safeParser.parse)
}

override def infer(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
Expand Down Expand Up @@ -185,6 +330,27 @@ object MultiLineXmlDataSource extends XmlDataSource {
}
}

/**
* Parses an archive entry as a single XML document, mirroring [[readFile]]: the optimized parser
* re-reads its input (to echo the corrupt-record text on a parse failure), which a single-use
* entry stream cannot do, so the entry's bytes are buffered and re-opened over; the legacy parser
* reads the entry stream directly. Buffering one whole entry in memory is an intended trade-off
* here -- the optimized parser requires a re-readable input, so a single very large XML document
* packed in an archive is materialized in full (a non-archive read streams from and re-opens the
* file instead). Entries are still read one at a time, so archive size itself stays bounded.
*/
override protected def readStream(
in: InputStream,
parser: StaxXmlParser,
schema: StructType): Iterator[InternalRow] = {
if (parser.options.useLegacyXMLParser) {
parser.parseStream(in, schema)
} else {
val bytes = in.readAllBytes()
parser.parseStreamOptimized(() => new ByteArrayInputStream(bytes), schema)
}
}

override def infer(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
Expand Down Expand Up @@ -250,27 +416,4 @@ object MultiLineXmlDataSource extends XmlDataSource {
schema
}
}

private def createBaseRdd(
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
options: XmlOptions): RDD[PortableDataStream] = {
val paths = inputPaths.map(_.getPath)
val name = paths.mkString(",")
val job = Job.getInstance(sparkSession.sessionState.newHadoopConfWithOptions(
options.parameters))
FileInputFormat.setInputPaths(job, paths: _*)
val conf = job.getConfiguration

val rdd = new BinaryFileRDD(
sparkSession.sparkContext,
classOf[StreamInputFormat],
classOf[String],
classOf[PortableDataStream],
conf,
sparkSession.sparkContext.defaultMinPartitions)

// Only returns `PortableDataStream`s without paths.
rdd.setName(s"XMLFile: $name").values
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ case class XmlFileFormat() extends TextBasedFileFormat with DataSourceRegister {
options: Map[String, String],
path: Path): Boolean = {
val xmlOptions = getXmlOptions(sparkSession, options)
if (xmlOptions.archiveFormatEnabled && ArchiveReader.isArchivePath(path)) {
// A tar archive is read as one sequential stream (entry by entry), so it is never split.
return false
}
XmlDataSource(xmlOptions).isSplitable && super.isSplitable(sparkSession, options, path)
}

Expand Down Expand Up @@ -116,14 +120,25 @@ case class XmlFileFormat() extends TextBasedFileFormat with DataSourceRegister {
}

(file: PartitionedFile) => {
val parser = new StaxXmlParser(
def parser() = new StaxXmlParser(
actualRequiredSchema,
xmlOptions)
XmlDataSource(xmlOptions).readFile(
broadcastedHadoopConf.value.value,
file,
parser,
requiredSchema)
// A tar archive (always a single split, see `isSplitable`) is streamed entry by entry when
// archive reads are enabled; otherwise the file is parsed directly. XML has no DSv2 reader,
// so this dispatch lives here rather than inside the shared `readFile`.
if (xmlOptions.archiveFormatEnabled && ArchiveReader.isArchivePath(file.toPath)) {
XmlDataSource(xmlOptions).readArchive(
broadcastedHadoopConf.value.value,
file,
() => parser(),
requiredSchema)
} else {
XmlDataSource(xmlOptions).readFile(
broadcastedHadoopConf.value.value,
file,
parser(),
requiredSchema)
}
}
}

Expand Down
Loading