diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 0754a90459d54..db48b250a7dae 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -179,10 +179,18 @@ private[spark] class DAGScheduler( private[spark] val jobIdToQueryExecutionId = new ConcurrentHashMap[Int, java.lang.Long]() - // For INJECT_SHUFFLE_FETCH_FAILURES: per-shuffleId, the stage attempt whose partition-0 task - // we corrupted. Read to (a) avoid re-corrupting that partition on recompute, and (b) decide - // when to fire INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE - the recompute is the - // task whose stageAttemptId is not the recorded one. + // The three injectShuffleFetchFailures* maps below are test-only state for + // INJECT_SHUFFLE_FETCH_FAILURES, keyed by the globally-unique (never-reused) shuffleId and + // only allocated under Utils.isTesting. They are intentionally never evicted per-stage: under + // AQE each Exchange is materialized as its own map-stage job, and removing the entry when that + // job finishes (e.g. in cleanupStateForJobAndIndependentStages) would drop the pending + // corruption before the consuming stage is ever submitted. Bounded by the number of shuffles + // in a test SparkContext, so leaving them for its lifetime is harmless. + + // Per-shuffleId, the stage attempt whose partition-0 task we corrupted. Read to (a) avoid + // re-corrupting that partition on recompute, and (b) decide when to fire + // INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE - the recompute is the task whose + // stageAttemptId is not the recorded one. private val injectShuffleFetchFailuresCorruptedAttempt: ConcurrentHashMap[Int, Int] = if (Utils.isTesting) new ConcurrentHashMap[Int, Int]() else null @@ -967,11 +975,6 @@ private[spark] class DAGScheduler( } for ((k, v) <- shuffleIdToMapStage.find(_._2 == stage)) { shuffleIdToMapStage.remove(k) - if (Utils.isTesting) { - injectShuffleFetchFailuresCorruptedAttempt.remove(k) - injectShuffleFetchFailuresPendingDelayedCorruption.remove(k) - injectShuffleFetchFailuresDownstreamSuccessCount.remove(k) - } } if (waitingStages.contains(stage)) { logDebug("Removing stage %d from waiting set.".format(stageId)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 126b84b507caf..8cb0727314f7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -2703,12 +2703,13 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase } test("metric values are stable across stage retries") { - // The join in the MERGE plan introduces a shuffle (with broadcast disabled), and the - // DAGScheduler corrupts the first attempt of every upstream shuffle map stage. Note: - // the current fetch-failure injection does not retry the MergeRowsExec/writer stage, - // so this test passes equally well with plain SQLMetric — it only exercises the - // SLAM-aware read path. Follow-up #55738 will add infra to actually retry the writer - // stage and exercise the SLAM behavior end-to-end for MERGE. + // INJECT_SHUFFLE_FETCH_FAILURES corrupts the partition-0 task of the first successful + // attempt of every shuffle map stage, so a downstream stage FetchFails and the producer + // re-runs. For the metadata variants of MERGE - where the writer's + // `RequiresDistributionAndOrdering` forces a re-shuffle between MergeRowsExec and the + // writer - MergeRowsExec sits in a non-leaf shuffle map stage and therefore re-runs with + // the same metric instances, double-counting the per-row increments. SQLLastAttemptMetric + // reports only the last attempt, so `MergeSummary` is still correct. withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { withTempView("source") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", @@ -2720,9 +2721,9 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase val sourceDF = Seq(1, 2, 10).toDF("pk") sourceDF.createOrReplaceTempView("source") - withSparkContextConf( + val mergeExec = withSparkContextConf( config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true") { - sql( + findMergeExec { s"""MERGE INTO $tableNameAsString t |USING source s |ON t.pk = s.pk @@ -2730,7 +2731,8 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase | UPDATE SET salary = salary + 100 |WHEN NOT MATCHED THEN | INSERT (pk, salary, dep) VALUES (s.pk, 999, 'unknown') - |""".stripMargin) + |""".stripMargin + } } val mergeSummary = getMergeSummary() @@ -2743,6 +2745,19 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase assert(mergeSummary.numTargetRowsNotMatchedBySourceUpdated === 0L) assert(mergeSummary.numTargetRowsNotMatchedBySourceDeleted === 0L) + // For metadata variants, MergeRowsExec lives in a non-leaf shuffle map stage that the + // fetch-failure injection forces to re-run, so the raw per-MergeRowsExec accumulator + // (`metric.value`) overcounts. SLAM-aware `MergeSummary` (asserted above) is correct. + if (!noMetadata) { + val rawUpdated = mergeExec.metrics("numTargetRowsUpdated").value + assert(rawUpdated > 2L, + s"Expected MergeRowsExec.numTargetRowsUpdated to overcount under fetch-failure " + + s"injection (got $rawUpdated)") + val rawMatchedUpdated = mergeExec.metrics("numTargetRowsMatchedUpdated").value + assert(rawMatchedUpdated > 2L, + s"Expected numTargetRowsMatchedUpdated to overcount (got $rawMatchedUpdated)") + } + checkAnswer( sql(s"SELECT pk, salary FROM $tableNameAsString ORDER BY pk"), Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala index 0c465969e347c..199b9ecbe0a07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/RowLevelOperationSuiteBase.scala @@ -89,6 +89,12 @@ abstract class RowLevelOperationSuiteBase Collections.emptyMap[String, String] } + /** True for the *NoMetadata* test variants - the writer doesn't request any required + * distribution / ordering and so MergeRowsExec / writer can run in the same stage as the + * preceding join. */ + protected def noMetadata: Boolean = + extraTableProps.getOrDefault("no-metadata", "false") == "true" + protected def catalog: InMemoryRowLevelOperationTableCatalog = { val catalog = spark.sessionState.catalogManager.catalog("cat") catalog.asTableCatalog.asInstanceOf[InMemoryRowLevelOperationTableCatalog] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala index 6e9afe7abc97e..23d73cfa57cee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateTableSuiteBase.scala @@ -342,12 +342,12 @@ abstract class UpdateTableSuiteBase extends RowLevelOperationSuiteBase { } test("metric values are stable across stage retries") { - // Force a shuffle in the UPDATE plan via an IN-subquery (with broadcast disabled), then - // have the DAGScheduler corrupt the first attempt of every upstream shuffle map stage. - // Note: the current fetch-failure injection does not retry the writer stage, so this - // test passes equally well with plain SQLMetric — it only exercises the SLAM-aware - // read path. Follow-up #55738 will add infra to actually retry the writer stage and - // exercise the SLAM behavior end-to-end for UPDATE. + // INJECT_SHUFFLE_FETCH_FAILURES corrupts the partition-0 task of the first successful + // attempt of every shuffle map stage, so a downstream stage FetchFails and the producer + // re-runs. UPDATE writer-side metrics live on the result stage (`metric.add(N)` at + // end-of-task in WritingSparkTask), and ResultStage.findMissingPartitions only re-runs + // partitions that haven't successfully completed, so the writer accumulator single-counts; + // this test is regression coverage that retries don't break the SLAM-aware `UpdateSummary`. withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { withTempView("source") { createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/MetricsFailureInjectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/MetricsFailureInjectionSuite.scala index 6fc784f33815f..93b14f21c3308 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/MetricsFailureInjectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/MetricsFailureInjectionSuite.scala @@ -420,6 +420,58 @@ class MetricsFailureInjectionSuite } } + test("Three stage metrics block failure injection with AQE") { + // Same as the previous test but with AQE enabled. Under AQE each Exchange is materialized + // as its own map-stage job, which exercises a different DAGScheduler path than the + // AQE-disabled variant: the injection's deferred corruption must survive across those + // per-shuffle jobs for the downstream FetchFailed (and thus the producer re-run) to fire. + val stage1Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 1 counter") + val stage2Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 2 counter") + val stage3Metric = SQLMetrics.createMetric(spark.sparkContext, "stage 3 counter") + val stage1SLAMetric = + SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 1 SLAM") + val stage2SLAMetric = + SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 2 SLAM") + val stage3SLAMetric = + SQLLastAttemptMetrics.createMetric(spark.sparkContext, "stage 3 SLAM") + + withTable("primary_table", "secondary_table") { + setUpTestTable("primary_table") + setUpTestTable("secondary_table") + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + withSparkContextConf( + config.Tests.INJECT_SHUFFLE_FETCH_FAILURES.key -> "true") { + val stage1MetricsExpr = incrementMetrics(Seq(stage1Metric, stage1SLAMetric)) + val stage1 = spark.read.table("primary_table") + .filter(Column(stage1MetricsExpr)) + val stage2MetricsExpr = incrementMetrics(Seq(stage2Metric, stage2SLAMetric)) + val stage2 = stage1.join( + spark.read.table("secondary_table"), + usingColumn = "id", + joinType = "fullOuter") + .filter(Column(stage2MetricsExpr)) + val stage3MetricsExpr = incrementMetrics(Seq(stage3Metric, stage3SLAMetric)) + val stage3 = stage2 + .groupBy("primary_table.low_cardinality_col") + .count() + .filter(Column(stage3MetricsExpr)) + val finalDf = stage3.as[(Int, Long)] + val result = finalDf.collect() + assert(result.toMap === (0 until 5).map(v => (v, 300 / 5)).toMap) + + // The non-leaf stage 2 gets its first successful attempt corrupted and re-runs, so its + // raw counter overcounts. SLAM reports only the last successful attempt per RDD. + assert(stage1Metric.value > 300, s"stage1Metric=${stage1Metric.value}") + assert(stage2Metric.value > 300, s"stage2Metric=${stage2Metric.value}") + + assert(stage1SLAMetric.lastAttemptValueForHighestRDDId() === Some(300)) + assert(stage2SLAMetric.lastAttemptValueForHighestRDDId() === Some(300)) + assert(stage3SLAMetric.lastAttemptValueForHighestRDDId() === Some(5)) + } + } + } + } + test("Three stage metrics force-checksum-mismatch on recompute") { // INJECT_SHUFFLE_FORCE_CHECKSUM_MISMATCH_ON_RECOMPUTE additionally flags the recompute of the // partition-0 task as a checksum mismatch. The DAGScheduler then runs