Added patch to allow fastsafetensors in cluster config

This commit is contained in:
eugr
2025-11-26 21:25:04 -08:00
parent 712637a348
commit 6a66a4b66f
3 changed files with 55 additions and 1 deletions

28
fastsafetensors.patch Normal file
View File

@@ -0,0 +1,28 @@
diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py
index 0809bdfa9..a7878f44f 100644
--- a/vllm/model_executor/model_loader/weight_utils.py
+++ b/vllm/model_executor/model_loader/weight_utils.py
@@ -28,6 +28,7 @@ from vllm import envs
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.distributed import get_tensor_model_parallel_rank
+from vllm.distributed.parallel_state import get_world_group
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (
QuantizationConfig,
@@ -770,11 +771,13 @@ def fastsafetensors_weights_iterator(
"""Iterate over the weights in the model safetensor files
using fastsafetensor library."""
if torch.distributed.is_initialized():
- pg = torch.distributed.group.WORLD
+ world = get_world_group()
+ pg = world.device_group
+ device = world.device
else:
pg = SingleGroup()
+ device = torch.device(f"cuda:{pg.rank()}")
- device = torch.device(f"cuda:{pg.rank()}")
weight_files_sub_lists = [
hf_weights_files[i : i + pg.size()]
for i in range(0, len(hf_weights_files), pg.size())