Supporting other CUDA archs via --gpu-arch flag
This commit is contained in:
@@ -60,11 +60,13 @@ copy_to_host() {
|
||||
fi
|
||||
}
|
||||
BUILD_JOBS="16"
|
||||
GPU_ARCH_LIST="12.1a"
|
||||
|
||||
# Help function
|
||||
usage() {
|
||||
echo "Usage: $0 [OPTIONS]"
|
||||
echo " -t, --tag <tag> : Image tag (default: 'vllm-node')"
|
||||
echo " --gpu-arch <arch> : GPU architecture (default: '12.1a')"
|
||||
echo " --rebuild-deps : Set cache bust for dependencies"
|
||||
echo " --rebuild-vllm : Set cache bust for vllm"
|
||||
echo " --triton-ref <ref> : Triton commit SHA, branch or tag (default: 'v3.5.1')"
|
||||
@@ -88,6 +90,7 @@ usage() {
|
||||
while [[ "$#" -gt 0 ]]; do
|
||||
case $1 in
|
||||
-t|--tag) IMAGE_TAG="$2"; shift ;;
|
||||
--gpu-arch) GPU_ARCH_LIST="$2"; shift ;;
|
||||
--rebuild-deps) REBUILD_DEPS=true ;;
|
||||
--rebuild-vllm) REBUILD_VLLM=true ;;
|
||||
--triton-ref) TRITON_REF="$2"; TRITON_REF_SET=true; shift ;;
|
||||
@@ -227,6 +230,10 @@ if [ "$NO_BUILD" = false ]; then
|
||||
# Add BUILD_JOBS to build arguments
|
||||
CMD+=("--build-arg" "BUILD_JOBS=$BUILD_JOBS")
|
||||
|
||||
# Add GPU architecture to build arguments
|
||||
CMD+=("--build-arg" "TORCH_CUDA_ARCH_LIST=$GPU_ARCH_LIST")
|
||||
CMD+=("--build-arg" "FLASHINFER_CUDA_ARCH_LIST=$GPU_ARCH_LIST")
|
||||
|
||||
if [ "$PRE_FLASHINFER" = true ]; then
|
||||
echo "Using pre-release FlashInfer..."
|
||||
CMD+=("--build-arg" "FLASHINFER_PRE=--pre")
|
||||
|
||||
Reference in New Issue
Block a user