--- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -8,6 +8,7 @@ import hashlib import json import os +import re import tempfile import time from collections import defaultdict @@ -786,6 +786,14 @@ loader.add_filenames(rank_file_map) return loader +def _natural_sort_key(filepath: str) -> list: + """Natural sort key for filenames with numeric components, such as + model-00001-of-00005.safetensors -> ['model-', 1, '-of-', 5, '.safetensors']""" + return [ + int(s) if s.isdigit() else s + for s in re.split(r"(\d+)", os.path.basename(filepath)) + ] + def fastsafetensors_weights_iterator( hf_weights_files: list[str], @@ -801,6 +809,7 @@ pg = SingleGroup() device = torch.device(f"cuda:{pg.rank()}") + hf_weights_files = sorted(hf_weights_files, key=_natural_sort_key) weight_files_sub_lists = [ hf_weights_files[i : i + pg.size()] for i in range(0, len(hf_weights_files), pg.size())