Supporting other CUDA archs via --gpu-arch flag

This commit is contained in:
Eugene Rakhmatulin
2026-02-11 13:10:41 -08:00
parent c6b245cfe8
commit 3b1e49dcb0
5 changed files with 47 additions and 13 deletions

View File

@@ -60,11 +60,13 @@ copy_to_host() {
fi
}
BUILD_JOBS="16"
GPU_ARCH_LIST="12.1a"
# Help function
usage() {
echo "Usage: $0 [OPTIONS]"
echo " -t, --tag <tag> : Image tag (default: 'vllm-node')"
echo " --gpu-arch <arch> : GPU architecture (default: '12.1a')"
echo " --rebuild-deps : Set cache bust for dependencies"
echo " --rebuild-vllm : Set cache bust for vllm"
echo " --triton-ref <ref> : Triton commit SHA, branch or tag (default: 'v3.5.1')"
@@ -88,6 +90,7 @@ usage() {
while [[ "$#" -gt 0 ]]; do
case $1 in
-t|--tag) IMAGE_TAG="$2"; shift ;;
--gpu-arch) GPU_ARCH_LIST="$2"; shift ;;
--rebuild-deps) REBUILD_DEPS=true ;;
--rebuild-vllm) REBUILD_VLLM=true ;;
--triton-ref) TRITON_REF="$2"; TRITON_REF_SET=true; shift ;;
@@ -227,6 +230,10 @@ if [ "$NO_BUILD" = false ]; then
# Add BUILD_JOBS to build arguments
CMD+=("--build-arg" "BUILD_JOBS=$BUILD_JOBS")
# Add GPU architecture to build arguments
CMD+=("--build-arg" "TORCH_CUDA_ARCH_LIST=$GPU_ARCH_LIST")
CMD+=("--build-arg" "FLASHINFER_CUDA_ARCH_LIST=$GPU_ARCH_LIST")
if [ "$PRE_FLASHINFER" = true ]; then
echo "Using pre-release FlashInfer..."
CMD+=("--build-arg" "FLASHINFER_PRE=--pre")