14 lines
749 B
Diff
14 lines
749 B
Diff
--- a/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/mla/triton_mla.py
|
|
+++ b/usr/local/lib/python3.12/dist-packages/vllm/v1/attention/backends/mla/triton_mla.py
|
|
@@ -135,7 +135,9 @@
|
|
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
|
|
|
|
# For batch invariance, use only 1 split to ensure deterministic reduction
|
|
- num_kv_splits = 1 if vllm_is_batch_invariant() else 4
|
|
+ # Dynamic splits: ~1.5K tokens per split, clamped to [32, 128]
|
|
+ max_seq_len = int(attn_metadata.decode.seq_lens.max().item())
|
|
+ num_kv_splits = 1 if vllm_is_batch_invariant() else max(32, min(128, max_seq_len // 1500))
|
|
|
|
# TODO(lucas) Allocate ahead of time
|
|
attn_logits = torch.empty(
|