Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.util.collection.Utils

Expand Down Expand Up @@ -223,8 +224,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
case a: Aggregate if mayNeedtoRewrite(a) => rewrite(a)
}

def rewrite(a: Aggregate): Aggregate = {
def rewrite(origAgg: Aggregate): Aggregate = {

val a = normalizeCountDistinctConditional(origAgg)
val aggExpressions = collectAggregateExprs(a)
val distinctAggs = aggExpressions.filter(_.isDistinct)

Expand Down Expand Up @@ -419,6 +421,46 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}
}

/**
* Canonicalizes COUNT(DISTINCT IF(cond, base, NULL)) and
* COUNT(DISTINCT CASE WHEN cond THEN base END) to COUNT(DISTINCT base) FILTER (WHERE cond).
* This reduces the number of distinct groups: multiple conditional counts on the same base
* column collapse into one group, shrinking the Expand fan-out from Nx to 1x.
*/
private def normalizeCountDistinctConditional(a: Aggregate): Aggregate = {
if (!SQLConf.get.rewriteCountDistinctConditionalEnabled) return a
a.transformExpressionsUp {
case ae @ AggregateExpression(count: Count, _, true, None, _)
if count.children.size == 1 =>
extractCondAndBase(count.children.head) match {
case Some((cond, base)) =>
ae.copy(
aggregateFunction = count.withNewChildren(Seq(base)).asInstanceOf[Count],
filter = Some(cond))
case None => ae
}
}.asInstanceOf[Aggregate]
}

/**
* Matches IF(cond, base, null), CASE WHEN cond THEN base END, and
* CASE WHEN cond THEN base ELSE NULL END (including null wrapped in Cast).
* Returns None for anything else.
*/
private def extractCondAndBase(expr: Expression): Option[(Expression, Expression)] =
expr match {
case If(cond, base, e) if isNullExpr(e) => Some((cond, base))
case CaseWhen(Seq((cond, base)), None) => Some((cond, base))
case CaseWhen(Seq((cond, base)), Some(e)) if isNullExpr(e) => Some((cond, base))
case _ => None
}

private def isNullExpr(e: Expression): Boolean = e match {
case Literal(null, _) => true
case Cast(child, _, _, _) => isNullExpr(child)
case _ => false
}

private def collectAggregateExprs(a: Aggregate): Seq[AggregateExpression] = {
// Collect all aggregate expressions.
a.aggregateExpressions.flatMap { _.collect {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,18 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED =
buildConf("spark.sql.optimizer.rewriteCountDistinctConditional.enabled")
.doc("When true, rewrites COUNT(DISTINCT IF(cond, base, NULL)) and " +
"COUNT(DISTINCT CASE WHEN cond THEN base END) into " +
"COUNT(DISTINCT base) FILTER (WHERE cond). This reduces the Expand factor " +
"in RewriteDistinctAggregates from Nx to 1x when multiple conditional distinct " +
"counts share the same base column.")
.version("4.2.0")
.withBindingPolicy(ConfigBindingPolicy.SESSION)
.booleanConf
.createWithDefault(false)

val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
.internal()
.doc("When true, string literals (including regex patterns) remain escaped in our SQL " +
Expand Down Expand Up @@ -8473,6 +8485,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def decorrelateInnerQueryEnabledForExistsIn: Boolean =
!getConf(SQLConf.DECORRELATE_EXISTS_IN_SUBQUERY_LEGACY_INCORRECT_COUNT_HANDLING_ENABLED)

def rewriteCountDistinctConditionalEnabled: Boolean =
getConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED)

def maxConcurrentOutputFileWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS)

def plannedWriteEnabled: Boolean = getConf(SQLConf.PLANNED_WRITE_ENABLED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Literal, Round}
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.expressions.{CaseWhen, Expression, If, Literal, Round}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count, Sum}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
Expand Down Expand Up @@ -125,4 +126,234 @@ class RewriteDistinctAggregatesSuite extends PlanTest {
fail(s"Plan is not as expected:\n$rewrite")
}
}

// ---------------------------------------------------------------------------
// COUNT(DISTINCT IF/CASE) canonicalization (SPARK-56898)
// ---------------------------------------------------------------------------

val conditionalTestRelation = LocalRelation(
Symbol("a").int, Symbol("b").int, Symbol("c").int, Symbol("d").string)

private def countDistinctIf(cond: Expression, base: Expression): Expression = {
Count(If(cond, base, Literal(null))).toAggregateExpression(isDistinct = true)
}

private def countDistinctCaseWhen(cond: Expression, base: Expression): Expression = {
val caseWhen = CaseWhen(
Seq((cond, base)),
None)
Count(caseWhen).toAggregateExpression(isDistinct = true)
}

private def countDistinctCaseWhenElseNull(cond: Expression, base: Expression): Expression = {
val caseWhen = CaseWhen(
Seq((cond, base)),
Some(Literal(null)))
Count(caseWhen).toAggregateExpression(isDistinct = true)
}

/**
* Asserts that the optimized plan has exactly one Expand node with one projection,
* that the projection contains `baseColName` as a plain attribute, and that it
* contains no expression of `removedWrapperType` (the IF/CaseWhen that was stripped).
*/
private def assertSingleDistinctGroupExpand(
optimized: LogicalPlan,
baseColName: String,
removedWrapperType: Class[_]): Unit = {
val expand = optimized.collectFirst { case e: Expand => e }.get
assert(expand.projections.size == 1,
s"expected 1 distinct group but got ${expand.projections.size}")
val baseAttr = conditionalTestRelation.output.find(_.name == baseColName).get
assert(expand.projections.head.exists(_.semanticEquals(baseAttr)),
s"expected base column $baseColName in Expand projection")
assert(!expand.projections.head.exists(e => removedWrapperType.isInstance(e)),
s"${removedWrapperType.getSimpleName} wrapper should have been removed " +
"from Expand projection")
}

test("conditional: disabled by default") {
val input = conditionalTestRelation
.groupBy(Symbol("a"))(
countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1"))
.analyze
val optimized = RewriteDistinctAggregates(input)
comparePlans(optimized, input)
}

test("conditional: rewrite COUNT(DISTINCT IF(cond, col, NULL)) to COUNT(DISTINCT col) FILTER") {
val input = conditionalTestRelation
.groupBy(Symbol("a"))(
countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1"),
countDistinctIf(Symbol("b") > 2, Symbol("c")).as("cnt2"))
.analyze

withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") {
val optimized = RewriteDistinctAggregates(input)
assertSingleDistinctGroupExpand(optimized, "c", classOf[If])
}
}

test("conditional: rewrite COUNT(DISTINCT CASE WHEN cond THEN col END) to " +
"COUNT(DISTINCT col) FILTER") {
val input = conditionalTestRelation
.groupBy(Symbol("a"))(
countDistinctCaseWhen(Symbol("b") > 1, Symbol("c")).as("cnt1"),
countDistinctCaseWhen(Symbol("b") > 2, Symbol("c")).as("cnt2"))
.analyze

withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") {
val optimized = RewriteDistinctAggregates(input)
assertSingleDistinctGroupExpand(optimized, "c", classOf[CaseWhen])
}
}

test("conditional: rewrite COUNT(DISTINCT CASE WHEN cond THEN col ELSE NULL END)") {
val input = conditionalTestRelation
.groupBy(Symbol("a"))(
countDistinctCaseWhenElseNull(Symbol("b") > 1, Symbol("c")).as("cnt1"),
countDistinctCaseWhenElseNull(Symbol("b") > 2, Symbol("c")).as("cnt2"))
.analyze

withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") {
val optimized = RewriteDistinctAggregates(input)
assertSingleDistinctGroupExpand(optimized, "c", classOf[CaseWhen])
}
}

test("conditional: multiple conditional distinct counts collapse to single distinct group") {
val input = conditionalTestRelation
.groupBy(Symbol("a"))(
countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1"),
countDistinctIf(Symbol("b") > 2, Symbol("c")).as("cnt2"),
countDistinctIf(Symbol("b") > 3, Symbol("c")).as("cnt3"))
.analyze

withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") {
val optimized = RewriteDistinctAggregates(input)
// All three counts share the same base column c, collapsed to 1 distinct group.
assertSingleDistinctGroupExpand(optimized, "c", classOf[If])
}
}

test("conditional: single conditional distinct count does not produce Expand") {
// A lone COUNT(DISTINCT IF(...)) must NOT be pushed onto the Expand path.
// The canonicalization runs inside rewrite(), which is only called when
// mayNeedtoRewrite returns true. A single conditional distinct count has no filter
// and forms only one distinct group, so mayNeedtoRewrite returns false, rewrite()
// is never called, and no Expand is produced.
val input = conditionalTestRelation
.groupBy(Symbol("a"))(
countDistinctIf(Symbol("b") > 1, Symbol("c")).as("cnt1"))
.analyze

withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") {
val optimized = RewriteDistinctAggregates(input)
val expands = optimized.collect { case e: Expand => e }
assert(expands.isEmpty, "single conditional distinct count should not produce an Expand")
}
}

test("conditional: do not rewrite IF with non-null else branch") {
val input = conditionalTestRelation
.groupBy(Symbol("a"))(
Count(If(Symbol("b") > 1, Symbol("c"), Literal(0, IntegerType)))
.toAggregateExpression(isDistinct = true)
.as("cnt1"),
Count(If(Symbol("b") > 2, Symbol("c"), Literal(0, IntegerType)))
.toAggregateExpression(isDistinct = true)
.as("cnt2"))
.analyze

withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") {
val optimized = RewriteDistinctAggregates(input)
// Plan is rewritten (2 distinct groups) but not canonicalized
checkRewrite(optimized)
val expands = optimized.collect { case e: Expand => e }
assert(expands.head.projections.size == 2,
"non-null else branch should not be collapsed to 1 distinct group")
}
}

test("conditional: do not rewrite non-distinct COUNT") {
val input = conditionalTestRelation
.groupBy(Symbol("a"))(
Count(If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType)))
.toAggregateExpression(isDistinct = false)
.as("cnt1"),
countDistinct(Symbol("c")).as("cnt2"))
.analyze

withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") {
val optimized = RewriteDistinctAggregates(input)
// Still a single distinct group (cnt2), no canonicalization of the non-distinct agg
comparePlans(optimized, input)
}
}

test("conditional: do not rewrite when FILTER already exists") {
val input = conditionalTestRelation
.groupBy(Symbol("a"))(
Count(If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType)))
.toAggregateExpression(isDistinct = true, filter = Some(Symbol("d") === "x"))
.as("cnt1"),
Count(If(Symbol("b") > 2, Symbol("c"), Literal(null, IntegerType)))
.toAggregateExpression(isDistinct = true, filter = Some(Symbol("d") === "y"))
.as("cnt2"))
.analyze

withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") {
val optimized = RewriteDistinctAggregates(input)
checkRewrite(optimized)
val expands = optimized.collect { case e: Expand => e }
// 2 groups because existing FILTERs prevent canonicalization
assert(expands.head.projections.size == 2)
}
}

test("conditional: do not rewrite multi-branch CASE WHEN") {
val caseWhen1 = new CaseWhen(
Seq(
(Symbol("b") > Literal(1), Symbol("c")),
(Symbol("b") > Literal(2), Symbol("a"))),
Some(Literal(null)))
val caseWhen2 = new CaseWhen(
Seq(
(Symbol("b") > Literal(3), Symbol("c")),
(Symbol("b") > Literal(4), Symbol("a"))),
Some(Literal(null)))
val input = conditionalTestRelation
.groupBy(Symbol("a"))(
Count(caseWhen1).toAggregateExpression(isDistinct = true).as("cnt1"),
Count(caseWhen2).toAggregateExpression(isDistinct = true).as("cnt2"))
.analyze

withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") {
val optimized = RewriteDistinctAggregates(input)
checkRewrite(optimized)
val expands = optimized.collect { case e: Expand => e }
// 2 groups - multi-branch CASE was not canonicalized
assert(expands.head.projections.size == 2)
}
}

test("conditional: do not rewrite SUM(DISTINCT IF(...))") {
val input = conditionalTestRelation
.groupBy(Symbol("a"))(
Sum(If(Symbol("b") > 1, Symbol("c"), Literal(null, IntegerType)))
.toAggregateExpression(isDistinct = true)
.as("sum1"),
Sum(If(Symbol("b") > 2, Symbol("c"), Literal(null, IntegerType)))
.toAggregateExpression(isDistinct = true)
.as("sum2"))
.analyze

withSQLConf(SQLConf.REWRITE_COUNT_DISTINCT_CONDITIONAL_ENABLED.key -> "true") {
val optimized = RewriteDistinctAggregates(input)
checkRewrite(optimized)
val expands = optimized.collect { case e: Expand => e }
// 2 groups - SUM(DISTINCT IF) is not canonicalized
assert(expands.head.projections.size == 2)
}
}
}
Loading