From 028cf1ba7b89b1449a7e981ddb7d27986a6ba7e5 Mon Sep 17 00:00:00 2001 From: James Xu Date: Sun, 17 May 2026 08:19:06 +0800 Subject: [PATCH 1/2] [SPARK-56898][SQL] Rewrite COUNT(DISTINCT IF(...)) to COUNT(DISTINCT) FILTER for Expand reduction Adds RewriteCountDistinctConditional optimizer rule that canonicalizes: COUNT(DISTINCT IF(cond, base, NULL)) COUNT(DISTINCT CASE WHEN cond THEN base END) into: COUNT(DISTINCT base) FILTER (WHERE cond) This reduces the number of distinct groups seen by RewriteDistinctAggregates from N (one per unique conditional expression) down to 1 (all share the same base column), collapsing the Expand factor from Nx to 1x. Gated by spark.sql.optimizer.rewriteCountDistinctConditional.enabled (default: false). --- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../RewriteCountDistinctConditional.scala | 90 ++++++++ .../apache/spark/sql/internal/SQLConf.scala | 15 ++ ...RewriteCountDistinctConditionalSuite.scala | 216 ++++++++++++++++++ ...teCountDistinctConditionalQuerySuite.scala | 208 +++++++++++++++++ 5 files changed, 530 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditional.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditionalSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/RewriteCountDistinctConditionalQuerySuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ddfe80443d561..66d6045f54e43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -252,6 +252,7 @@ abstract class Optimizer(catalogManager: CatalogManager) // This batch must run after "Decimal Optimizations", as that one may change the // aggregate distinct column Batch("Distinct Aggregate Rewrite", Once, + RewriteCountDistinctConditional, RewriteDistinctAggregates, OptimizeExpand), Batch("Object Expressions Optimization", fixedPoint, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditional.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditional.scala new file mode 100644 index 0000000000000..18861ec0fb419 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditional.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE +import org.apache.spark.sql.internal.SQLConf + +/** + * Rewrites COUNT(DISTINCT IF(cond, base, NULL)) and + * COUNT(DISTINCT CASE WHEN cond THEN base END) into + * COUNT(DISTINCT base) FILTER (WHERE cond). + * + * This canonicalization reduces the number of distinct groups seen by + * RewriteDistinctAggregates from N (one per unique conditional expression) down to 1 + * (all share the same base column), collapsing the Expand factor from Nx to 1x. + * + * Correctness: COUNT DISTINCT ignores NULLs, so nulling out rows where !cond + * is semantically identical to filtering those rows out entirely. + */ +object RewriteCountDistinctConditional extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!SQLConf.get.rewriteCountDistinctConditionalEnabled) { + return plan + } + plan.transformUpWithPruning(_.containsPattern(AGGREGATE)) { + case agg: Aggregate => agg.transformExpressionsUp { + case ae @ AggregateExpression( + count: Count, + _, + true, // isDistinct + None, // no existing FILTER + _) + if count.children.size == 1 => + extractCondAndBase(count.children.head) match { + case Some((cond, base)) => + ae.copy( + aggregateFunction = count.withNewChildren(Seq(base)).asInstanceOf[Count], + filter = Some(cond)) + case None => ae + } + } + } + } + + /** + * Matches: + * IF(cond, base, null) + * CASE WHEN cond THEN base END + * CASE WHEN cond THEN base ELSE NULL END + * + * The analyzer may wrap the null branch in a Cast for type alignment, so the + * null check is done after unwrapping any surrounding Casts. + * + * Returns None for anything else, including IF(cond, base, fallback) where + * fallback is not null -- those change semantics and must not be rewritten. + */ + private def extractCondAndBase(expr: Expression): Option[(Expression, Expression)] = + expr match { + case If(cond, base, e) if isNullExpr(e) => Some((cond, base)) + case CaseWhen(Seq((cond, base)), None) => Some((cond, base)) + case CaseWhen(Seq((cond, base)), Some(e)) if isNullExpr(e) => Some((cond, base)) + case _ => None + } + + private def isNullExpr(e: Expression): Boolean = e match { + case Literal(null, _) => true + case Cast(child, _, _, _) => isNullExpr(child) + case _ => false + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5ed831f20f394..2c16a5c4e876a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1366,6 +1366,18 @@ object SQLConf { .booleanConf .createWithDefault(true) + val REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED = + buildConf("spark.sql.optimizer.rewriteCountDistinctConditional.enabled") + .doc("When true, rewrites COUNT(DISTINCT IF(cond, base, NULL)) and " + + "COUNT(DISTINCT CASE WHEN cond THEN base END) into " + + "COUNT(DISTINCT base) FILTER (WHERE cond). This reduces the Expand factor " + + "in RewriteDistinctAggregates from Nx to 1x when multiple conditional distinct " + + "counts share the same base column.") + .version("4.2.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) + .booleanConf + .createWithDefault(false) + val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") .internal() .doc("When true, string literals (including regex patterns) remain escaped in our SQL " + @@ -8473,6 +8485,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def decorrelateInnerQueryEnabledForExistsIn: Boolean = !getConf(SQLConf.DECORRELATE_EXISTS_IN_SUBQUERY_LEGACY_INCORRECT_COUNT_HANDLING_ENABLED) + def rewriteCountDistinctConditionalEnabled: Boolean = + getConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED) + def maxConcurrentOutputFileWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS) def plannedWriteEnabled: Boolean = getConf(SQLConf.PLANNED_WRITE_ENABLED) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditionalSuite.scala new file mode 100644 index 0000000000000..88520a472e45b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditionalSuite.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Expression, If, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.IntegerType + +class RewriteCountDistinctConditionalSuite extends PlanTest { + + val testRelation = LocalRelation( + Symbol("a").int, Symbol("b").int, Symbol("c").int, Symbol("d").string) + + private def countDistinctIf(cond: Expression, base: Expression): Expression = { + Count(If(cond, base, Literal(null))).toAggregateExpression(isDistinct = true) + } + + private def countDistinctCaseWhen(cond: Expression, base: Expression): Expression = { + val caseWhen = org.apache.spark.sql.catalyst.expressions.CaseWhen( + Seq((cond, base)), + None) + Count(caseWhen).toAggregateExpression(isDistinct = true) + } + + private def countDistinctCaseWhenElseNull(cond: Expression, base: Expression): Expression = { + val caseWhen = org.apache.spark.sql.catalyst.expressions.CaseWhen( + Seq((cond, base)), + Some(Literal(null))) + Count(caseWhen).toAggregateExpression(isDistinct = true) + } + + test("disabled by default") { + val input = testRelation + .groupBy(Symbol("a"))( + countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + + val optimized = RewriteCountDistinctConditional(input) + comparePlans(optimized, input) + } + + test("rewrite COUNT(DISTINCT IF(cond, col, NULL)) to COUNT(DISTINCT col) FILTER") { + val input = testRelation + .groupBy(Symbol("a"))( + countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + + val expected = testRelation + .groupBy(Symbol("a"))( + countDistinctWithFilter(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + + comparePlans(optimized, expected) + } + } + + test("rewrite COUNT(DISTINCT CASE WHEN cond THEN col END) to COUNT(DISTINCT col) FILTER") { + val input = testRelation + .groupBy(Symbol("a"))( + countDistinctCaseWhen(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + + val expected = testRelation + .groupBy(Symbol("a"))( + countDistinctWithFilter(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + + comparePlans(optimized, expected) + } + } + + test("rewrite COUNT(DISTINCT CASE WHEN cond THEN col ELSE NULL END)") { + val input = testRelation + .groupBy(Symbol("a"))( + countDistinctCaseWhenElseNull(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + + val expected = testRelation + .groupBy(Symbol("a"))( + countDistinctWithFilter(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + + comparePlans(optimized, expected) + } + } + + test("multiple conditional distinct counts collapse to single distinct group") { + val input = testRelation + .groupBy(Symbol("a"))( + countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1"), + countDistinctIf(Symbol("b") > 2, Symbol("c")).as("cnt2"), + countDistinctIf(Symbol("b") > 3, Symbol("c")).as("cnt3")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + + val expected = testRelation + .groupBy(Symbol("a"))( + countDistinctWithFilter(Symbol("b") > 1, Symbol("c")).as("cnt1"), + countDistinctWithFilter(Symbol("b") > 2, Symbol("c")).as("cnt2"), + countDistinctWithFilter(Symbol("b") > 3, Symbol("c")).as("cnt3")) + .analyze + + comparePlans(optimized, expected) + + // Verify RewriteDistinctAggregates sees only 1 distinct group + val rewritten = RewriteDistinctAggregates(optimized) + // Should be rewritten (not same as input) because there are now multiple + // distinct expressions with the same base column + assert(rewritten != optimized) + } + } + + test("do not rewrite IF with non-null else branch") { + val input = testRelation + .groupBy(Symbol("a"))( + Count(If(Symbol("b") > 1, Symbol("c"), Literal(0, IntegerType))) + .toAggregateExpression(isDistinct = true) + .as("cnt1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + comparePlans(optimized, input) + } + } + + test("do not rewrite non-distinct COUNT") { + val input = testRelation + .groupBy(Symbol("a"))( + Count(If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType))) + .toAggregateExpression(isDistinct = false) + .as("cnt1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + comparePlans(optimized, input) + } + } + + test("do not rewrite when FILTER already exists") { + val input = testRelation + .groupBy(Symbol("a"))( + Count(If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType))) + .toAggregateExpression(isDistinct = true, filter = Some(Symbol("d") === "x")) + .as("cnt1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + comparePlans(optimized, input) + } + } + + test("do not rewrite multi-branch CASE WHEN") { + val caseWhen = new org.apache.spark.sql.catalyst.expressions.CaseWhen( + Seq( + (Symbol("b") > Literal(1), Symbol("c")), + (Symbol("b") > Literal(2), Symbol("a"))), + Some(Literal(null))) + val input = testRelation + .groupBy(Symbol("a"))( + Count(caseWhen).toAggregateExpression(isDistinct = true).as("cnt1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + comparePlans(optimized, input) + } + } + + test("do not rewrite SUM(DISTINCT IF(...))") { + val input = testRelation + .groupBy(Symbol("a"))( + org.apache.spark.sql.catalyst.expressions.aggregate.Sum( + If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType))) + .toAggregateExpression(isDistinct = true) + .as("sum1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteCountDistinctConditional(input) + comparePlans(optimized, input) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RewriteCountDistinctConditionalQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RewriteCountDistinctConditionalQuerySuite.scala new file mode 100644 index 0000000000000..e83ce5b380c3a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/RewriteCountDistinctConditionalQuerySuite.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class RewriteCountDistinctConditionalQuerySuite extends QueryTest with SharedSparkSession { + + private def checkRewriteAndResult( + conditionalSql: String, + filterSql: String): Unit = { + withTempView("t") { + // Verify the rewrite produces the same result as the explicit FILTER form. + val withRewrite = withSQLConf( + SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + spark.sql(conditionalSql).collect() + } + val withoutRewrite = withSQLConf( + SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "false") { + spark.sql(conditionalSql).collect() + } + val explicitFilter = spark.sql(filterSql).collect() + + assert(withRewrite.sameElements(explicitFilter), + "Rewritten query should match explicit FILTER query") + assert(withoutRewrite.sameElements(explicitFilter), + "Non-rewritten query should also match explicit FILTER query") + } + } + + test("rewrite COUNT(DISTINCT IF(cond, col, NULL)) correctness") { + withTempView("t") { + spark.range(7) + .selectExpr( + "cast(id % 3 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "case when id % 4 = 0 then null else cast(id * 100 as int) end as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t GROUP BY key", + "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t GROUP BY key") + } + } + + test("rewrite COUNT(DISTINCT CASE WHEN cond THEN col END) correctness") { + withTempView("t") { + spark.range(7) + .selectExpr( + "cast(id % 3 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "case when id % 4 = 0 then null else cast(id * 100 as string) end as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + "SELECT key, COUNT(DISTINCT CASE WHEN col1 > 10 THEN col2 END) FROM t GROUP BY key", + "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t GROUP BY key") + } + } + + test("rewrite COUNT(DISTINCT CASE WHEN cond THEN col ELSE NULL END) correctness") { + withTempView("t") { + spark.range(6) + .selectExpr( + "cast(id % 2 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "case when id % 4 = 0 then null else cast(id * 1.0 as double) end as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + """SELECT key, COUNT(DISTINCT CASE WHEN col1 > 10 THEN col2 ELSE NULL END) + |FROM t GROUP BY key""".stripMargin, + "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t GROUP BY key") + } + } + + test("rewrite with no GROUP BY") { + withTempView("t") { + spark.range(5) + .selectExpr( + "cast(id * 10 as int) as col1", + "case when id % 3 = 0 then null else cast(id * 100 as int) end as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + "SELECT COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t", + "SELECT COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t") + } + } + + test("rewrite with all NULLs in conditional branch") { + withTempView("t") { + spark.range(3) + .selectExpr( + "cast(id % 2 + 1 as int) as key", + "cast(id * 5 as int) as col1", + "cast(id * 100 as int) as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t GROUP BY key", + "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t GROUP BY key") + } + } + + test("rewrite with duplicates in base column") { + withTempView("t") { + spark.range(6) + .selectExpr( + "cast(id % 2 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "case when id % 3 = 0 then 100 when id % 3 = 1 then 100 else 200 end as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t GROUP BY key", + "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t GROUP BY key") + } + } + + test("multiple conditional distinct counts collapse and produce correct results") { + withTempView("t") { + spark.range(5) + .selectExpr( + "cast(id % 2 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "case when id % 3 = 0 then null else cast(id * 100 as int) end as col2", + "case when id % 4 = 0 then null else cast(id * 10 as string) end as col3") + .createOrReplaceTempView("t") + + val conditionalSql = + """SELECT key, + | COUNT(DISTINCT IF(col1 > 10, col2, NULL)) as cnt1, + | COUNT(DISTINCT IF(col1 > 5, col3, NULL)) as cnt2 + |FROM t GROUP BY key""".stripMargin + + val filterSql = + """SELECT key, + | COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) as cnt1, + | COUNT(DISTINCT col3) FILTER (WHERE col1 > 5) as cnt2 + |FROM t GROUP BY key""".stripMargin + + checkRewriteAndResult(conditionalSql, filterSql) + } + } + + test("rewrite does not affect COUNT(DISTINCT IF(cond, col, non_null))") { + withTempView("t") { + spark.range(3) + .selectExpr( + "cast(id % 2 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "cast(id * 100 as int) as col2") + .createOrReplaceTempView("t") + + val sqlText = "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, 0)) FROM t GROUP BY key" + + val withRewrite = withSQLConf( + SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + spark.sql(sqlText).collect() + } + val withoutRewrite = withSQLConf( + SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "false") { + spark.sql(sqlText).collect() + } + + assert(withRewrite.sameElements(withoutRewrite), + "Non-null ELSE branch should not be rewritten") + } + } + + test("rewrite is present in optimized plan") { + withTempView("t") { + spark.range(2) + .selectExpr( + "cast(id + 1 as int) as key", + "cast(id * 10 as int) as col1", + "cast(id * 100 as int) as col2") + .createOrReplaceTempView("t") + + val planStr = withSQLConf( + SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val df = spark.sql( + "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t GROUP BY key") + df.queryExecution.optimizedPlan.toString + } + + assert(planStr.contains("FILTER"), + s"Optimized plan should contain FILTER clause. Plan:\n$planStr") + } + } +} From 409c5e870a8d6a034c7156c811c67124763e6ed6 Mon Sep 17 00:00:00 2001 From: James Xu Date: Thu, 18 Jun 2026 21:49:10 +0800 Subject: [PATCH 2/2] [SPARK-56898][SQL] Fold RewriteCountDistinctConditional into RewriteDistinctAggregates The standalone RewriteCountDistinctConditional rule had two problems: 1. Its ordering guarantee rested purely on batch position in Optimizer.scala 2. It blindly rewrote a lone COUNT(DISTINCT IF(...)), forcing the Expand path and regressing performance vs the planAggregateWithOneDistinct fast path Move the canonicalization logic into RewriteDistinctAggregates.rewrite() via a new normalizeCountDistinctConditional helper. Since rewrite() is only called when mayNeedtoRewrite returns true, a lone conditional distinct count never enters the path and the regression is eliminated. The SQLConf flag and semantics are unchanged. Delete RewriteCountDistinctConditional.scala and migrate all tests into RewriteDistinctAggregatesSuite; rename the integration suite to RewriteDistinctAggregatesConditionalQuerySuite. --- .../sql/catalyst/optimizer/Optimizer.scala | 1 - .../RewriteCountDistinctConditional.scala | 90 ------- .../optimizer/RewriteDistinctAggregates.scala | 44 +++- ...RewriteCountDistinctConditionalSuite.scala | 216 ---------------- .../RewriteDistinctAggregatesSuite.scala | 235 +++++++++++++++++- ...inctAggregatesConditionalQuerySuite.scala} | 9 +- 6 files changed, 283 insertions(+), 312 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditional.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditionalSuite.scala rename sql/core/src/test/scala/org/apache/spark/sql/{RewriteCountDistinctConditionalQuerySuite.scala => RewriteDistinctAggregatesConditionalQuerySuite.scala} (94%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 66d6045f54e43..ddfe80443d561 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -252,7 +252,6 @@ abstract class Optimizer(catalogManager: CatalogManager) // This batch must run after "Decimal Optimizations", as that one may change the // aggregate distinct column Batch("Distinct Aggregate Rewrite", Once, - RewriteCountDistinctConditional, RewriteDistinctAggregates, OptimizeExpand), Batch("Object Expressions Optimization", fixedPoint, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditional.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditional.scala deleted file mode 100644 index 18861ec0fb419..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditional.scala +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Count} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE -import org.apache.spark.sql.internal.SQLConf - -/** - * Rewrites COUNT(DISTINCT IF(cond, base, NULL)) and - * COUNT(DISTINCT CASE WHEN cond THEN base END) into - * COUNT(DISTINCT base) FILTER (WHERE cond). - * - * This canonicalization reduces the number of distinct groups seen by - * RewriteDistinctAggregates from N (one per unique conditional expression) down to 1 - * (all share the same base column), collapsing the Expand factor from Nx to 1x. - * - * Correctness: COUNT DISTINCT ignores NULLs, so nulling out rows where !cond - * is semantically identical to filtering those rows out entirely. - */ -object RewriteCountDistinctConditional extends Rule[LogicalPlan] { - - override def apply(plan: LogicalPlan): LogicalPlan = { - if (!SQLConf.get.rewriteCountDistinctConditionalEnabled) { - return plan - } - plan.transformUpWithPruning(_.containsPattern(AGGREGATE)) { - case agg: Aggregate => agg.transformExpressionsUp { - case ae @ AggregateExpression( - count: Count, - _, - true, // isDistinct - None, // no existing FILTER - _) - if count.children.size == 1 => - extractCondAndBase(count.children.head) match { - case Some((cond, base)) => - ae.copy( - aggregateFunction = count.withNewChildren(Seq(base)).asInstanceOf[Count], - filter = Some(cond)) - case None => ae - } - } - } - } - - /** - * Matches: - * IF(cond, base, null) - * CASE WHEN cond THEN base END - * CASE WHEN cond THEN base ELSE NULL END - * - * The analyzer may wrap the null branch in a Cast for type alignment, so the - * null check is done after unwrapping any surrounding Casts. - * - * Returns None for anything else, including IF(cond, base, fallback) where - * fallback is not null -- those change semantics and must not be rewritten. - */ - private def extractCondAndBase(expr: Expression): Option[(Expression, Expression)] = - expr match { - case If(cond, base, e) if isNullExpr(e) => Some((cond, base)) - case CaseWhen(Seq((cond, base)), None) => Some((cond, base)) - case CaseWhen(Seq((cond, base)), Some(e)) if isNullExpr(e) => Some((cond, base)) - case _ => None - } - - private def isNullExpr(e: Expression): Boolean = e match { - case Literal(null, _) => true - case Cast(child, _, _, _) => isNullExpr(child) - case _ => false - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 5aef82b64ed32..939e647b27b67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.collection.Utils @@ -223,8 +224,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { case a: Aggregate if mayNeedtoRewrite(a) => rewrite(a) } - def rewrite(a: Aggregate): Aggregate = { + def rewrite(origAgg: Aggregate): Aggregate = { + val a = normalizeCountDistinctConditional(origAgg) val aggExpressions = collectAggregateExprs(a) val distinctAggs = aggExpressions.filter(_.isDistinct) @@ -419,6 +421,46 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } } + /** + * Canonicalizes COUNT(DISTINCT IF(cond, base, NULL)) and + * COUNT(DISTINCT CASE WHEN cond THEN base END) to COUNT(DISTINCT base) FILTER (WHERE cond). + * This reduces the number of distinct groups: multiple conditional counts on the same base + * column collapse into one group, shrinking the Expand fan-out from Nx to 1x. + */ + private def normalizeCountDistinctConditional(a: Aggregate): Aggregate = { + if (!SQLConf.get.rewriteCountDistinctConditionalEnabled) return a + a.transformExpressionsUp { + case ae @ AggregateExpression(count: Count, _, true, None, _) + if count.children.size == 1 => + extractCondAndBase(count.children.head) match { + case Some((cond, base)) => + ae.copy( + aggregateFunction = count.withNewChildren(Seq(base)).asInstanceOf[Count], + filter = Some(cond)) + case None => ae + } + }.asInstanceOf[Aggregate] + } + + /** + * Matches IF(cond, base, null), CASE WHEN cond THEN base END, and + * CASE WHEN cond THEN base ELSE NULL END (including null wrapped in Cast). + * Returns None for anything else. + */ + private def extractCondAndBase(expr: Expression): Option[(Expression, Expression)] = + expr match { + case If(cond, base, e) if isNullExpr(e) => Some((cond, base)) + case CaseWhen(Seq((cond, base)), None) => Some((cond, base)) + case CaseWhen(Seq((cond, base)), Some(e)) if isNullExpr(e) => Some((cond, base)) + case _ => None + } + + private def isNullExpr(e: Expression): Boolean = e match { + case Literal(null, _) => true + case Cast(child, _, _, _) => isNullExpr(child) + case _ => false + } + private def collectAggregateExprs(a: Aggregate): Seq[AggregateExpression] = { // Collect all aggregate expressions. a.aggregateExpressions.flatMap { _.collect { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditionalSuite.scala deleted file mode 100644 index 88520a472e45b..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteCountDistinctConditionalSuite.scala +++ /dev/null @@ -1,216 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Expression, If, Literal} -import org.apache.spark.sql.catalyst.expressions.aggregate.Count -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.IntegerType - -class RewriteCountDistinctConditionalSuite extends PlanTest { - - val testRelation = LocalRelation( - Symbol("a").int, Symbol("b").int, Symbol("c").int, Symbol("d").string) - - private def countDistinctIf(cond: Expression, base: Expression): Expression = { - Count(If(cond, base, Literal(null))).toAggregateExpression(isDistinct = true) - } - - private def countDistinctCaseWhen(cond: Expression, base: Expression): Expression = { - val caseWhen = org.apache.spark.sql.catalyst.expressions.CaseWhen( - Seq((cond, base)), - None) - Count(caseWhen).toAggregateExpression(isDistinct = true) - } - - private def countDistinctCaseWhenElseNull(cond: Expression, base: Expression): Expression = { - val caseWhen = org.apache.spark.sql.catalyst.expressions.CaseWhen( - Seq((cond, base)), - Some(Literal(null))) - Count(caseWhen).toAggregateExpression(isDistinct = true) - } - - test("disabled by default") { - val input = testRelation - .groupBy(Symbol("a"))( - countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1")) - .analyze - - val optimized = RewriteCountDistinctConditional(input) - comparePlans(optimized, input) - } - - test("rewrite COUNT(DISTINCT IF(cond, col, NULL)) to COUNT(DISTINCT col) FILTER") { - val input = testRelation - .groupBy(Symbol("a"))( - countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1")) - .analyze - - withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { - val optimized = RewriteCountDistinctConditional(input) - - val expected = testRelation - .groupBy(Symbol("a"))( - countDistinctWithFilter(Symbol("b") > 1, Symbol("c")).as("cnt1")) - .analyze - - comparePlans(optimized, expected) - } - } - - test("rewrite COUNT(DISTINCT CASE WHEN cond THEN col END) to COUNT(DISTINCT col) FILTER") { - val input = testRelation - .groupBy(Symbol("a"))( - countDistinctCaseWhen(Symbol("b") > 1, Symbol("c")).as("cnt1")) - .analyze - - withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { - val optimized = RewriteCountDistinctConditional(input) - - val expected = testRelation - .groupBy(Symbol("a"))( - countDistinctWithFilter(Symbol("b") > 1, Symbol("c")).as("cnt1")) - .analyze - - comparePlans(optimized, expected) - } - } - - test("rewrite COUNT(DISTINCT CASE WHEN cond THEN col ELSE NULL END)") { - val input = testRelation - .groupBy(Symbol("a"))( - countDistinctCaseWhenElseNull(Symbol("b") > 1, Symbol("c")).as("cnt1")) - .analyze - - withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { - val optimized = RewriteCountDistinctConditional(input) - - val expected = testRelation - .groupBy(Symbol("a"))( - countDistinctWithFilter(Symbol("b") > 1, Symbol("c")).as("cnt1")) - .analyze - - comparePlans(optimized, expected) - } - } - - test("multiple conditional distinct counts collapse to single distinct group") { - val input = testRelation - .groupBy(Symbol("a"))( - countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1"), - countDistinctIf(Symbol("b") > 2, Symbol("c")).as("cnt2"), - countDistinctIf(Symbol("b") > 3, Symbol("c")).as("cnt3")) - .analyze - - withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { - val optimized = RewriteCountDistinctConditional(input) - - val expected = testRelation - .groupBy(Symbol("a"))( - countDistinctWithFilter(Symbol("b") > 1, Symbol("c")).as("cnt1"), - countDistinctWithFilter(Symbol("b") > 2, Symbol("c")).as("cnt2"), - countDistinctWithFilter(Symbol("b") > 3, Symbol("c")).as("cnt3")) - .analyze - - comparePlans(optimized, expected) - - // Verify RewriteDistinctAggregates sees only 1 distinct group - val rewritten = RewriteDistinctAggregates(optimized) - // Should be rewritten (not same as input) because there are now multiple - // distinct expressions with the same base column - assert(rewritten != optimized) - } - } - - test("do not rewrite IF with non-null else branch") { - val input = testRelation - .groupBy(Symbol("a"))( - Count(If(Symbol("b") > 1, Symbol("c"), Literal(0, IntegerType))) - .toAggregateExpression(isDistinct = true) - .as("cnt1")) - .analyze - - withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { - val optimized = RewriteCountDistinctConditional(input) - comparePlans(optimized, input) - } - } - - test("do not rewrite non-distinct COUNT") { - val input = testRelation - .groupBy(Symbol("a"))( - Count(If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType))) - .toAggregateExpression(isDistinct = false) - .as("cnt1")) - .analyze - - withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { - val optimized = RewriteCountDistinctConditional(input) - comparePlans(optimized, input) - } - } - - test("do not rewrite when FILTER already exists") { - val input = testRelation - .groupBy(Symbol("a"))( - Count(If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType))) - .toAggregateExpression(isDistinct = true, filter = Some(Symbol("d") === "x")) - .as("cnt1")) - .analyze - - withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { - val optimized = RewriteCountDistinctConditional(input) - comparePlans(optimized, input) - } - } - - test("do not rewrite multi-branch CASE WHEN") { - val caseWhen = new org.apache.spark.sql.catalyst.expressions.CaseWhen( - Seq( - (Symbol("b") > Literal(1), Symbol("c")), - (Symbol("b") > Literal(2), Symbol("a"))), - Some(Literal(null))) - val input = testRelation - .groupBy(Symbol("a"))( - Count(caseWhen).toAggregateExpression(isDistinct = true).as("cnt1")) - .analyze - - withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { - val optimized = RewriteCountDistinctConditional(input) - comparePlans(optimized, input) - } - } - - test("do not rewrite SUM(DISTINCT IF(...))") { - val input = testRelation - .groupBy(Symbol("a"))( - org.apache.spark.sql.catalyst.expressions.aggregate.Sum( - If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType))) - .toAggregateExpression(isDistinct = true) - .as("sum1")) - .analyze - - withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { - val optimized = RewriteCountDistinctConditional(input) - comparePlans(optimized, input) - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 08dd4011f04d6..a2a538ec677d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Literal, Round} -import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet +import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Expression, If, Literal, Round} +import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count, Sum} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StringType} class RewriteDistinctAggregatesSuite extends PlanTest { @@ -125,4 +126,234 @@ class RewriteDistinctAggregatesSuite extends PlanTest { fail(s"Plan is not as expected:\n$rewrite") } } + + // --------------------------------------------------------------------------- + // COUNT(DISTINCT IF/CASE) canonicalization (SPARK-56898) + // --------------------------------------------------------------------------- + + val conditionalTestRelation = LocalRelation( + Symbol("a").int, Symbol("b").int, Symbol("c").int, Symbol("d").string) + + private def countDistinctIf(cond: Expression, base: Expression): Expression = { + Count(If(cond, base, Literal(null))).toAggregateExpression(isDistinct = true) + } + + private def countDistinctCaseWhen(cond: Expression, base: Expression): Expression = { + val caseWhen = CaseWhen( + Seq((cond, base)), + None) + Count(caseWhen).toAggregateExpression(isDistinct = true) + } + + private def countDistinctCaseWhenElseNull(cond: Expression, base: Expression): Expression = { + val caseWhen = CaseWhen( + Seq((cond, base)), + Some(Literal(null))) + Count(caseWhen).toAggregateExpression(isDistinct = true) + } + + /** + * Asserts that the optimized plan has exactly one Expand node with one projection, + * that the projection contains `baseColName` as a plain attribute, and that it + * contains no expression of `removedWrapperType` (the IF/CaseWhen that was stripped). + */ + private def assertSingleDistinctGroupExpand( + optimized: LogicalPlan, + baseColName: String, + removedWrapperType: Class[_]): Unit = { + val expand = optimized.collectFirst { case e: Expand => e }.get + assert(expand.projections.size == 1, + s"expected 1 distinct group but got ${expand.projections.size}") + val baseAttr = conditionalTestRelation.output.find(_.name == baseColName).get + assert(expand.projections.head.exists(_.semanticEquals(baseAttr)), + s"expected base column $baseColName in Expand projection") + assert(!expand.projections.head.exists(e => removedWrapperType.isInstance(e)), + s"${removedWrapperType.getSimpleName} wrapper should have been removed " + + "from Expand projection") + } + + test("conditional: disabled by default") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + val optimized = RewriteDistinctAggregates(input) + comparePlans(optimized, input) + } + + test("conditional: rewrite COUNT(DISTINCT IF(cond, col, NULL)) to COUNT(DISTINCT col) FILTER") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1"), + countDistinctIf(Symbol("b") > 2, Symbol("c")).as("cnt2")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + assertSingleDistinctGroupExpand(optimized, "c", classOf[If]) + } + } + + test("conditional: rewrite COUNT(DISTINCT CASE WHEN cond THEN col END) to " + + "COUNT(DISTINCT col) FILTER") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + countDistinctCaseWhen(Symbol("b") > 1, Symbol("c")).as("cnt1"), + countDistinctCaseWhen(Symbol("b") > 2, Symbol("c")).as("cnt2")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + assertSingleDistinctGroupExpand(optimized, "c", classOf[CaseWhen]) + } + } + + test("conditional: rewrite COUNT(DISTINCT CASE WHEN cond THEN col ELSE NULL END)") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + countDistinctCaseWhenElseNull(Symbol("b") > 1, Symbol("c")).as("cnt1"), + countDistinctCaseWhenElseNull(Symbol("b") > 2, Symbol("c")).as("cnt2")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + assertSingleDistinctGroupExpand(optimized, "c", classOf[CaseWhen]) + } + } + + test("conditional: multiple conditional distinct counts collapse to single distinct group") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1"), + countDistinctIf(Symbol("b") > 2, Symbol("c")).as("cnt2"), + countDistinctIf(Symbol("b") > 3, Symbol("c")).as("cnt3")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + // All three counts share the same base column c, collapsed to 1 distinct group. + assertSingleDistinctGroupExpand(optimized, "c", classOf[If]) + } + } + + test("conditional: single conditional distinct count does not produce Expand") { + // A lone COUNT(DISTINCT IF(...)) must NOT be pushed onto the Expand path. + // The canonicalization runs inside rewrite(), which is only called when + // mayNeedtoRewrite returns true. A single conditional distinct count has no filter + // and forms only one distinct group, so mayNeedtoRewrite returns false, rewrite() + // is never called, and no Expand is produced. + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + val expands = optimized.collect { case e: Expand => e } + assert(expands.isEmpty, "single conditional distinct count should not produce an Expand") + } + } + + test("conditional: do not rewrite IF with non-null else branch") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + Count(If(Symbol("b") > 1, Symbol("c"), Literal(0, IntegerType))) + .toAggregateExpression(isDistinct = true) + .as("cnt1"), + Count(If(Symbol("b") > 2, Symbol("c"), Literal(0, IntegerType))) + .toAggregateExpression(isDistinct = true) + .as("cnt2")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + // Plan is rewritten (2 distinct groups) but not canonicalized + checkRewrite(optimized) + val expands = optimized.collect { case e: Expand => e } + assert(expands.head.projections.size == 2, + "non-null else branch should not be collapsed to 1 distinct group") + } + } + + test("conditional: do not rewrite non-distinct COUNT") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + Count(If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType))) + .toAggregateExpression(isDistinct = false) + .as("cnt1"), + countDistinct(Symbol("c")).as("cnt2")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + // Still a single distinct group (cnt2), no canonicalization of the non-distinct agg + comparePlans(optimized, input) + } + } + + test("conditional: do not rewrite when FILTER already exists") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + Count(If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType))) + .toAggregateExpression(isDistinct = true, filter = Some(Symbol("d") === "x")) + .as("cnt1"), + Count(If(Symbol("b") > 2, Symbol("c"), Literal(null, IntegerType))) + .toAggregateExpression(isDistinct = true, filter = Some(Symbol("d") === "y")) + .as("cnt2")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + checkRewrite(optimized) + val expands = optimized.collect { case e: Expand => e } + // 2 groups because existing FILTERs prevent canonicalization + assert(expands.head.projections.size == 2) + } + } + + test("conditional: do not rewrite multi-branch CASE WHEN") { + val caseWhen1 = new CaseWhen( + Seq( + (Symbol("b") > Literal(1), Symbol("c")), + (Symbol("b") > Literal(2), Symbol("a"))), + Some(Literal(null))) + val caseWhen2 = new CaseWhen( + Seq( + (Symbol("b") > Literal(3), Symbol("c")), + (Symbol("b") > Literal(4), Symbol("a"))), + Some(Literal(null))) + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + Count(caseWhen1).toAggregateExpression(isDistinct = true).as("cnt1"), + Count(caseWhen2).toAggregateExpression(isDistinct = true).as("cnt2")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + checkRewrite(optimized) + val expands = optimized.collect { case e: Expand => e } + // 2 groups - multi-branch CASE was not canonicalized + assert(expands.head.projections.size == 2) + } + } + + test("conditional: do not rewrite SUM(DISTINCT IF(...))") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + Sum(If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType))) + .toAggregateExpression(isDistinct = true) + .as("sum1"), + Sum(If(Symbol("b") > 2, Symbol("c"), Literal(null, IntegerType))) + .toAggregateExpression(isDistinct = true) + .as("sum2")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + checkRewrite(optimized) + val expands = optimized.collect { case e: Expand => e } + // 2 groups - SUM(DISTINCT IF) is not canonicalized + assert(expands.head.projections.size == 2) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RewriteCountDistinctConditionalQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RewriteDistinctAggregatesConditionalQuerySuite.scala similarity index 94% rename from sql/core/src/test/scala/org/apache/spark/sql/RewriteCountDistinctConditionalQuerySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/RewriteDistinctAggregatesConditionalQuerySuite.scala index e83ce5b380c3a..c4a531dfca1b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RewriteCountDistinctConditionalQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RewriteDistinctAggregatesConditionalQuerySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -class RewriteCountDistinctConditionalQuerySuite extends QueryTest with SharedSparkSession { +class RewriteDistinctAggregatesConditionalQuerySuite extends QueryTest with SharedSparkSession { private def checkRewriteAndResult( conditionalSql: String, @@ -194,10 +194,15 @@ class RewriteCountDistinctConditionalQuerySuite extends QueryTest with SharedSpa "cast(id * 100 as int) as col2") .createOrReplaceTempView("t") + // Two conditional distinct counts on the same base column trigger canonicalization, + // so the optimized plan must contain FILTER clauses. val planStr = withSQLConf( SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { val df = spark.sql( - "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t GROUP BY key") + """SELECT key, + | COUNT(DISTINCT IF(col1 > 10, col2, NULL)) as cnt1, + | COUNT(DISTINCT IF(col1 > 5, col2, NULL)) as cnt2 + |FROM t GROUP BY key""".stripMargin) df.queryExecution.optimizedPlan.toString }