Files
spark-vllm-docker/fastsafetensors_mxfp4.patch

34 lines
1.0 KiB
Diff

--- a/vllm/model_executor/model_loader/weight_utils.py
+++ b/vllm/model_executor/model_loader/weight_utils.py
@@ -8,6 +8,7 @@
import hashlib
import json
import os
+import re
import tempfile
import time
from collections import defaultdict
@@ -786,6 +786,14 @@
loader.add_filenames(rank_file_map)
return loader
+def _natural_sort_key(filepath: str) -> list:
+ """Natural sort key for filenames with numeric components, such as
+ model-00001-of-00005.safetensors -> ['model-', 1, '-of-', 5, '.safetensors']"""
+ return [
+ int(s) if s.isdigit() else s
+ for s in re.split(r"(\d+)", os.path.basename(filepath))
+ ]
+
def fastsafetensors_weights_iterator(
hf_weights_files: list[str],
@@ -801,6 +809,7 @@
pg = SingleGroup()
device = torch.device(f"cuda:{pg.rank()}")
+ hf_weights_files = sorted(hf_weights_files, key=_natural_sort_key)
weight_files_sub_lists = [
hf_weights_files[i : i + pg.size()]
for i in range(0, len(hf_weights_files), pg.size())