Production-tested recipe for running Qwen3.5-397B-A17B with INT4 AutoRound quantization across 4 DGX Spark nodes using tensor parallelism. Performance (4× DGX Spark, driver 580.126.09): - Single user: 37 tok/s - 4 concurrent: ~26 tok/s per user, ~103 tok/s aggregate The Marlin TP fix resolves the MIN_THREAD_N=64 constraint that breaks in_proj_ba layers at TP=4 (output_size=128/4=32 < 64). Solution: ReplicatedLinear for B/A projections, applied via diff patches. Key config: - VLLM_MARLIN_USE_ATOMIC_ADD=1 (required for Marlin correctness) - KV cache FP8, prefix caching enabled - gpu_memory_utilization 0.78 (UMA safe margin) - CUDAGraphs enabled (default, requires driver 580.x) Note: Driver 590.x has CUDAGraph capture deadlock on GB10 unified memory. Stay on driver 580.126.09.
57 lines
2.4 KiB
Diff
57 lines
2.4 KiB
Diff
--- qwen3_next.py.orig 2026-03-03 00:00:00.000000000 +0000
|
|
+++ qwen3_next.py 2026-03-03 00:00:00.000000000 +0000
|
|
@@ -411,15 +411,22 @@
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.in_proj_qkvz",
|
|
)
|
|
- # ba_proj doesn't support blockwise fp8 quantization.
|
|
- # # in_proj_ba is defined as MergedColumnParallelLinear for
|
|
- # compatibility with Qwen3_5.
|
|
- self.in_proj_ba = MergedColumnParallelLinear(
|
|
+ # ba_proj: Use ReplicatedLinear to avoid Marlin TP split constraint
|
|
+ # (num_v_heads=64 is too small for TP=4 Marlin min_thread_n=64).
|
|
+ # Each rank loads full weights and slices in forward.
|
|
+ self.in_proj_b = ReplicatedLinear(
|
|
input_size=self.hidden_size,
|
|
- output_sizes=[self.num_v_heads] * 2,
|
|
+ output_size=self.num_v_heads,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
- prefix=f"{prefix}.in_proj_ba",
|
|
+ prefix=f"{prefix}.in_proj_b",
|
|
+ )
|
|
+ self.in_proj_a = ReplicatedLinear(
|
|
+ input_size=self.hidden_size,
|
|
+ output_size=self.num_v_heads,
|
|
+ bias=False,
|
|
+ quant_config=quant_config,
|
|
+ prefix=f"{prefix}.in_proj_a",
|
|
)
|
|
|
|
query_key_settings = (self.key_dim, 0, False)
|
|
@@ -584,7 +591,15 @@
|
|
# Part 1: Input Projection
|
|
# ============================================================
|
|
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
|
|
- projected_states_ba, _ = self.in_proj_ba(hidden_states)
|
|
+ # Replicated B/A projections — full output, sliced to local TP partition
|
|
+ b_full, _ = self.in_proj_b(hidden_states)
|
|
+ a_full, _ = self.in_proj_a(hidden_states)
|
|
+ _ba_chunk = self.num_v_heads // self.tp_size
|
|
+ _ba_start = self.tp_rank * _ba_chunk
|
|
+ projected_states_ba = torch.cat([
|
|
+ b_full[:, _ba_start:_ba_start+_ba_chunk],
|
|
+ a_full[:, _ba_start:_ba_start+_ba_chunk],
|
|
+ ], dim=-1)
|
|
query, key, value, z, b, a = self.fix_query_key_value_ordering(
|
|
projected_states_qkvz, projected_states_ba
|
|
)
|
|
@@ -1326,7 +1341,6 @@
|
|
],
|
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
|
"in_proj_qkvz": ["in_proj_qkvz"],
|
|
- "in_proj_ba": ["in_proj_ba"],
|
|
}
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|