diff --git a/src/maxdiffusion/kernels/custom_splash_attention.py b/src/maxdiffusion/kernels/custom_splash_attention.py index 5cf8cb6bd..fb50a51a9 100644 --- a/src/maxdiffusion/kernels/custom_splash_attention.py +++ b/src/maxdiffusion/kernels/custom_splash_attention.py @@ -359,6 +359,7 @@ def _splash_attention_forward( kv_seq_len: int | None = None, use_base2_exp: bool = True, use_experimental_scheduler: bool = False, + vmem_limit_bytes: int | None = None, ): num_q_heads, padded_q_seq_len, head_dim_qk = q.shape head_dim_v = v.shape[-1] @@ -429,6 +430,7 @@ def v_index_map(h, i, j, *_): flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": use_experimental_scheduler}, disable_bounds_checks=True, skip_device_barrier=True, + vmem_limit_bytes=vmem_limit_bytes, ), out_shape=out_shapes, )(q, k, v) @@ -446,6 +448,7 @@ def _splash_attention_forward_mhpt( kv_seq_len: int | None = None, use_base2_exp: bool = True, use_experimental_scheduler: bool = False, + vmem_limit_bytes: int | None = None, ): num_q_heads, padded_q_seq_len, head_dim_qk = q.shape head_dim_v = v.shape[-1] @@ -518,6 +521,7 @@ def out_index_map(h, i, j, *_): flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": use_experimental_scheduler}, disable_bounds_checks=True, skip_device_barrier=True, + vmem_limit_bytes=vmem_limit_bytes, ), out_shape=out_shapes, )(q, k, v) @@ -532,6 +536,7 @@ def make_splash_mha( heads_per_tile: int = 1, use_base2_exp: bool = True, use_experimental_scheduler: bool = False, + vmem_limit_bytes: int | None = None, ): def _splash_attention(q, k, v): if heads_per_tile > 1: @@ -546,6 +551,7 @@ def _splash_attention(q, k, v): kv_seq_len=orig_kv_seq_len, use_base2_exp=use_base2_exp, use_experimental_scheduler=use_experimental_scheduler, + vmem_limit_bytes=vmem_limit_bytes, ) return _splash_attention_forward( q, @@ -557,6 +563,7 @@ def _splash_attention(q, k, v): kv_seq_len=orig_kv_seq_len, use_base2_exp=use_base2_exp, use_experimental_scheduler=use_experimental_scheduler, + vmem_limit_bytes=vmem_limit_bytes, ) return _splash_attention @@ -581,6 +588,7 @@ def tpu_custom_attention( heads_per_tile=None, use_base2_exp=True, use_experimental_scheduler=False, + vmem_limit_bytes=None, flash_block_sizes=None, ): _LOG2_E = 1.44269504 @@ -592,6 +600,7 @@ def tpu_custom_attention( block_kv_compute = flash_block_sizes.get("block_kv_compute", block_kv_compute) block_kv_compute_in = flash_block_sizes.get("block_kv_compute_in", block_kv_compute_in) heads_per_tile = flash_block_sizes.get("heads_per_tile", heads_per_tile) + vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", vmem_limit_bytes) block_q = block_q if block_q is not None else DEFAULT_BQSIZE block_kv = block_kv if block_kv is not None else DEFAULT_BKVSIZE @@ -639,6 +648,7 @@ def _kernel_3d(q_3d, k_3d, v_3d): heads_per_tile=heads_per_tile, use_base2_exp=use_base2_exp, use_experimental_scheduler=use_experimental_scheduler, + vmem_limit_bytes=vmem_limit_bytes, ) out = splash_kernel( q_3d_padded.astype(jnp.bfloat16), @@ -706,6 +716,7 @@ def make_custom_splash_sdpa(mesh, env, **kwargs): use_k_smooth = kwargs.get("use_k_smooth", True) use_base2_exp = kwargs.get("use_base2_exp", True) use_experimental_scheduler = kwargs.get("use_experimental_scheduler", False) + vmem_limit_bytes = kwargs.get("vmem_limit_bytes", None) def _simple_attention(q, k, v, scale=None): s = scale if scale is not None else 1.0 / math.sqrt(q.shape[-1]) @@ -747,6 +758,7 @@ def _sdpa( heads_per_tile=hpt, use_base2_exp=use_base2_exp, use_experimental_scheduler=use_experimental_scheduler, + vmem_limit_bytes=vmem_limit_bytes, flash_block_sizes=flash_block_sizes, ) return env.j2t_iso(result) diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 79de06f6b..8f1df7b48 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -17,6 +17,7 @@ # pylint: disable=bare-except, consider-using-generator """ Common Max Utils needed by multiple modules""" +import dataclasses import functools from functools import partial, reduce from contextlib import nullcontext @@ -612,12 +613,44 @@ def value_or_none(flash_block_sizes, key): return None +@dataclasses.dataclass(frozen=True) +class CustomFlashBlockSizes: + """Hashable carrier for the custom splash kernel's block sizes. + + The JAX `splash_attention_kernel.BlockSizes` is frozen + slotted and only has + fields for block_q/block_kv/block_kv_compute — it silently drops + block_kv_compute_in, heads_per_tile, and vmem_limit_bytes, which the custom + kernel needs. A plain dict would carry them but is unhashable (it ends up in + nnx's static graphdef, which jit requires to be hashable). This frozen + dataclass is hashable and is read via getattr in wrap_ulysses_attention. + """ + + block_q: int | None = None + block_kv: int | None = None + block_kv_compute: int | None = None + block_kv_compute_in: int | None = None + heads_per_tile: int | None = None + vmem_limit_bytes: int | None = None + + def get_flash_block_sizes(config): """Create custom flash attention BlockSizes.""" flash_block_sizes = None if len(config.flash_block_sizes.keys()) > 0: attention_is_tokamax = "tokamax" in config.attention user_block_sizes: Dict[str, int] = config.flash_block_sizes + # The custom splash kernel reads flash_block_sizes via getattr and needs + # fields the JAX BlockSizes dataclass cannot hold. Return a frozen, hashable + # carrier so they survive the trip to wrap_ulysses_attention. + if "custom" in config.attention: + return CustomFlashBlockSizes( + block_q=user_block_sizes.get("block_q"), + block_kv=user_block_sizes.get("block_kv"), + block_kv_compute=user_block_sizes.get("block_kv_compute"), + block_kv_compute_in=user_block_sizes.get("block_kv_compute_in"), + heads_per_tile=user_block_sizes.get("heads_per_tile"), + vmem_limit_bytes=user_block_sizes.get("vmem_limit_bytes"), + ) if attention_is_tokamax: max_logging.log( "Tokamax kernel specified, Note: Tokamax only supports fused backward kernel." diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index e82b2b0a9..c8d94bef4 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -581,6 +581,7 @@ def wrap_ulysses_attention(query, key, value): bkv_compute = 1024 bkv_compute_in = 1024 heads_per_tile = 1 + vmem_limit_bytes = None if flash_block_sizes is not None: if isinstance(flash_block_sizes, dict): @@ -589,12 +590,14 @@ def wrap_ulysses_attention(query, key, value): bkv_compute = flash_block_sizes.get("block_kv_compute", bkv_compute) bkv_compute_in = flash_block_sizes.get("block_kv_compute_in", bkv_compute_in) heads_per_tile = flash_block_sizes.get("heads_per_tile", heads_per_tile) + vmem_limit_bytes = flash_block_sizes.get("vmem_limit_bytes", vmem_limit_bytes) else: bq = getattr(flash_block_sizes, "block_q", bq) bkv = getattr(flash_block_sizes, "block_kv", bkv) bkv_compute = getattr(flash_block_sizes, "block_kv_compute", bkv_compute) bkv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", bkv_compute_in) heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", heads_per_tile) + vmem_limit_bytes = getattr(flash_block_sizes, "vmem_limit_bytes", vmem_limit_bytes) if use_base2_exp: query = query * LOG2E @@ -613,6 +616,7 @@ def wrap_ulysses_attention(query, key, value): heads_per_tile=heads_per_tile, use_base2_exp=use_base2_exp, use_experimental_scheduler=use_experimental_scheduler, + vmem_limit_bytes=vmem_limit_bytes, ) vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0))