Qwen3-Coder-Next fixes and updated recipe
This commit is contained in:
10
README.md
10
README.md
@@ -164,6 +164,16 @@ Don't do it every time you rebuild, because it will slow down compilation times.
|
|||||||
|
|
||||||
For periodic maintenance, I recommend using a filter: `docker builder prune --filter until=72h`
|
For periodic maintenance, I recommend using a filter: `docker builder prune --filter until=72h`
|
||||||
|
|
||||||
|
### 2026-02-12
|
||||||
|
|
||||||
|
Added a mod for Qwen3-Coder-Next-FP8 that fixes:
|
||||||
|
|
||||||
|
- A bug with Triton allocator (https://github.com/vllm-project/vllm/issues/33857) that prevented the model to run in a cluster.
|
||||||
|
- A bug that introduced crash when `--enable-prefix-caching` is on (https://github.com/vllm-project/vllm/issues/34361).
|
||||||
|
- A bug that significantly impacted the performance on Spark (https://github.com/vllm-project/vllm/issues/34413).
|
||||||
|
|
||||||
|
This mod was included in `qwen3-coder-next-fp8` recipe.
|
||||||
|
|
||||||
### 2026-02-11
|
### 2026-02-11
|
||||||
|
|
||||||
#### Configurable GPU Architecture
|
#### Configurable GPU Architecture
|
||||||
|
|||||||
1
mods/fix-qwen3-coder-next/_triton_alloc_setup.pth
Normal file
1
mods/fix-qwen3-coder-next/_triton_alloc_setup.pth
Normal file
@@ -0,0 +1 @@
|
|||||||
|
import _triton_alloc_setup
|
||||||
9
mods/fix-qwen3-coder-next/_triton_alloc_setup.py
Normal file
9
mods/fix-qwen3-coder-next/_triton_alloc_setup.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
try:
|
||||||
|
import triton.runtime._allocation as _alloc
|
||||||
|
import torch
|
||||||
|
|
||||||
|
_alloc.NullAllocator.__call__ = staticmethod(
|
||||||
|
lambda size, alignment, stream:
|
||||||
|
torch.cuda.caching_allocator_alloc(size, stream=stream))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
14
mods/fix-qwen3-coder-next/fix_crash.diff
Normal file
14
mods/fix-qwen3-coder-next/fix_crash.diff
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py
|
||||||
|
index 0b6b7ed42ac1..b6e0305a312d 100644
|
||||||
|
--- a/vllm/v1/core/single_type_kv_cache_manager.py
|
||||||
|
+++ b/vllm/v1/core/single_type_kv_cache_manager.py
|
||||||
|
@@ -1000,7 +1000,8 @@ def cache_blocks(self, request: Request, num_tokens: int) -> None:
|
||||||
|
for block in self.req_to_blocks[request.request_id][
|
||||||
|
num_cached_blocks_before:num_cached_blocks_after
|
||||||
|
]:
|
||||||
|
- assert block.block_hash is not None
|
||||||
|
+ if block.is_null:
|
||||||
|
+ continue
|
||||||
|
self.cached_blocks_this_step.add(block.block_hash)
|
||||||
|
|
||||||
|
def new_step_starts(self) -> None:
|
||||||
72
mods/fix-qwen3-coder-next/fix_slowness.diff
Normal file
72
mods/fix-qwen3-coder-next/fix_slowness.diff
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
|
||||||
|
index 63aae43c3ddf..6ca3213fbd8d 100644
|
||||||
|
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
|
||||||
|
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
|
||||||
|
@@ -95,19 +95,19 @@ def fused_moe_kernel_gptq_awq(
|
||||||
|
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||||
|
# how much to increase `a_ptr` by to get the element one row down
|
||||||
|
# (A has M rows).
|
||||||
|
- stride_am,
|
||||||
|
- stride_ak,
|
||||||
|
- stride_be,
|
||||||
|
- stride_bk,
|
||||||
|
- stride_bn,
|
||||||
|
- stride_cm,
|
||||||
|
- stride_cn,
|
||||||
|
- stride_bse,
|
||||||
|
- stride_bsk,
|
||||||
|
- stride_bsn,
|
||||||
|
- stride_bze,
|
||||||
|
- stride_bzk,
|
||||||
|
- stride_bzn,
|
||||||
|
+ stride_am: tl.int64,
|
||||||
|
+ stride_ak: tl.int64,
|
||||||
|
+ stride_be: tl.int64,
|
||||||
|
+ stride_bk: tl.int64,
|
||||||
|
+ stride_bn: tl.int64,
|
||||||
|
+ stride_cm: tl.int64,
|
||||||
|
+ stride_cn: tl.int64,
|
||||||
|
+ stride_bse: tl.int64,
|
||||||
|
+ stride_bsk: tl.int64,
|
||||||
|
+ stride_bsn: tl.int64,
|
||||||
|
+ stride_bze: tl.int64,
|
||||||
|
+ stride_bzk: tl.int64,
|
||||||
|
+ stride_bzn: tl.int64,
|
||||||
|
block_k_diviable: tl.constexpr,
|
||||||
|
group_size: tl.constexpr,
|
||||||
|
# Meta-parameters
|
||||||
|
@@ -329,20 +329,20 @@ def fused_moe_kernel(
|
||||||
|
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||||
|
# how much to increase `a_ptr` by to get the element one row down
|
||||||
|
# (A has M rows).
|
||||||
|
- stride_am,
|
||||||
|
- stride_ak,
|
||||||
|
- stride_be,
|
||||||
|
- stride_bk,
|
||||||
|
- stride_bn,
|
||||||
|
- stride_cm,
|
||||||
|
- stride_cn,
|
||||||
|
- stride_asm,
|
||||||
|
- stride_ask,
|
||||||
|
- stride_bse,
|
||||||
|
- stride_bsk,
|
||||||
|
- stride_bsn,
|
||||||
|
- stride_bbe, # bias expert stride
|
||||||
|
- stride_bbn, # bias N stride
|
||||||
|
+ stride_am: tl.int64,
|
||||||
|
+ stride_ak: tl.int64,
|
||||||
|
+ stride_be: tl.int64,
|
||||||
|
+ stride_bk: tl.int64,
|
||||||
|
+ stride_bn: tl.int64,
|
||||||
|
+ stride_cm: tl.int64,
|
||||||
|
+ stride_cn: tl.int64,
|
||||||
|
+ stride_asm: tl.int64,
|
||||||
|
+ stride_ask: tl.int64,
|
||||||
|
+ stride_bse: tl.int64,
|
||||||
|
+ stride_bsk: tl.int64,
|
||||||
|
+ stride_bsn: tl.int64,
|
||||||
|
+ stride_bbe: tl.int64, # bias expert stride
|
||||||
|
+ stride_bbn: tl.int64, # bias N stride
|
||||||
|
# Block size for block-wise quantization
|
||||||
|
group_n: tl.constexpr,
|
||||||
|
group_k: tl.constexpr,
|
||||||
11
mods/fix-qwen3-coder-next/run.sh
Normal file
11
mods/fix-qwen3-coder-next/run.sh
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
set -e
|
||||||
|
|
||||||
|
echo "Patching Qwen3-Coder-Next crashing on start"
|
||||||
|
patch -p1 -d /usr/local/lib/python3.12/dist-packages < fix_crash.diff || echo "Patch is not applicable, skipping"
|
||||||
|
|
||||||
|
echo "Reverting PR #34279 that causes slowness"
|
||||||
|
patch -p1 -R -d /usr/local/lib/python3.12/dist-packages < fix_slowness.diff || echo "Reversing PR #34279 failed, skipping"
|
||||||
|
|
||||||
|
echo "Fixing Triton allocator bug"
|
||||||
|
cp _triton* /usr/local/lib/python3.12/dist-packages/
|
||||||
@@ -1,30 +1,30 @@
|
|||||||
# Recipe: Qwen3-Coder-Next-FP8
|
# Recipe: Qwen3-Coder-Next-FP8
|
||||||
# Qwen3-Coder-Next model in native FP8 format
|
# Qwen3-Coder-Next model in native FP8 format
|
||||||
# Currently can only be run in solo mode, cluster mode fails with error - tracking https://github.com/vllm-project/vllm/issues/33857
|
|
||||||
|
|
||||||
recipe_version: "1"
|
recipe_version: "1"
|
||||||
name: Qwen3-Coder-Next-FP8
|
name: Qwen3-Coder-Next-FP8
|
||||||
description: vLLM serving Qwen3-Coder-Next-FP8 on a SINGLE NODE ONLY!
|
description: vLLM serving Qwen3-Coder-Next-FP8
|
||||||
|
|
||||||
# HuggingFace model to download (optional, for --download-model)
|
# HuggingFace model to download (optional, for --download-model)
|
||||||
model: Qwen/Qwen3-Coder-Next-FP8
|
model: Qwen/Qwen3-Coder-Next-FP8
|
||||||
|
|
||||||
# This model can only run on single node (solo)
|
#solo_only: true
|
||||||
solo_only: true
|
|
||||||
|
|
||||||
# Container image to use
|
# Container image to use
|
||||||
container: vllm-node
|
container: vllm-node
|
||||||
|
|
||||||
# No mods required
|
# Mod required to fix slowness and crash in the cluster (tracking https://github.com/vllm-project/vllm/issues/33857)
|
||||||
mods: []
|
mods:
|
||||||
|
- mods/fix-qwen3-coder-next
|
||||||
|
|
||||||
# Default settings (can be overridden via CLI)
|
# Default settings (can be overridden via CLI)
|
||||||
defaults:
|
defaults:
|
||||||
port: 8000
|
port: 8000
|
||||||
host: 0.0.0.0
|
host: 0.0.0.0
|
||||||
tensor_parallel: 1
|
tensor_parallel: 2
|
||||||
gpu_memory_utilization: 0.7
|
gpu_memory_utilization: 0.7
|
||||||
max_model_len: 131072
|
max_model_len: 262144
|
||||||
|
|
||||||
# Environment variables
|
# Environment variables
|
||||||
env: {}
|
env: {}
|
||||||
@@ -40,4 +40,7 @@ command: |
|
|||||||
--load-format fastsafetensors \
|
--load-format fastsafetensors \
|
||||||
--attention-backend flashinfer \
|
--attention-backend flashinfer \
|
||||||
--enable-prefix-caching \
|
--enable-prefix-caching \
|
||||||
--max-model-len {max_model_len}
|
--max-model-len {max_model_len} \
|
||||||
|
-tp {tensor_parallel} \
|
||||||
|
--distributed-executor-backend ray
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user