Added patch to allow fastsafetensors in cluster config
This commit is contained in:
28
fastsafetensors.patch
Normal file
28
fastsafetensors.patch
Normal file
@@ -0,0 +1,28 @@
|
||||
diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py
|
||||
index 0809bdfa9..a7878f44f 100644
|
||||
--- a/vllm/model_executor/model_loader/weight_utils.py
|
||||
+++ b/vllm/model_executor/model_loader/weight_utils.py
|
||||
@@ -28,6 +28,7 @@ from vllm import envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
+from vllm.distributed.parallel_state import get_world_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (
|
||||
QuantizationConfig,
|
||||
@@ -770,11 +771,13 @@ def fastsafetensors_weights_iterator(
|
||||
"""Iterate over the weights in the model safetensor files
|
||||
using fastsafetensor library."""
|
||||
if torch.distributed.is_initialized():
|
||||
- pg = torch.distributed.group.WORLD
|
||||
+ world = get_world_group()
|
||||
+ pg = world.device_group
|
||||
+ device = world.device
|
||||
else:
|
||||
pg = SingleGroup()
|
||||
+ device = torch.device(f"cuda:{pg.rank()}")
|
||||
|
||||
- device = torch.device(f"cuda:{pg.rank()}")
|
||||
weight_files_sub_lists = [
|
||||
hf_weights_files[i : i + pg.size()]
|
||||
for i in range(0, len(hf_weights_files), pg.size())
|
||||
Reference in New Issue
Block a user