Applied new fastsafetensors fix to mxfp4 build; disabled wheel builds by default

This commit is contained in:
Eugene Rakhmatulin
2026-02-09 23:47:06 -08:00
parent 74876dd442
commit ace16f3a8f
7 changed files with 71 additions and 35 deletions

View File

@@ -1,28 +1,12 @@
diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py
index 0809bdfa9..a7878f44f 100644
index d43656c4f382..7025efd1c2de 100644
--- a/vllm/model_executor/model_loader/weight_utils.py
+++ b/vllm/model_executor/model_loader/weight_utils.py
@@ -28,6 +28,7 @@ from vllm import envs
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.distributed import get_tensor_model_parallel_rank
+from vllm.distributed.parallel_state import get_world_group
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (
QuantizationConfig,
@@ -770,11 +771,13 @@ def fastsafetensors_weights_iterator(
"""Iterate over the weights in the model safetensor files
using fastsafetensor library."""
if torch.distributed.is_initialized():
- pg = torch.distributed.group.WORLD
+ world = get_world_group()
+ pg = world.device_group
+ device = world.device
else:
@@ -826,6 +826,7 @@ def fastsafetensors_weights_iterator(
pg = SingleGroup()
+ device = torch.device(f"cuda:{pg.rank()}")
- device = torch.device(f"cuda:{pg.rank()}")
device = torch.device(f"cuda:{current_platform.current_device()}")
+ hf_weights_files = sorted(hf_weights_files, key=_natural_sort_key)
weight_files_sub_lists = [
hf_weights_files[i : i + pg.size()]
for i in range(0, len(hf_weights_files), pg.size())