--- qwen3_next.py.orig 2026-03-03 00:00:00.000000000 +0000 +++ qwen3_next.py 2026-03-03 00:00:00.000000000 +0000 @@ -411,15 +411,22 @@ quant_config=quant_config, prefix=f"{prefix}.in_proj_qkvz", ) - # ba_proj doesn't support blockwise fp8 quantization. - # # in_proj_ba is defined as MergedColumnParallelLinear for - # compatibility with Qwen3_5. - self.in_proj_ba = MergedColumnParallelLinear( + # ba_proj: Use ReplicatedLinear to avoid Marlin TP split constraint + # (num_v_heads=64 is too small for TP=4 Marlin min_thread_n=64). + # Each rank loads full weights and slices in forward. + self.in_proj_b = ReplicatedLinear( input_size=self.hidden_size, - output_sizes=[self.num_v_heads] * 2, + output_size=self.num_v_heads, bias=False, quant_config=quant_config, - prefix=f"{prefix}.in_proj_ba", + prefix=f"{prefix}.in_proj_b", + ) + self.in_proj_a = ReplicatedLinear( + input_size=self.hidden_size, + output_size=self.num_v_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_a", ) query_key_settings = (self.key_dim, 0, False) @@ -584,7 +591,15 @@ # Part 1: Input Projection # ============================================================ projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) - projected_states_ba, _ = self.in_proj_ba(hidden_states) + # Replicated B/A projections — full output, sliced to local TP partition + b_full, _ = self.in_proj_b(hidden_states) + a_full, _ = self.in_proj_a(hidden_states) + _ba_chunk = self.num_v_heads // self.tp_size + _ba_start = self.tp_rank * _ba_chunk + projected_states_ba = torch.cat([ + b_full[:, _ba_start:_ba_start+_ba_chunk], + a_full[:, _ba_start:_ba_start+_ba_chunk], + ], dim=-1) query, key, value, z, b, a = self.fix_query_key_value_ordering( projected_states_qkvz, projected_states_ba ) @@ -1326,7 +1341,6 @@ ], "gate_up_proj": ["gate_proj", "up_proj"], "in_proj_qkvz": ["in_proj_qkvz"], - "in_proj_ba": ["in_proj_ba"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):