diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 0809bdfa9..a7878f44f 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -28,6 +28,7 @@ from vllm import envs from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.distributed import get_tensor_model_parallel_rank +from vllm.distributed.parallel_state import get_world_group from vllm.logger import init_logger from vllm.model_executor.layers.quantization import ( QuantizationConfig, @@ -770,11 +771,13 @@ def fastsafetensors_weights_iterator( """Iterate over the weights in the model safetensor files using fastsafetensor library.""" if torch.distributed.is_initialized(): - pg = torch.distributed.group.WORLD + world = get_world_group() + pg = world.device_group + device = world.device else: pg = SingleGroup() + device = torch.device(f"cuda:{pg.rank()}") - device = torch.device(f"cuda:{pg.rank()}") weight_files_sub_lists = [ hf_weights_files[i : i + pg.size()] for i in range(0, len(hf_weights_files), pg.size())