Skip to content

[SPARK-56898] Rewrite COUNT(DISTINCT IF) to COUNT(DISTINCT) FILTER for Expand reduction#55925

Open
xumingming wants to merge 2 commits into
apache:masterfrom
xumingming:count-distinct-filter-rewrite
Open

[SPARK-56898] Rewrite COUNT(DISTINCT IF) to COUNT(DISTINCT) FILTER for Expand reduction#55925
xumingming wants to merge 2 commits into
apache:masterfrom
xumingming:count-distinct-filter-rewrite

Conversation

@xumingming

@xumingming xumingming commented May 17, 2026

Copy link
Copy Markdown

What changes were proposed in this pull request?

Adds RewriteCountDistinctConditional optimizer rule that canonicalizes:

  COUNT(DISTINCT IF(cond, base, NULL))
  COUNT(DISTINCT CASE WHEN cond THEN base END)

into:

  COUNT(DISTINCT base) FILTER (WHERE cond)

This reduces the number of distinct groups seen by RewriteDistinctAggregates from N (one per unique conditional expression) down to 1 (all share the same base column), collapsing the Expand factor from Nx to 1x.

Gated by spark.sql.optimizer.rewriteCountDistinctConditional.enabled (default: false).

Includes comprehensive unit tests for rewrite patterns and safety boundaries.

Why are the changes needed?

When a query contains many COUNT(DISTINCT IF(cond_i, col, NULL)) expressions over the same base column, RewriteDistinctAggregates treats each unique IF(...) expression as a distinct group. N conditions → N distinct groups → N× Expand amplification. In production workloads with 25–60 such expressions, this produces multi-terabyte shuffles and hour-long runtimes.

SELECT
  user_id,
  COUNT(DISTINCT IF(dt >= '2026-04-16', order_id, NULL)) AS orders_30d,
  COUNT(DISTINCT IF(dt >= '2026-02-15', order_id, NULL)) AS orders_90d,
  COUNT(DISTINCT IF(pay_status = 'paid', order_id, NULL)) AS orders_paid,
  -- ... 50 more expressions
FROM transactions
GROUP BY user_id 

Does this PR introduce any user-facing change?

No.

How was this patch tested?

Unit Test.

Was this patch authored or co-authored using generative AI tooling?

No.

@xumingming xumingming force-pushed the count-distinct-filter-rewrite branch 3 times, most recently from cc0a893 to 3bc29c1 Compare May 17, 2026 11:58
@xumingming xumingming changed the title [SPARK-56898] feat: rewrite COUNT(DISTINCT IF) to COUNT(DISTINCT) FILTER for Expand reduction [SPARK-56898] Rewrite COUNT(DISTINCT IF) to COUNT(DISTINCT) FILTER for Expand reduction May 17, 2026
@xumingming

Copy link
Copy Markdown
Author

@LuciferYang Can you help take a look at this PR?

… FILTER for Expand reduction

Adds RewriteCountDistinctConditional optimizer rule that canonicalizes:
  COUNT(DISTINCT IF(cond, base, NULL))
  COUNT(DISTINCT CASE WHEN cond THEN base END)
into:
  COUNT(DISTINCT base) FILTER (WHERE cond)

This reduces the number of distinct groups seen by RewriteDistinctAggregates
from N (one per unique conditional expression) down to 1 (all share the same
base column), collapsing the Expand factor from Nx to 1x.

Gated by spark.sql.optimizer.rewriteCountDistinctConditional.enabled (default: false).
@xumingming xumingming force-pushed the count-distinct-filter-rewrite branch from 3bc29c1 to 028cf1b Compare May 19, 2026 09:10
@xumingming

Copy link
Copy Markdown
Author

@cloud-fan Can you help taking a look at this PR?

@cloud-fan cloud-fan left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0 blocking, 2 non-blocking, 0 nits. A correct, well-tested, narrowly-scoped rule; the main point is structural.

Design / architecture (1)

  • RewriteCountDistinctConditional.scala:45: fold the canonicalization into RewriteDistinctAggregates instead of a separate rule ordered ahead of it in the same Once batch — removes a positional rule-ordering dependency and also eliminates a single-count Expand regression — see inline

Suggestions (1)

  • RewriteCountDistinctConditionalSuite.scala:139: assert(rewritten != optimized) doesn't verify the "single distinct group" collapse the test claims — see inline

Verification

Traced the COUNT(DISTINCT IF(cond, base, NULL)) (and CASE WHEN cond THEN base [ELSE NULL] END) -> COUNT(DISTINCT base) FILTER (WHERE cond) rewrite: since COUNT DISTINCT ignores NULLs, both forms count distinct non-null base over rows where cond is TRUE. Equivalent across cond=TRUE/FALSE/NULL, NULL base, duplicate base, empty input, and nondeterministic cond (evaluated once per row in both). The non-null-else, multi-branch-CASE, pre-existing-FILTER, non-distinct, non-Count, and multi-arg cases are all gated out (RewriteCountDistinctConditional.scala:49-54,78-90). The rule rewrites the aggregate in place (ae.copy, resultId preserved; Count stays LongType non-nullable), so no synthesized-operator / plan-integrity concern.

PR description suggestions

  • Document the win condition and tradeoff: the rewrite reduces work only when multiple conditional distinct counts share a base column, and (as a separate pre-pass) can regress a single such count onto the Expand path — the rationale for the default-off gate.

if (!SQLConf.get.rewriteCountDistinctConditionalEnabled) {
return plan
}
plan.transformUpWithPruning(_.containsPattern(AGGREGATE)) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider folding this canonicalization into RewriteDistinctAggregates rather than shipping it as a separate rule that must run before it in the same Once batch.

Rule-ordering dependency. The canonicalization is meaningful only as a pre-pass feeding RewriteDistinctAggregates's group-by key (ExpressionSet of unfoldable children). As a standalone rule, "runs first" rests purely on batch position in Optimizer.scala, not on a structural guarantee. The natural home is right before the groupBy in RewriteDistinctAggregates.rewrite: normalize each distinct Count(IF/CASE(cond, base, NULL)) to Count(base) with filter = cond, and the existing grouping collapses them onto base.

It also fixes a single-count pessimization for free. Run as a separate pre-pass, this rule blindly converts a lone COUNT(DISTINCT IF(...)) — which is planned today with no Expand via planAggregateWithOneDistinct — into a FILTER form that then trips mustRewrite and forces the Expand-based path (~2x row fan-out + an extra aggregation stage). Integrated into rewrite(), the normalization sits behind the mayNeedtoRewrite gate: a lone count (size == 1, no filter) is never entered (rewrite returns the input unchanged), so there's no regression; and for >=2 distinct aggregates — where the rule was already going to expand — canonicalization only reduces the group count (Nx -> 1x) and can never introduce an expansion the baseline wouldn't already do.

val rewritten = RewriteDistinctAggregates(optimized)
// Should be rewritten (not same as input) because there are now multiple
// distinct expressions with the same base column
assert(rewritten != optimized)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion only checks that some rewrite happened — it's true whether RewriteDistinctAggregates collapses the three counts into one distinct group or leaves them as three. It doesn't actually verify the "sees only 1 distinct group" collapse that the comment claims and that this PR exists to achieve. Consider asserting the resulting distinct-group / Expand projection count is 1 (e.g. count Expand output projections, or assert all three counts share a single group) so the test guards the collapse rather than just "changed".

@xumingming xumingming force-pushed the count-distinct-filter-rewrite branch from 65155e2 to 5dc1f60 Compare June 18, 2026 14:14
…istinctAggregates

The standalone RewriteCountDistinctConditional rule had two problems:
1. Its ordering guarantee rested purely on batch position in Optimizer.scala
2. It blindly rewrote a lone COUNT(DISTINCT IF(...)), forcing the Expand
   path and regressing performance vs the planAggregateWithOneDistinct fast path

Move the canonicalization logic into RewriteDistinctAggregates.rewrite() via a
new normalizeCountDistinctConditional helper. Since rewrite() is only called when
mayNeedtoRewrite returns true, a lone conditional distinct count never enters the
path and the regression is eliminated. The SQLConf flag and semantics are unchanged.

Delete RewriteCountDistinctConditional.scala and migrate all tests into
RewriteDistinctAggregatesSuite; rename the integration suite to
RewriteDistinctAggregatesConditionalQuerySuite.
@xumingming xumingming force-pushed the count-distinct-filter-rewrite branch from 5dc1f60 to 409c5e8 Compare June 18, 2026 16:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants