Applied new fastsafetensors fix to mxfp4 build; disabled wheel builds by default
This commit is contained in:
33
fastsafetensors_mxfp4.patch
Normal file
33
fastsafetensors_mxfp4.patch
Normal file
@@ -0,0 +1,33 @@
|
||||
--- a/vllm/model_executor/model_loader/weight_utils.py
|
||||
+++ b/vllm/model_executor/model_loader/weight_utils.py
|
||||
@@ -8,6 +8,7 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
+import re
|
||||
import tempfile
|
||||
import time
|
||||
from collections import defaultdict
|
||||
@@ -786,6 +786,14 @@
|
||||
loader.add_filenames(rank_file_map)
|
||||
return loader
|
||||
|
||||
+def _natural_sort_key(filepath: str) -> list:
|
||||
+ """Natural sort key for filenames with numeric components, such as
|
||||
+ model-00001-of-00005.safetensors -> ['model-', 1, '-of-', 5, '.safetensors']"""
|
||||
+ return [
|
||||
+ int(s) if s.isdigit() else s
|
||||
+ for s in re.split(r"(\d+)", os.path.basename(filepath))
|
||||
+ ]
|
||||
+
|
||||
|
||||
def fastsafetensors_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
@@ -801,6 +809,7 @@
|
||||
pg = SingleGroup()
|
||||
device = torch.device(f"cuda:{pg.rank()}")
|
||||
|
||||
+ hf_weights_files = sorted(hf_weights_files, key=_natural_sort_key)
|
||||
weight_files_sub_lists = [
|
||||
hf_weights_files[i : i + pg.size()]
|
||||
for i in range(0, len(hf_weights_files), pg.size())
|
||||
Reference in New Issue
Block a user