Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/maxdiffusion/kernels/custom_splash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def _splash_attention_forward(
kv_seq_len: int | None = None,
use_base2_exp: bool = True,
use_experimental_scheduler: bool = False,
vmem_limit_bytes: int | None = None,
):
num_q_heads, padded_q_seq_len, head_dim_qk = q.shape
head_dim_v = v.shape[-1]
Expand Down Expand Up @@ -429,6 +430,7 @@ def v_index_map(h, i, j, *_):
flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": use_experimental_scheduler},
disable_bounds_checks=True,
skip_device_barrier=True,
vmem_limit_bytes=vmem_limit_bytes,
),
out_shape=out_shapes,
)(q, k, v)
Expand All @@ -446,6 +448,7 @@ def _splash_attention_forward_mhpt(
kv_seq_len: int | None = None,
use_base2_exp: bool = True,
use_experimental_scheduler: bool = False,
vmem_limit_bytes: int | None = None,
):
num_q_heads, padded_q_seq_len, head_dim_qk = q.shape
head_dim_v = v.shape[-1]
Expand Down Expand Up @@ -518,6 +521,7 @@ def out_index_map(h, i, j, *_):
flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": use_experimental_scheduler},
disable_bounds_checks=True,
skip_device_barrier=True,
vmem_limit_bytes=vmem_limit_bytes,
),
out_shape=out_shapes,
)(q, k, v)
Expand All @@ -532,6 +536,7 @@ def make_splash_mha(
heads_per_tile: int = 1,
use_base2_exp: bool = True,
use_experimental_scheduler: bool = False,
vmem_limit_bytes: int | None = None,
):
def _splash_attention(q, k, v):
if heads_per_tile > 1:
Expand All @@ -546,6 +551,7 @@ def _splash_attention(q, k, v):
kv_seq_len=orig_kv_seq_len,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
vmem_limit_bytes=vmem_limit_bytes,
)
return _splash_attention_forward(
q,
Expand All @@ -557,6 +563,7 @@ def _splash_attention(q, k, v):
kv_seq_len=orig_kv_seq_len,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
vmem_limit_bytes=vmem_limit_bytes,
)

return _splash_attention
Expand All @@ -581,6 +588,7 @@ def tpu_custom_attention(
heads_per_tile=None,
use_base2_exp=True,
use_experimental_scheduler=False,
vmem_limit_bytes=None,
flash_block_sizes=None,
):
_LOG2_E = 1.44269504
Expand All @@ -592,6 +600,7 @@ def tpu_custom_attention(
block_kv_compute = flash_block_sizes.get("block_kv_compute", block_kv_compute)
block_kv_compute_in = flash_block_sizes.get("block_kv_compute_in", block_kv_compute_in)
heads_per_tile = flash_block_sizes.get("heads_per_tile", heads_per_tile)
vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", vmem_limit_bytes)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 With the introduction of the frozen dataclass CustomFlashBlockSizes in max_utils.py to preserve custom properties across JAX boundaries, passing flash_block_sizes to tpu_custom_attention (e.g., via the TorchAX SDPA wrapper path in make_custom_splash_sdpa) will result in a runtime crash. Since flash_block_sizes is no longer guaranteed to be a dictionary, calling .get() on a CustomFlashBlockSizes instance will raise an AttributeError.

We should safely extract values by checking if flash_block_sizes is a dictionary or an object, similar to how it is done in attention_flax.py.

The entire block starting from line 597 should be refactored as follows:

  if flash_block_sizes is not None:
    if isinstance(flash_block_sizes, dict):
      block_q = flash_block_sizes.get("block_q", block_q)
      block_kv = flash_block_sizes.get("block_kv", block_kv)
      block_kv_compute = flash_block_sizes.get("block_kv_compute", block_kv_compute)
      block_kv_compute_in = flash_block_sizes.get("block_kv_compute_in", block_kv_compute_in)
      heads_per_tile = flash_block_sizes.get("heads_per_tile", heads_per_tile)
      vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", vmem_limit_bytes)
    else:
      block_q = getattr(flash_block_sizes, "block_q", block_q)
      block_kv = getattr(flash_block_sizes, "block_kv", block_kv)
      block_kv_compute = getattr(flash_block_sizes, "block_kv_compute", block_kv_compute)
      block_kv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", block_kv_compute_in)
      heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", heads_per_tile)
      vmem_limit_bytes = getattr(flash_block_sizes, "vmem_limit_bytes", vmem_limit_bytes)

block_q = block_q if block_q is not None else DEFAULT_BQSIZE
block_kv = block_kv if block_kv is not None else DEFAULT_BKVSIZE
Expand Down Expand Up @@ -639,6 +648,7 @@ def _kernel_3d(q_3d, k_3d, v_3d):
heads_per_tile=heads_per_tile,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
vmem_limit_bytes=vmem_limit_bytes,
)
out = splash_kernel(
q_3d_padded.astype(jnp.bfloat16),
Expand Down Expand Up @@ -706,6 +716,7 @@ def make_custom_splash_sdpa(mesh, env, **kwargs):
use_k_smooth = kwargs.get("use_k_smooth", True)
use_base2_exp = kwargs.get("use_base2_exp", True)
use_experimental_scheduler = kwargs.get("use_experimental_scheduler", False)
vmem_limit_bytes = kwargs.get("vmem_limit_bytes", None)

def _simple_attention(q, k, v, scale=None):
s = scale if scale is not None else 1.0 / math.sqrt(q.shape[-1])
Expand Down Expand Up @@ -747,6 +758,7 @@ def _sdpa(
heads_per_tile=hpt,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
vmem_limit_bytes=vmem_limit_bytes,
flash_block_sizes=flash_block_sizes,
)
return env.j2t_iso(result)
Expand Down
33 changes: 33 additions & 0 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# pylint: disable=bare-except, consider-using-generator
""" Common Max Utils needed by multiple modules"""
import dataclasses
import functools
from functools import partial, reduce
from contextlib import nullcontext
Expand Down Expand Up @@ -612,12 +613,44 @@ def value_or_none(flash_block_sizes, key):
return None


@dataclasses.dataclass(frozen=True)
class CustomFlashBlockSizes:
"""Hashable carrier for the custom splash kernel's block sizes.

The JAX `splash_attention_kernel.BlockSizes` is frozen + slotted and only has
fields for block_q/block_kv/block_kv_compute — it silently drops
block_kv_compute_in, heads_per_tile, and vmem_limit_bytes, which the custom
kernel needs. A plain dict would carry them but is unhashable (it ends up in
nnx's static graphdef, which jit requires to be hashable). This frozen
dataclass is hashable and is read via getattr in wrap_ulysses_attention.
"""

block_q: int | None = None
block_kv: int | None = None
block_kv_compute: int | None = None
block_kv_compute_in: int | None = None
heads_per_tile: int | None = None
vmem_limit_bytes: int | None = None


def get_flash_block_sizes(config):
"""Create custom flash attention BlockSizes."""
flash_block_sizes = None
if len(config.flash_block_sizes.keys()) > 0:
attention_is_tokamax = "tokamax" in config.attention
user_block_sizes: Dict[str, int] = config.flash_block_sizes
# The custom splash kernel reads flash_block_sizes via getattr and needs
# fields the JAX BlockSizes dataclass cannot hold. Return a frozen, hashable
# carrier so they survive the trip to wrap_ulysses_attention.
if "custom" in config.attention:
return CustomFlashBlockSizes(
block_q=user_block_sizes.get("block_q"),
block_kv=user_block_sizes.get("block_kv"),
block_kv_compute=user_block_sizes.get("block_kv_compute"),
block_kv_compute_in=user_block_sizes.get("block_kv_compute_in"),
heads_per_tile=user_block_sizes.get("heads_per_tile"),
vmem_limit_bytes=user_block_sizes.get("vmem_limit_bytes"),
)
if attention_is_tokamax:
max_logging.log(
"Tokamax kernel specified, Note: Tokamax only supports fused backward kernel."
Expand Down
4 changes: 4 additions & 0 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ def wrap_ulysses_attention(query, key, value):
bkv_compute = 1024
bkv_compute_in = 1024
heads_per_tile = 1
vmem_limit_bytes = None

if flash_block_sizes is not None:
if isinstance(flash_block_sizes, dict):
Expand All @@ -589,12 +590,14 @@ def wrap_ulysses_attention(query, key, value):
bkv_compute = flash_block_sizes.get("block_kv_compute", bkv_compute)
bkv_compute_in = flash_block_sizes.get("block_kv_compute_in", bkv_compute_in)
heads_per_tile = flash_block_sizes.get("heads_per_tile", heads_per_tile)
vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", vmem_limit_bytes)
else:
bq = getattr(flash_block_sizes, "block_q", bq)
bkv = getattr(flash_block_sizes, "block_kv", bkv)
bkv_compute = getattr(flash_block_sizes, "block_kv_compute", bkv_compute)
bkv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", bkv_compute_in)
heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", heads_per_tile)
vmem_limit_bytes = getattr(flash_block_sizes, "vmem_limit_bytes", vmem_limit_bytes)

if use_base2_exp:
query = query * LOG2E
Expand All @@ -613,6 +616,7 @@ def wrap_ulysses_attention(query, key, value):
heads_per_tile=heads_per_tile,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
vmem_limit_bytes=vmem_limit_bytes,
)

vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0))
Expand Down
Loading