Caching cubins during build for reuse
This commit is contained in:
@@ -135,20 +135,27 @@ RUN --mount=type=cache,id=repo-cache,target=/repo-cache \
|
|||||||
|
|
||||||
WORKDIR /workspace/flashinfer
|
WORKDIR /workspace/flashinfer
|
||||||
|
|
||||||
|
# Apply patch to avoid re-downloading existing cubins
|
||||||
|
COPY flashinfer_cache.patch .
|
||||||
|
RUN patch -p1 < flashinfer_cache.patch
|
||||||
|
|
||||||
# flashinfer-python
|
# flashinfer-python
|
||||||
RUN --mount=type=cache,id=uv-cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,id=uv-cache,target=/root/.cache/uv \
|
||||||
--mount=type=cache,id=ccache,target=/root/.ccache \
|
--mount=type=cache,id=ccache,target=/root/.ccache \
|
||||||
|
--mount=type=cache,id=cubins-cache,target=/workspace/flashinfer/flashinfer-cubin/flashinfer_cubin/cubins \
|
||||||
sed -i -e 's/license = "Apache-2.0"/license = { text = "Apache-2.0" }/' -e '/license-files/d' pyproject.toml && \
|
sed -i -e 's/license = "Apache-2.0"/license = { text = "Apache-2.0" }/' -e '/license-files/d' pyproject.toml && \
|
||||||
uv build --no-build-isolation --wheel . --out-dir=/workspace/wheels -v
|
uv build --no-build-isolation --wheel . --out-dir=/workspace/wheels -v
|
||||||
|
|
||||||
# flashinfer-cubin
|
# flashinfer-cubin
|
||||||
RUN --mount=type=cache,id=uv-cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,id=uv-cache,target=/root/.cache/uv \
|
||||||
--mount=type=cache,id=ccache,target=/root/.ccache \
|
--mount=type=cache,id=ccache,target=/root/.ccache \
|
||||||
|
--mount=type=cache,id=cubins-cache,target=/workspace/flashinfer/flashinfer-cubin/flashinfer_cubin/cubins \
|
||||||
cd flashinfer-cubin && uv build --no-build-isolation --wheel . --out-dir=/workspace/wheels -v
|
cd flashinfer-cubin && uv build --no-build-isolation --wheel . --out-dir=/workspace/wheels -v
|
||||||
|
|
||||||
# flashinfer-jit-cache
|
# flashinfer-jit-cache
|
||||||
RUN --mount=type=cache,id=uv-cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,id=uv-cache,target=/root/.cache/uv \
|
||||||
--mount=type=cache,id=ccache,target=/root/.ccache \
|
--mount=type=cache,id=ccache,target=/root/.ccache \
|
||||||
|
--mount=type=cache,id=cubins-cache,target=/workspace/flashinfer/flashinfer-cubin/flashinfer_cubin/cubins \
|
||||||
cd flashinfer-jit-cache && \
|
cd flashinfer-jit-cache && \
|
||||||
uv build --no-build-isolation --wheel . --out-dir=/workspace/wheels -v
|
uv build --no-build-isolation --wheel . --out-dir=/workspace/wheels -v
|
||||||
|
|
||||||
@@ -301,7 +308,7 @@ RUN chmod +x $VLLM_BASE_DIR/run-cluster-node.sh
|
|||||||
|
|
||||||
# Final extra deps
|
# Final extra deps
|
||||||
RUN --mount=type=cache,id=uv-cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,id=uv-cache,target=/root/.cache/uv \
|
||||||
uv pip install ray[default] fastsafetensors
|
uv pip install ray[default] fastsafetensors nvidia-nvshmem-cu13
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
|
|
||||||
|
|||||||
@@ -276,7 +276,7 @@ RUN chmod +x $VLLM_BASE_DIR/run-cluster-node.sh
|
|||||||
|
|
||||||
# Final extra deps
|
# Final extra deps
|
||||||
RUN --mount=type=cache,id=uv-cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,id=uv-cache,target=/root/.cache/uv \
|
||||||
uv pip install ray[default] fastsafetensors
|
uv pip install ray[default] fastsafetensors nvidia-nvshmem-cu13
|
||||||
|
|
||||||
# If not compiling Triton
|
# If not compiling Triton
|
||||||
# remove triton-kernels as they are not compatible with this vLLM version yet
|
# remove triton-kernels as they are not compatible with this vLLM version yet
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ EXP_MXFP4=false
|
|||||||
TRITON_REF_SET=false
|
TRITON_REF_SET=false
|
||||||
VLLM_REF_SET=false
|
VLLM_REF_SET=false
|
||||||
VLLM_PRS=""
|
VLLM_PRS=""
|
||||||
|
FULL_LOG=false
|
||||||
|
|
||||||
cleanup() {
|
cleanup() {
|
||||||
if [ -n "$TMP_IMAGE" ] && [ -f "$TMP_IMAGE" ]; then
|
if [ -n "$TMP_IMAGE" ] && [ -f "$TMP_IMAGE" ]; then
|
||||||
@@ -81,6 +82,7 @@ usage() {
|
|||||||
echo " --pre-tf, --pre-transformers : Install transformers 5.0.0rc0 or higher"
|
echo " --pre-tf, --pre-transformers : Install transformers 5.0.0rc0 or higher"
|
||||||
echo " --exp-mxfp4, --experimental-mxfp4 : Build with experimental native MXFP4 support"
|
echo " --exp-mxfp4, --experimental-mxfp4 : Build with experimental native MXFP4 support"
|
||||||
echo " --apply-vllm-pr <pr-num> : Apply a specific PR patch to vLLM source code. Can be specified multiple times."
|
echo " --apply-vllm-pr <pr-num> : Apply a specific PR patch to vLLM source code. Can be specified multiple times."
|
||||||
|
echo " --full-log : Enable full build logging (--progress=plain)"
|
||||||
echo " --no-build : Skip building, only copy image (requires --copy-to)"
|
echo " --no-build : Skip building, only copy image (requires --copy-to)"
|
||||||
echo " -h, --help : Show this help message"
|
echo " -h, --help : Show this help message"
|
||||||
exit 1
|
exit 1
|
||||||
@@ -158,6 +160,7 @@ while [[ "$#" -gt 0 ]]; do
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
;;
|
;;
|
||||||
|
--full-log) FULL_LOG=true ;;
|
||||||
--no-build) NO_BUILD=true ;;
|
--no-build) NO_BUILD=true ;;
|
||||||
-h|--help) usage ;;
|
-h|--help) usage ;;
|
||||||
*) echo "Unknown parameter passed: $1"; usage ;;
|
*) echo "Unknown parameter passed: $1"; usage ;;
|
||||||
@@ -198,6 +201,10 @@ if [ "$NO_BUILD" = false ]; then
|
|||||||
# Construct build command
|
# Construct build command
|
||||||
CMD=("docker" "build" "-t" "$IMAGE_TAG")
|
CMD=("docker" "build" "-t" "$IMAGE_TAG")
|
||||||
|
|
||||||
|
if [ "$FULL_LOG" = true ]; then
|
||||||
|
CMD+=("--progress=plain")
|
||||||
|
fi
|
||||||
|
|
||||||
if [ "$EXP_MXFP4" = true ]; then
|
if [ "$EXP_MXFP4" = true ]; then
|
||||||
echo "Building with experimental MXFP4 support..."
|
echo "Building with experimental MXFP4 support..."
|
||||||
CMD+=("-f" "Dockerfile.mxfp4")
|
CMD+=("-f" "Dockerfile.mxfp4")
|
||||||
|
|||||||
18
flashinfer_cache.patch
Normal file
18
flashinfer_cache.patch
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
--- a/flashinfer/artifacts.py
|
||||||
|
+++ b/flashinfer/artifacts.py
|
||||||
|
@@ -203,9 +203,13 @@
|
||||||
|
with ThreadPoolExecutor(num_threads) as pool:
|
||||||
|
futures = []
|
||||||
|
- for name, _ in cubin_files:
|
||||||
|
- source = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, name)
|
||||||
|
- local_path = FLASHINFER_CUBIN_DIR / name
|
||||||
|
+ for name, checksum in cubin_files:
|
||||||
|
+ local_path = FLASHINFER_CUBIN_DIR / name
|
||||||
|
+ if local_path.exists() and verify_cubin(str(local_path), checksum):
|
||||||
|
+ pbar.update(1)
|
||||||
|
+ continue
|
||||||
|
+
|
||||||
|
+ source = safe_urljoin(FLASHINFER_CUBINS_REPOSITORY, name)
|
||||||
|
# Ensure parent directory exists
|
||||||
|
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fut = pool.submit(
|
||||||
Reference in New Issue
Block a user