diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 5aef82b64ed32..939e647b27b67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.collection.Utils @@ -223,8 +224,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { case a: Aggregate if mayNeedtoRewrite(a) => rewrite(a) } - def rewrite(a: Aggregate): Aggregate = { + def rewrite(origAgg: Aggregate): Aggregate = { + val a = normalizeCountDistinctConditional(origAgg) val aggExpressions = collectAggregateExprs(a) val distinctAggs = aggExpressions.filter(_.isDistinct) @@ -419,6 +421,46 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } } + /** + * Canonicalizes COUNT(DISTINCT IF(cond, base, NULL)) and + * COUNT(DISTINCT CASE WHEN cond THEN base END) to COUNT(DISTINCT base) FILTER (WHERE cond). + * This reduces the number of distinct groups: multiple conditional counts on the same base + * column collapse into one group, shrinking the Expand fan-out from Nx to 1x. + */ + private def normalizeCountDistinctConditional(a: Aggregate): Aggregate = { + if (!SQLConf.get.rewriteCountDistinctConditionalEnabled) return a + a.transformExpressionsUp { + case ae @ AggregateExpression(count: Count, _, true, None, _) + if count.children.size == 1 => + extractCondAndBase(count.children.head) match { + case Some((cond, base)) => + ae.copy( + aggregateFunction = count.withNewChildren(Seq(base)).asInstanceOf[Count], + filter = Some(cond)) + case None => ae + } + }.asInstanceOf[Aggregate] + } + + /** + * Matches IF(cond, base, null), CASE WHEN cond THEN base END, and + * CASE WHEN cond THEN base ELSE NULL END (including null wrapped in Cast). + * Returns None for anything else. + */ + private def extractCondAndBase(expr: Expression): Option[(Expression, Expression)] = + expr match { + case If(cond, base, e) if isNullExpr(e) => Some((cond, base)) + case CaseWhen(Seq((cond, base)), None) => Some((cond, base)) + case CaseWhen(Seq((cond, base)), Some(e)) if isNullExpr(e) => Some((cond, base)) + case _ => None + } + + private def isNullExpr(e: Expression): Boolean = e match { + case Literal(null, _) => true + case Cast(child, _, _, _) => isNullExpr(child) + case _ => false + } + private def collectAggregateExprs(a: Aggregate): Seq[AggregateExpression] = { // Collect all aggregate expressions. a.aggregateExpressions.flatMap { _.collect { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5ed831f20f394..2c16a5c4e876a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1366,6 +1366,18 @@ object SQLConf { .booleanConf .createWithDefault(true) + val REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED = + buildConf("spark.sql.optimizer.rewriteCountDistinctConditional.enabled") + .doc("When true, rewrites COUNT(DISTINCT IF(cond, base, NULL)) and " + + "COUNT(DISTINCT CASE WHEN cond THEN base END) into " + + "COUNT(DISTINCT base) FILTER (WHERE cond). This reduces the Expand factor " + + "in RewriteDistinctAggregates from Nx to 1x when multiple conditional distinct " + + "counts share the same base column.") + .version("4.2.0") + .withBindingPolicy(ConfigBindingPolicy.SESSION) + .booleanConf + .createWithDefault(false) + val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") .internal() .doc("When true, string literals (including regex patterns) remain escaped in our SQL " + @@ -8473,6 +8485,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def decorrelateInnerQueryEnabledForExistsIn: Boolean = !getConf(SQLConf.DECORRELATE_EXISTS_IN_SUBQUERY_LEGACY_INCORRECT_COUNT_HANDLING_ENABLED) + def rewriteCountDistinctConditionalEnabled: Boolean = + getConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED) + def maxConcurrentOutputFileWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS) def plannedWriteEnabled: Boolean = getConf(SQLConf.PLANNED_WRITE_ENABLED) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 08dd4011f04d6..a2a538ec677d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Literal, Round} -import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet +import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Expression, If, Literal, Round} +import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count, Sum} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, StringType} class RewriteDistinctAggregatesSuite extends PlanTest { @@ -125,4 +126,234 @@ class RewriteDistinctAggregatesSuite extends PlanTest { fail(s"Plan is not as expected:\n$rewrite") } } + + // --------------------------------------------------------------------------- + // COUNT(DISTINCT IF/CASE) canonicalization (SPARK-56898) + // --------------------------------------------------------------------------- + + val conditionalTestRelation = LocalRelation( + Symbol("a").int, Symbol("b").int, Symbol("c").int, Symbol("d").string) + + private def countDistinctIf(cond: Expression, base: Expression): Expression = { + Count(If(cond, base, Literal(null))).toAggregateExpression(isDistinct = true) + } + + private def countDistinctCaseWhen(cond: Expression, base: Expression): Expression = { + val caseWhen = CaseWhen( + Seq((cond, base)), + None) + Count(caseWhen).toAggregateExpression(isDistinct = true) + } + + private def countDistinctCaseWhenElseNull(cond: Expression, base: Expression): Expression = { + val caseWhen = CaseWhen( + Seq((cond, base)), + Some(Literal(null))) + Count(caseWhen).toAggregateExpression(isDistinct = true) + } + + /** + * Asserts that the optimized plan has exactly one Expand node with one projection, + * that the projection contains `baseColName` as a plain attribute, and that it + * contains no expression of `removedWrapperType` (the IF/CaseWhen that was stripped). + */ + private def assertSingleDistinctGroupExpand( + optimized: LogicalPlan, + baseColName: String, + removedWrapperType: Class[_]): Unit = { + val expand = optimized.collectFirst { case e: Expand => e }.get + assert(expand.projections.size == 1, + s"expected 1 distinct group but got ${expand.projections.size}") + val baseAttr = conditionalTestRelation.output.find(_.name == baseColName).get + assert(expand.projections.head.exists(_.semanticEquals(baseAttr)), + s"expected base column $baseColName in Expand projection") + assert(!expand.projections.head.exists(e => removedWrapperType.isInstance(e)), + s"${removedWrapperType.getSimpleName} wrapper should have been removed " + + "from Expand projection") + } + + test("conditional: disabled by default") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + val optimized = RewriteDistinctAggregates(input) + comparePlans(optimized, input) + } + + test("conditional: rewrite COUNT(DISTINCT IF(cond, col, NULL)) to COUNT(DISTINCT col) FILTER") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1"), + countDistinctIf(Symbol("b") > 2, Symbol("c")).as("cnt2")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + assertSingleDistinctGroupExpand(optimized, "c", classOf[If]) + } + } + + test("conditional: rewrite COUNT(DISTINCT CASE WHEN cond THEN col END) to " + + "COUNT(DISTINCT col) FILTER") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + countDistinctCaseWhen(Symbol("b") > 1, Symbol("c")).as("cnt1"), + countDistinctCaseWhen(Symbol("b") > 2, Symbol("c")).as("cnt2")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + assertSingleDistinctGroupExpand(optimized, "c", classOf[CaseWhen]) + } + } + + test("conditional: rewrite COUNT(DISTINCT CASE WHEN cond THEN col ELSE NULL END)") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + countDistinctCaseWhenElseNull(Symbol("b") > 1, Symbol("c")).as("cnt1"), + countDistinctCaseWhenElseNull(Symbol("b") > 2, Symbol("c")).as("cnt2")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + assertSingleDistinctGroupExpand(optimized, "c", classOf[CaseWhen]) + } + } + + test("conditional: multiple conditional distinct counts collapse to single distinct group") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1"), + countDistinctIf(Symbol("b") > 2, Symbol("c")).as("cnt2"), + countDistinctIf(Symbol("b") > 3, Symbol("c")).as("cnt3")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + // All three counts share the same base column c, collapsed to 1 distinct group. + assertSingleDistinctGroupExpand(optimized, "c", classOf[If]) + } + } + + test("conditional: single conditional distinct count does not produce Expand") { + // A lone COUNT(DISTINCT IF(...)) must NOT be pushed onto the Expand path. + // The canonicalization runs inside rewrite(), which is only called when + // mayNeedtoRewrite returns true. A single conditional distinct count has no filter + // and forms only one distinct group, so mayNeedtoRewrite returns false, rewrite() + // is never called, and no Expand is produced. + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + val expands = optimized.collect { case e: Expand => e } + assert(expands.isEmpty, "single conditional distinct count should not produce an Expand") + } + } + + test("conditional: do not rewrite IF with non-null else branch") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + Count(If(Symbol("b") > 1, Symbol("c"), Literal(0, IntegerType))) + .toAggregateExpression(isDistinct = true) + .as("cnt1"), + Count(If(Symbol("b") > 2, Symbol("c"), Literal(0, IntegerType))) + .toAggregateExpression(isDistinct = true) + .as("cnt2")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + // Plan is rewritten (2 distinct groups) but not canonicalized + checkRewrite(optimized) + val expands = optimized.collect { case e: Expand => e } + assert(expands.head.projections.size == 2, + "non-null else branch should not be collapsed to 1 distinct group") + } + } + + test("conditional: do not rewrite non-distinct COUNT") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + Count(If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType))) + .toAggregateExpression(isDistinct = false) + .as("cnt1"), + countDistinct(Symbol("c")).as("cnt2")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + // Still a single distinct group (cnt2), no canonicalization of the non-distinct agg + comparePlans(optimized, input) + } + } + + test("conditional: do not rewrite when FILTER already exists") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + Count(If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType))) + .toAggregateExpression(isDistinct = true, filter = Some(Symbol("d") === "x")) + .as("cnt1"), + Count(If(Symbol("b") > 2, Symbol("c"), Literal(null, IntegerType))) + .toAggregateExpression(isDistinct = true, filter = Some(Symbol("d") === "y")) + .as("cnt2")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + checkRewrite(optimized) + val expands = optimized.collect { case e: Expand => e } + // 2 groups because existing FILTERs prevent canonicalization + assert(expands.head.projections.size == 2) + } + } + + test("conditional: do not rewrite multi-branch CASE WHEN") { + val caseWhen1 = new CaseWhen( + Seq( + (Symbol("b") > Literal(1), Symbol("c")), + (Symbol("b") > Literal(2), Symbol("a"))), + Some(Literal(null))) + val caseWhen2 = new CaseWhen( + Seq( + (Symbol("b") > Literal(3), Symbol("c")), + (Symbol("b") > Literal(4), Symbol("a"))), + Some(Literal(null))) + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + Count(caseWhen1).toAggregateExpression(isDistinct = true).as("cnt1"), + Count(caseWhen2).toAggregateExpression(isDistinct = true).as("cnt2")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + checkRewrite(optimized) + val expands = optimized.collect { case e: Expand => e } + // 2 groups - multi-branch CASE was not canonicalized + assert(expands.head.projections.size == 2) + } + } + + test("conditional: do not rewrite SUM(DISTINCT IF(...))") { + val input = conditionalTestRelation + .groupBy(Symbol("a"))( + Sum(If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType))) + .toAggregateExpression(isDistinct = true) + .as("sum1"), + Sum(If(Symbol("b") > 2, Symbol("c"), Literal(null, IntegerType))) + .toAggregateExpression(isDistinct = true) + .as("sum2")) + .analyze + + withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val optimized = RewriteDistinctAggregates(input) + checkRewrite(optimized) + val expands = optimized.collect { case e: Expand => e } + // 2 groups - SUM(DISTINCT IF) is not canonicalized + assert(expands.head.projections.size == 2) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RewriteDistinctAggregatesConditionalQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RewriteDistinctAggregatesConditionalQuerySuite.scala new file mode 100644 index 0000000000000..c4a531dfca1b1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/RewriteDistinctAggregatesConditionalQuerySuite.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +class RewriteDistinctAggregatesConditionalQuerySuite extends QueryTest with SharedSparkSession { + + private def checkRewriteAndResult( + conditionalSql: String, + filterSql: String): Unit = { + withTempView("t") { + // Verify the rewrite produces the same result as the explicit FILTER form. + val withRewrite = withSQLConf( + SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + spark.sql(conditionalSql).collect() + } + val withoutRewrite = withSQLConf( + SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "false") { + spark.sql(conditionalSql).collect() + } + val explicitFilter = spark.sql(filterSql).collect() + + assert(withRewrite.sameElements(explicitFilter), + "Rewritten query should match explicit FILTER query") + assert(withoutRewrite.sameElements(explicitFilter), + "Non-rewritten query should also match explicit FILTER query") + } + } + + test("rewrite COUNT(DISTINCT IF(cond, col, NULL)) correctness") { + withTempView("t") { + spark.range(7) + .selectExpr( + "cast(id % 3 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "case when id % 4 = 0 then null else cast(id * 100 as int) end as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t GROUP BY key", + "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t GROUP BY key") + } + } + + test("rewrite COUNT(DISTINCT CASE WHEN cond THEN col END) correctness") { + withTempView("t") { + spark.range(7) + .selectExpr( + "cast(id % 3 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "case when id % 4 = 0 then null else cast(id * 100 as string) end as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + "SELECT key, COUNT(DISTINCT CASE WHEN col1 > 10 THEN col2 END) FROM t GROUP BY key", + "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t GROUP BY key") + } + } + + test("rewrite COUNT(DISTINCT CASE WHEN cond THEN col ELSE NULL END) correctness") { + withTempView("t") { + spark.range(6) + .selectExpr( + "cast(id % 2 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "case when id % 4 = 0 then null else cast(id * 1.0 as double) end as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + """SELECT key, COUNT(DISTINCT CASE WHEN col1 > 10 THEN col2 ELSE NULL END) + |FROM t GROUP BY key""".stripMargin, + "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t GROUP BY key") + } + } + + test("rewrite with no GROUP BY") { + withTempView("t") { + spark.range(5) + .selectExpr( + "cast(id * 10 as int) as col1", + "case when id % 3 = 0 then null else cast(id * 100 as int) end as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + "SELECT COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t", + "SELECT COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t") + } + } + + test("rewrite with all NULLs in conditional branch") { + withTempView("t") { + spark.range(3) + .selectExpr( + "cast(id % 2 + 1 as int) as key", + "cast(id * 5 as int) as col1", + "cast(id * 100 as int) as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t GROUP BY key", + "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t GROUP BY key") + } + } + + test("rewrite with duplicates in base column") { + withTempView("t") { + spark.range(6) + .selectExpr( + "cast(id % 2 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "case when id % 3 = 0 then 100 when id % 3 = 1 then 100 else 200 end as col2") + .createOrReplaceTempView("t") + + checkRewriteAndResult( + "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, NULL)) FROM t GROUP BY key", + "SELECT key, COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) FROM t GROUP BY key") + } + } + + test("multiple conditional distinct counts collapse and produce correct results") { + withTempView("t") { + spark.range(5) + .selectExpr( + "cast(id % 2 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "case when id % 3 = 0 then null else cast(id * 100 as int) end as col2", + "case when id % 4 = 0 then null else cast(id * 10 as string) end as col3") + .createOrReplaceTempView("t") + + val conditionalSql = + """SELECT key, + | COUNT(DISTINCT IF(col1 > 10, col2, NULL)) as cnt1, + | COUNT(DISTINCT IF(col1 > 5, col3, NULL)) as cnt2 + |FROM t GROUP BY key""".stripMargin + + val filterSql = + """SELECT key, + | COUNT(DISTINCT col2) FILTER (WHERE col1 > 10) as cnt1, + | COUNT(DISTINCT col3) FILTER (WHERE col1 > 5) as cnt2 + |FROM t GROUP BY key""".stripMargin + + checkRewriteAndResult(conditionalSql, filterSql) + } + } + + test("rewrite does not affect COUNT(DISTINCT IF(cond, col, non_null))") { + withTempView("t") { + spark.range(3) + .selectExpr( + "cast(id % 2 + 1 as int) as key", + "cast(id * 10 as int) as col1", + "cast(id * 100 as int) as col2") + .createOrReplaceTempView("t") + + val sqlText = "SELECT key, COUNT(DISTINCT IF(col1 > 10, col2, 0)) FROM t GROUP BY key" + + val withRewrite = withSQLConf( + SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + spark.sql(sqlText).collect() + } + val withoutRewrite = withSQLConf( + SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "false") { + spark.sql(sqlText).collect() + } + + assert(withRewrite.sameElements(withoutRewrite), + "Non-null ELSE branch should not be rewritten") + } + } + + test("rewrite is present in optimized plan") { + withTempView("t") { + spark.range(2) + .selectExpr( + "cast(id + 1 as int) as key", + "cast(id * 10 as int) as col1", + "cast(id * 100 as int) as col2") + .createOrReplaceTempView("t") + + // Two conditional distinct counts on the same base column trigger canonicalization, + // so the optimized plan must contain FILTER clauses. + val planStr = withSQLConf( + SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") { + val df = spark.sql( + """SELECT key, + | COUNT(DISTINCT IF(col1 > 10, col2, NULL)) as cnt1, + | COUNT(DISTINCT IF(col1 > 5, col2, NULL)) as cnt2 + |FROM t GROUP BY key""".stripMargin) + df.queryExecution.optimizedPlan.toString + } + + assert(planStr.contains("FILTER"), + s"Optimized plan should contain FILTER clause. Plan:\n$planStr") + } + } +}