29 lines
1.2 KiB
Diff
29 lines
1.2 KiB
Diff
diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py
|
|
index 0809bdfa9..a7878f44f 100644
|
|
--- a/vllm/model_executor/model_loader/weight_utils.py
|
|
+++ b/vllm/model_executor/model_loader/weight_utils.py
|
|
@@ -28,6 +28,7 @@ from vllm import envs
|
|
from vllm.config import ModelConfig
|
|
from vllm.config.load import LoadConfig
|
|
from vllm.distributed import get_tensor_model_parallel_rank
|
|
+from vllm.distributed.parallel_state import get_world_group
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.quantization import (
|
|
QuantizationConfig,
|
|
@@ -770,11 +771,13 @@ def fastsafetensors_weights_iterator(
|
|
"""Iterate over the weights in the model safetensor files
|
|
using fastsafetensor library."""
|
|
if torch.distributed.is_initialized():
|
|
- pg = torch.distributed.group.WORLD
|
|
+ world = get_world_group()
|
|
+ pg = world.device_group
|
|
+ device = world.device
|
|
else:
|
|
pg = SingleGroup()
|
|
+ device = torch.device(f"cuda:{pg.rank()}")
|
|
|
|
- device = torch.device(f"cuda:{pg.rank()}")
|
|
weight_files_sub_lists = [
|
|
hf_weights_files[i : i + pg.size()]
|
|
for i in range(0, len(hf_weights_files), pg.size())
|