Files
2026-01-29 18:18:00 -08:00

14 lines
749 B
Diff

--- a/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/mla/triton_mla.py
+++ b/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/mla/triton_mla.py
@@ -135,7 +135,9 @@
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
# For batch invariance, use only 1 split to ensure deterministic reduction
- num_kv_splits = 1 if vllm_is_batch_invariant() else 4
+ # Dynamic splits: ~1.5K tokens per split, clamped to [32, 128]
+ max_seq_len = int(attn_metadata.decode.seq_lens.max().item())
+ num_kv_splits = 1 if vllm_is_batch_invariant() else max(32, min(128, max_seq_len // 1500))
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(