Added EXPERIMENTAL mod for b12x - initial support
This commit is contained in:
116
mods/exp-b12x/run.sh
Normal file
116
mods/exp-b12x/run.sh
Normal file
@@ -0,0 +1,116 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
SITE_PACKAGES="/usr/local/lib/python3.12/dist-packages"
|
||||
|
||||
echo "=== EXPERIMENTAL b12x-patches mod ==="
|
||||
|
||||
# 0a. Check if b12x support is present in vLLM
|
||||
if [ ! -f "$SITE_PACKAGES/vllm/model_executor/layers/fused_moe/experts/flashinfer_b12x_moe.py" ]; then
|
||||
echo "[b12x ERROR] No b12x support detected; please rebuild with --apply-vllm-pr 40082, e.g.:"
|
||||
echo "./build-and-copy.sh -t vllm-node-40082 --apply-vllm-pr 40082"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# 0b. Check if environment variables are set
|
||||
|
||||
if [[ "$VLLM_NVFP4_GEMM_BACKEND" != "flashinfer-b12x" ]]; then
|
||||
echo "[b12x ERROR] Please set required environment variables to use b12x backend"
|
||||
echo "*** Add the following arguments to launch-cluster.sh:"
|
||||
echo " -e FLASHINFER_DISABLE_VERSION_CHECK=1 -e VLLM_USE_FLASHINFER_MOE_FP16=1 -e VLLM_NVFP4_GEMM_BACKEND=flashinfer-b12x -e VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 -e VLLM_FLASHINFER_ALLREDUCE_BACKEND=trtllm -e VLLM_USE_FLASHINFER_MOE_FP4=1"
|
||||
echo "*** also set the following vLLM parameters:"
|
||||
echo " --moe-backend flashinfer_b12x --attention-backend flashinfer"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# 1. Pin nvidia-cutlass-dsl + companion libs to 4.4.2
|
||||
# (4.5.x generates bad PTX on SM121 — `_mma` rejected by ptxas).
|
||||
# All THREE packages must match: the python frontend, the base libs,
|
||||
# and the CUDA 13 libs (which contain the MLIR compiler).
|
||||
# ---------------------------------------------------------------
|
||||
DSL_VER=$(pip show nvidia-cutlass-dsl 2>/dev/null | grep '^Version:' | awk '{print $2}' || true)
|
||||
LIBS_BASE_VER=$(pip show nvidia-cutlass-dsl-libs-base 2>/dev/null | grep '^Version:' | awk '{print $2}' || true)
|
||||
# LIBS_CU13_VER=$(pip show nvidia-cutlass-dsl-libs-cu13 2>/dev/null | grep '^Version:' | awk '{print $2}' || true)
|
||||
if [ "$DSL_VER" != "4.4.2" ] || [ "$LIBS_BASE_VER" != "4.4.2" ] || [ "$LIBS_CU13_VER" != "4.4.2" ]; then
|
||||
echo "[b12x] Pinning nvidia-cutlass-dsl{,-libs-base,-libs-cu13} to 4.4.2"
|
||||
echo "[b12x] current: dsl=${DSL_VER:-none} libs-base=${LIBS_BASE_VER:-none} libs-cu13=${LIBS_CU13_VER:-none}"
|
||||
uv pip install \
|
||||
nvidia-cutlass-dsl==4.4.2 \
|
||||
nvidia-cutlass-dsl-libs-base==4.4.2 \
|
||||
nvidia-cutlass-dsl-libs-cu13==4.4.2 \
|
||||
-q 2>/dev/null || echo "[b12x] WARNING: cutlass-dsl pin install returned non-zero"
|
||||
else
|
||||
echo "[b12x] nvidia-cutlass-dsl + libs already at 4.4.2"
|
||||
fi
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# 2. Apply cutlass-dsl SM121 patches
|
||||
# FlashInfer/vLLM install wipes vendored cutlass, so re-apply every time
|
||||
# ---------------------------------------------------------------
|
||||
echo "[b12x] Applying cutlass-dsl SM121 patches..."
|
||||
|
||||
# 2a. warp/mma.py: allow sm_121a alongside sm_120a in both the runtime
|
||||
# arch check and the `admissible_archs` string list (used in error msgs)
|
||||
for f in $(find "$SITE_PACKAGES" -name "mma.py" -path "*/warp/*" 2>/dev/null); do
|
||||
if grep -q "if not arch == Arch.sm_120a:" "$f" 2>/dev/null; then
|
||||
sed -i "s/if not arch == Arch.sm_120a:/if arch not in (Arch.sm_120a, Arch.sm_121a):/" "$f"
|
||||
echo " patched $f (warp sm_121a runtime check)"
|
||||
fi
|
||||
# Add sm_121a to the admissible_archs list if missing
|
||||
if grep -q '"sm_120a",' "$f" 2>/dev/null && ! grep -q '"sm_121a"' "$f" 2>/dev/null; then
|
||||
sed -i 's/^\(\s*\)"sm_120a",$/\1"sm_120a",\n\1"sm_121a",/' "$f"
|
||||
echo " patched $f (warp sm_121a admissible_archs)"
|
||||
fi
|
||||
done
|
||||
|
||||
# 2b. tcgen05/mma.py: add sm_120a and sm_121a to supported arch list
|
||||
for f in $(find "$SITE_PACKAGES" -name "mma.py" -path "*/tcgen05/*" 2>/dev/null); do
|
||||
if ! grep -q "Arch.sm_121a" "$f" 2>/dev/null; then
|
||||
sed -i "/Arch.sm_103a,/a\\ Arch.sm_120a,\n Arch.sm_121a," "$f"
|
||||
echo " patched $f (tcgen05 mma sm_121a)"
|
||||
fi
|
||||
done
|
||||
|
||||
# 2c. tcgen05/copy.py: allow sm_120f family
|
||||
for f in $(find "$SITE_PACKAGES" -name "copy.py" -path "*/tcgen05/*" 2>/dev/null); do
|
||||
if ! grep -q "sm_120f" "$f" 2>/dev/null; then
|
||||
sed -i "s/arch.is_family_of(Arch.sm_110f)/arch.is_family_of(Arch.sm_110f) or arch.is_family_of(Arch.sm_120f)/" "$f"
|
||||
echo " patched $f (tcgen05 copy sm_120f)"
|
||||
fi
|
||||
done
|
||||
|
||||
# Clear pycache so patched code takes effect
|
||||
find "$SITE_PACKAGES" -name "__pycache__" -path "*/cutlass*" -exec rm -rf {} + 2>/dev/null || true
|
||||
find "$SITE_PACKAGES" -name "__pycache__" -path "*/flashinfer*" -exec rm -rf {} + 2>/dev/null || true
|
||||
|
||||
# ---------------------------------------------------------------
|
||||
# 3 Patch FlashInfer's blackwell_sm12x __init__.py to drop the
|
||||
# broken `sm120_moe_dispatch_context` import (FlashInfer main
|
||||
# has a stale __init__ that references a function that no
|
||||
# longer exists in moe_dispatch.py — but the symbol isn't
|
||||
# actually used by anything, so we just remove it from the
|
||||
# import + __all__ list).
|
||||
# ---------------------------------------------------------------
|
||||
SM12X_INIT="$SITE_PACKAGES/flashinfer/fused_moe/cute_dsl/blackwell_sm12x/__init__.py"
|
||||
if [ -f "$SM12X_INIT" ]; then
|
||||
if grep -q "sm120_moe_dispatch_context" "$SM12X_INIT"; then
|
||||
# Drop the line that imports/exports the missing symbol
|
||||
sed -i '/sm120_moe_dispatch_context/d' "$SM12X_INIT"
|
||||
find "$SITE_PACKAGES/flashinfer" -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
echo "[b12x] patched $SM12X_INIT (dropped stale sm120_moe_dispatch_context references)"
|
||||
else
|
||||
echo "[b12x] $SM12X_INIT already cleaned"
|
||||
fi
|
||||
else
|
||||
echo "[b12x] $SM12X_INIT not found (older FlashInfer?), skipping"
|
||||
fi
|
||||
|
||||
if grep -q "if current_platform.has_device_capability(120) and has_flashinfer_b12x_gemm():" $SITE_PACKAGES/vllm/model_executor/kernels/linear/nvfp4/flashinfer.py; then
|
||||
echo "[b12x] Patching vLLM PR 40080 to enable sm121 cap"
|
||||
sed -i "s/if current_platform.has_device_capability(120) and has_flashinfer_b12x_gemm():/if True:/" $SITE_PACKAGES/vllm/model_executor/kernels/linear/nvfp4/flashinfer.py
|
||||
fi
|
||||
|
||||
|
||||
Reference in New Issue
Block a user