Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,18 @@ private[spark] class DAGScheduler(

private[spark] val jobIdToQueryExecutionId = new ConcurrentHashMap[Int, java.lang.Long]()

// For INJECT_SHUFFLE_FETCH_FAILURES: per-shuffleId, the stage attempt whose partition-0 task
// we corrupted. Read to (a) avoid re-corrupting that partition on recompute, and (b) decide
// when to fire INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE - the recompute is the
// task whose stageAttemptId is not the recorded one.
// The three injectShuffleFetchFailures* maps below are test-only state for
// INJECT_SHUFFLE_FETCH_FAILURES, keyed by the globally-unique (never-reused) shuffleId and
// only allocated under Utils.isTesting. They are intentionally never evicted per-stage: under
// AQE each Exchange is materialized as its own map-stage job, and removing the entry when that
// job finishes (e.g. in cleanupStateForJobAndIndependentStages) would drop the pending
// corruption before the consuming stage is ever submitted. Bounded by the number of shuffles
// in a test SparkContext, so leaving them for its lifetime is harmless.

// Per-shuffleId, the stage attempt whose partition-0 task we corrupted. Read to (a) avoid
// re-corrupting that partition on recompute, and (b) decide when to fire
// INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE - the recompute is the task whose
// stageAttemptId is not the recorded one.
private val injectShuffleFetchFailuresCorruptedAttempt: ConcurrentHashMap[Int, Int] =
if (Utils.isTesting) new ConcurrentHashMap[Int, Int]() else null

Expand Down Expand Up @@ -967,11 +975,6 @@ private[spark] class DAGScheduler(
}
for ((k, v) <- shuffleIdToMapStage.find(_._2 == stage)) {
shuffleIdToMapStage.remove(k)
if (Utils.isTesting) {
injectShuffleFetchFailuresCorruptedAttempt.remove(k)
injectShuffleFetchFailuresPendingDelayedCorruption.remove(k)
injectShuffleFetchFailuresDownstreamSuccessCount.remove(k)
}
}
if (waitingStages.contains(stage)) {
logDebug("Removing stage %d from waiting set.".format(stageId))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2703,12 +2703,13 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
}

test("metric values are stable across stage retries") {
// The join in the MERGE plan introduces a shuffle (with broadcast disabled), and the
// DAGScheduler corrupts the first attempt of every upstream shuffle map stage. Note:
// the current fetch-failure injection does not retry the MergeRowsExec/writer stage,
// so this test passes equally well with plain SQLMetric — it only exercises the
// SLAM-aware read path. Follow-up #55738 will add infra to actually retry the writer
// stage and exercise the SLAM behavior end-to-end for MERGE.
// INJECT_SHUFFLE_FETCH_FAILURES corrupts the partition-0 task of the first successful
// attempt of every shuffle map stage, so a downstream stage FetchFails and the producer
// re-runs. For the metadata variants of MERGE - where the writer's
// `RequiresDistributionAndOrdering` forces a re-shuffle between MergeRowsExec and the
// writer - MergeRowsExec sits in a non-leaf shuffle map stage and therefore re-runs with
// the same metric instances, double-counting the per-row increments. SQLLastAttemptMetric
// reports only the last attempt, so `MergeSummary` is still correct.
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTempView("source") {
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
Expand All @@ -2720,17 +2721,18 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
val sourceDF = Seq(1, 2, 10).toDF("pk")
sourceDF.createOrReplaceTempView("source")

withSparkContextConf(
val mergeExec = withSparkContextConf(
config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true") {
sql(
findMergeExec {
s"""MERGE INTO $tableNameAsString t
|USING source s
|ON t.pk = s.pk
|WHEN MATCHED THEN
| UPDATE SET salary = salary + 100
|WHEN NOT MATCHED THEN
| INSERT (pk, salary, dep) VALUES (s.pk, 999, 'unknown')
|""".stripMargin)
|""".stripMargin
}
}

val mergeSummary = getMergeSummary()
Expand All @@ -2743,6 +2745,19 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase
assert(mergeSummary.numTargetRowsNotMatchedBySourceUpdated === 0L)
assert(mergeSummary.numTargetRowsNotMatchedBySourceDeleted === 0L)

// For metadata variants, MergeRowsExec lives in a non-leaf shuffle map stage that the
// fetch-failure injection forces to re-run, so the raw per-MergeRowsExec accumulator
// (`metric.value`) overcounts. SLAM-aware `MergeSummary` (asserted above) is correct.
if (!noMetadata) {
val rawUpdated = mergeExec.metrics("numTargetRowsUpdated").value
assert(rawUpdated > 2L,
s"Expected MergeRowsExec.numTargetRowsUpdated to overcount under fetch-failure " +
s"injection (got $rawUpdated)")
val rawMatchedUpdated = mergeExec.metrics("numTargetRowsMatchedUpdated").value
assert(rawMatchedUpdated > 2L,
s"Expected numTargetRowsMatchedUpdated to overcount (got $rawMatchedUpdated)")
}

checkAnswer(
sql(s"SELECT pk, salary FROM $tableNameAsString ORDER BY pk"),
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ abstract class RowLevelOperationSuiteBase
Collections.emptyMap[String, String]
}

/** True for the *NoMetadata* test variants - the writer doesn't request any required
* distribution / ordering and so MergeRowsExec / writer can run in the same stage as the
* preceding join. */
protected def noMetadata: Boolean =
extraTableProps.getOrDefault("no-metadata", "false") == "true"

protected def catalog: InMemoryRowLevelOperationTableCatalog = {
val catalog = spark.sessionState.catalogManager.catalog("cat")
catalog.asTableCatalog.asInstanceOf[InMemoryRowLevelOperationTableCatalog]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,12 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase {
}

test("metric values are stable across stage retries") {
// Force a shuffle in the UPDATE plan via an IN-subquery (with broadcast disabled), then
// have the DAGScheduler corrupt the first attempt of every upstream shuffle map stage.
// Note: the current fetch-failure injection does not retry the writer stage, so this
// test passes equally well with plain SQLMetric — it only exercises the SLAM-aware
// read path. Follow-up #55738 will add infra to actually retry the writer stage and
// exercise the SLAM behavior end-to-end for UPDATE.
// INJECT_SHUFFLE_FETCH_FAILURES corrupts the partition-0 task of the first successful
// attempt of every shuffle map stage, so a downstream stage FetchFails and the producer
// re-runs. UPDATE writer-side metrics live on the result stage (`metric.add(N)` at
// end-of-task in WritingSparkTask), and ResultStage.findMissingPartitions only re-runs
// partitions that haven't successfully completed, so the writer accumulator single-counts;
// this test is regression coverage that retries don't break the SLAM-aware `UpdateSummary`.
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTempView("source") {
createAndInitTable("pk INT NOT NULL, salary INT, dep STRING",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,58 @@ class MetricsFailureInjectionSuite
}
}

test("Three stage metrics block failure injection with AQE") {
// Same as the previous test but with AQE enabled. Under AQE each Exchange is materialized
// as its own map-stage job, which exercises a different DAGScheduler path than the
// AQE-disabled variant: the injection's deferred corruption must survive across those
// per-shuffle jobs for the downstream FetchFailed (and thus the producer re-run) to fire.
val stage1Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 1 counter")
val stage2Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 2 counter")
val stage3Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 3 counter")
val stage1SLAMetric =
SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 1 SLAM")
val stage2SLAMetric =
SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 2 SLAM")
val stage3SLAMetric =
SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 3 SLAM")

withTable("primary_table", "secondary_table") {
setUpTestTable("primary_table")
setUpTestTable("secondary_table")
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
withSparkContextConf(
config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true") {
val stage1MetricsExpr = incrementMetrics(Seq(stage1Metric, stage1SLAMetric))
val stage1 = spark.read.table("primary_table")
.filter(Column(stage1MetricsExpr))
val stage2MetricsExpr = incrementMetrics(Seq(stage2Metric, stage2SLAMetric))
val stage2 = stage1.join(
spark.read.table("secondary_table"),
usingColumn = "id",
joinType = "fullOuter")
.filter(Column(stage2MetricsExpr))
val stage3MetricsExpr = incrementMetrics(Seq(stage3Metric, stage3SLAMetric))
val stage3 = stage2
.groupBy("primary_table.low_cardinality_col")
.count()
.filter(Column(stage3MetricsExpr))
val finalDf = stage3.as[(Int, Long)]
val result = finalDf.collect()
assert(result.toMap === (0 until 5).map(v => (v, 300 / 5)).toMap)

// The non-leaf stage 2 gets its first successful attempt corrupted and re-runs, so its
// raw counter overcounts. SLAM reports only the last successful attempt per RDD.
assert(stage1Metric.value > 300, s"stage1Metric=${stage1Metric.value}")
assert(stage2Metric.value > 300, s"stage2Metric=${stage2Metric.value}")

assert(stage1SLAMetric.lastAttemptValueForHighestRDDId() === Some(300))
assert(stage2SLAMetric.lastAttemptValueForHighestRDDId() === Some(300))
assert(stage3SLAMetric.lastAttemptValueForHighestRDDId() === Some(5))
}
}
}
}

test("Three stage metrics force-checksum-mismatch on recompute") {
// INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE additionally flags the recompute of the
// partition-0 task as a checksum mismatch. The DAGScheduler then runs
Expand Down