Skip to content

Fix: Plumb vmem limit and bkv_compute_in to the custom kernel#416

Open
eltsai wants to merge 1 commit into
mainfrom
plumb_vmem_limit
Open

Fix: Plumb vmem limit and bkv_compute_in to the custom kernel#416
eltsai wants to merge 1 commit into
mainfrom
plumb_vmem_limit

Conversation

@eltsai

@eltsai eltsai commented Jun 9, 2026

Copy link
Copy Markdown
Collaborator

Right now the vmem_limit_bytes and block_kv_compute_in from the argument was never passed into the kernel.

This PR plumbs vmem_limit_bytes from the flash_block_sizes config down to the Mosaic
CompilerParams so the custom splash kernel can run under a raised scoped-VMEM
limit.

  • custom_splash_attention.py: vmem_limit_bytes param through both forwards ->
    pltpu.CompilerParams, make_splash_mha, tpu_custom_attention (+ flash_block_sizes
    read), make_custom_splash_sdpa (+ kwargs read).
  • attention_flax.py: read vmem_limit_bytes in wrap_ulysses_attention, pass to
    make_splash_mha.
  • max_utils.py: CustomFlashBlockSizes frozen carrier so vmem_limit_bytes (and
    block_kv_compute_in / heads_per_tile) survive config -> kernel instead of being
    dropped by the JAX BlockSizes dataclass.

By raising the vmem to the 64 MB limit, our grid search finds a better config:

sweep1 sweep2
export BQ=8448
export BKV=1024
export BKV_COMPUTE=1024
export BKV_COMPUTE_IN=1024
export HEADS_PER_TILE=1

E2E result:

==================================================
  TIMING SUMMARY
==================================================
  Load (checkpoint):     141.2s
  Compile:               183.3s
  Inference:             126.9s
  ────────────────────────────────────────
  Conditioning:            2.2s
    - VAE Encode:          0.0s
  Denoise Total:         122.6s
  VAE Decode:              2.1s
    - TPU Compute:         1.7s
    - Host Formatting:     0.4s
==================================================

Comparing to previous config (BQ=4864), the denoising steps is 3.7% faster (122.6 sec vs 127.3 sec).

@eltsai eltsai requested a review from entrpn as a code owner June 9, 2026 22:18
@github-actions

github-actions Bot commented Jun 9, 2026

Copy link
Copy Markdown

@eltsai eltsai requested a review from Perseus14 June 9, 2026 22:18
@Perseus14

Copy link
Copy Markdown
Collaborator

Pretty great find @eltsai, awesome work!

Few questions

  • Could you clarify on the usage of vmem_limit? Do we add this param in the config file under flash_block_sizes, if so can we update the configs as well (if it doesn't affect other attention mechanisms)?
  • Does setting the XLA flag for vmem_limit_bytes (eg: xla_tpu_scoped_vmem_limit_kib) affect this value?
  • Can we update the PR title and description as block_kv_compute_in seems to be already passed into the kernel?

@github-actions

Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 I'm sorry @Perseus14, but I was unable to process your request. Please see the logs for more details.

@github-actions

Copy link
Copy Markdown

🤖 Hi @Perseus14, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request correctly plumbs the vmem_limit_bytes configuration option and other block-level tuning parameters (such as block_kv_compute_in and heads_per_tile) all the way down to the custom TPU Pallas splash attention kernel. This allows executing under raised scoped VMEM limits, unlocking optimal layout sweeping and a ~3.7% performance improvement in denoising.

🔍 General Feedback

  • Preservation of Configuration Parameters: Introducing the hashable, frozen CustomFlashBlockSizes carrier is an excellent design decision to prevent JAX from silently dropping these custom properties.
  • Robust Config Extraction: The extraction of custom properties in attention_flax.py correctly handles both raw dicts and custom objects, which provides excellent resilience against different configuration entry points.
  • Consistency Check: Please verify that tpu_custom_attention (often used by TorchAX or standalone tests) uses the same object-safe extraction logic to prevent AttributeError crashes, as detailed in the inline comment.

block_kv_compute_in = flash_block_sizes.get("block_kv_compute_in", block_kv_compute_in)
heads_per_tile = flash_block_sizes.get("heads_per_tile", heads_per_tile)
vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", vmem_limit_bytes)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 With the introduction of the frozen dataclass CustomFlashBlockSizes in max_utils.py to preserve custom properties across JAX boundaries, passing flash_block_sizes to tpu_custom_attention (e.g., via the TorchAX SDPA wrapper path in make_custom_splash_sdpa) will result in a runtime crash. Since flash_block_sizes is no longer guaranteed to be a dictionary, calling .get() on a CustomFlashBlockSizes instance will raise an AttributeError.

We should safely extract values by checking if flash_block_sizes is a dictionary or an object, similar to how it is done in attention_flax.py.

The entire block starting from line 597 should be refactored as follows:

  if flash_block_sizes is not None:
    if isinstance(flash_block_sizes, dict):
      block_q = flash_block_sizes.get("block_q", block_q)
      block_kv = flash_block_sizes.get("block_kv", block_kv)
      block_kv_compute = flash_block_sizes.get("block_kv_compute", block_kv_compute)
      block_kv_compute_in = flash_block_sizes.get("block_kv_compute_in", block_kv_compute_in)
      heads_per_tile = flash_block_sizes.get("heads_per_tile", heads_per_tile)
      vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", vmem_limit_bytes)
    else:
      block_q = getattr(flash_block_sizes, "block_q", block_q)
      block_kv = getattr(flash_block_sizes, "block_kv", block_kv)
      block_kv_compute = getattr(flash_block_sizes, "block_kv_compute", block_kv_compute)
      block_kv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", block_kv_compute_in)
      heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", heads_per_tile)
      vmem_limit_bytes = getattr(flash_block_sizes, "vmem_limit_bytes", vmem_limit_bytes)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants