--- a/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/mla/triton_mla.py +++ b/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/mla/triton_mla.py @@ -135,7 +135,9 @@ lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device) # For batch invariance, use only 1 split to ensure deterministic reduction - num_kv_splits = 1 if vllm_is_batch_invariant() else 4 + # Dynamic splits: ~1.5K tokens per split, clamped to [32, 128] + max_seq_len = int(attn_metadata.decode.seq_lens.max().item()) + num_kv_splits = 1 if vllm_is_batch_invariant() else max(32, min(128, max_seq_len // 1500)) # TODO(lucas) Allocate ahead of time attn_logits = torch.empty(