Files
spark-vllm-docker/mods/fix-qwen35-tp4-marlin/qwen3_5.patch
sonusflow 006734910c Add Qwen3.5-397B INT4-AutoRound TP=4 recipe and Marlin fix
Production-tested recipe for running Qwen3.5-397B-A17B with INT4 AutoRound
quantization across 4 DGX Spark nodes using tensor parallelism.

Performance (4× DGX Spark, driver 580.126.09):
- Single user: 37 tok/s
- 4 concurrent: ~26 tok/s per user, ~103 tok/s aggregate

The Marlin TP fix resolves the MIN_THREAD_N=64 constraint that breaks
in_proj_ba layers at TP=4 (output_size=128/4=32 < 64). Solution:
ReplicatedLinear for B/A projections, applied via diff patches.

Key config:
- VLLM_MARLIN_USE_ATOMIC_ADD=1 (required for Marlin correctness)
- KV cache FP8, prefix caching enabled
- gpu_memory_utilization 0.78 (UMA safe margin)
- CUDAGraphs enabled (default, requires driver 580.x)

Note: Driver 590.x has CUDAGraph capture deadlock on GB10 unified memory.
Stay on driver 580.126.09.
2026-03-09 21:30:28 +00:00

47 lines
1.9 KiB
Diff

--- qwen3_5.py.orig 2026-03-03 00:00:00.000000000 +0000
+++ qwen3_5.py 2026-03-03 00:00:00.000000000 +0000
@@ -166,11 +166,13 @@
z_size = self.value_dim // self.tp_size
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
z = z.reshape(z.size(0), -1, self.head_v_dim)
- ba, _ = self.in_proj_ba(hidden_states)
- b, a = ba.chunk(2, dim=-1)
-
- b = b.contiguous()
- a = a.contiguous()
+ # Replicated B/A projections — full output, sliced to local TP partition
+ b_full, _ = self.in_proj_b(hidden_states)
+ a_full, _ = self.in_proj_a(hidden_states)
+ _ba_chunk = self.num_v_heads // self.tp_size
+ _ba_start = self.tp_rank * _ba_chunk
+ b = b_full[:, _ba_start:_ba_start+_ba_chunk].contiguous()
+ a = a_full[:, _ba_start:_ba_start+_ba_chunk].contiguous()
# ============================================================
# Part 2: Core Attention (Custom Op)
@@ -374,8 +376,6 @@
# GDN
("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)),
("in_proj_qkvz", "in_proj_z", 3),
- ("in_proj_ba", "in_proj_b", 0),
- ("in_proj_ba", "in_proj_a", 1),
]
params_dict = dict(self.named_parameters())
@@ -530,7 +530,6 @@
"gate_up_proj": ["gate_proj", "up_proj"],
# GDN fused projections.
"in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
- "in_proj_ba": ["in_proj_b", "in_proj_a"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -630,7 +629,6 @@
class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid):
packed_modules_mapping = Qwen3VLForConditionalGeneration.packed_modules_mapping | {
"in_proj_qkvz": ["in_proj_qkv", "in_proj_z"],
- "in_proj_ba": ["in_proj_b", "in_proj_a"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):