256 lines
13 KiB
Diff
256 lines
13 KiB
Diff
diff --git a/vllm/config/cache.py b/vllm/config/cache.py
|
|
index 3796265ff..b6dcfb54c 100644
|
|
--- a/vllm/config/cache.py
|
|
+++ b/vllm/config/cache.py
|
|
@@ -45,6 +45,11 @@ class CacheConfig:
|
|
not matter if you have another vLLM instance running on the same GPU. For
|
|
example, if you have two vLLM instances running on the same GPU, you can
|
|
set the GPU memory utilization to 0.5 for each instance."""
|
|
+ gpu_memory_utilization_gb: float | None = Field(default=None, gt=0)
|
|
+ """Amount of GPU memory to be used in GiB. This provides fine-grained control
|
|
+ over GPU memory usage and is particularly useful on unified memory systems
|
|
+ where available memory changes dynamically. If specified, it overrides
|
|
+ gpu_memory_utilization. Cannot be used simultaneously with kv_cache_memory_bytes."""
|
|
cache_dtype: CacheDType = "auto"
|
|
"""Data type for kv cache storage. If "auto", will use model data type.
|
|
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
|
|
@@ -204,6 +209,18 @@ class CacheConfig:
|
|
object.__setattr__(self, "user_specified_block_size", True)
|
|
return self
|
|
|
|
+ @model_validator(mode="after")
|
|
+ def _validate_memory_params(self) -> "CacheConfig":
|
|
+ if (
|
|
+ self.gpu_memory_utilization_gb is not None
|
|
+ and self.kv_cache_memory_bytes is not None
|
|
+ ):
|
|
+ raise ValueError(
|
|
+ "Cannot specify both gpu_memory_utilization_gb and "
|
|
+ "kv_cache_memory_bytes. Please use only one of them."
|
|
+ )
|
|
+ return self
|
|
+
|
|
@field_validator("cache_dtype", mode="after")
|
|
@classmethod
|
|
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
|
|
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
|
|
index 56bbb7bf5..db5012608 100644
|
|
--- a/vllm/engine/arg_utils.py
|
|
+++ b/vllm/engine/arg_utils.py
|
|
@@ -454,6 +454,7 @@ class EngineArgs:
|
|
offload_prefetch_step: int = PrefetchOffloadConfig.offload_prefetch_step
|
|
offload_params: set[str] = get_field(PrefetchOffloadConfig, "offload_params")
|
|
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
|
|
+ gpu_memory_utilization_gb: float | None = CacheConfig.gpu_memory_utilization_gb
|
|
kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
|
|
max_num_batched_tokens: int | None = None
|
|
max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
|
|
@@ -954,6 +955,9 @@ class EngineArgs:
|
|
cache_group.add_argument(
|
|
"--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"]
|
|
)
|
|
+ cache_group.add_argument(
|
|
+ "--gpu-memory-utilization-gb", **cache_kwargs["gpu_memory_utilization_gb"]
|
|
+ )
|
|
cache_group.add_argument(
|
|
"--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"]
|
|
)
|
|
@@ -1512,6 +1516,7 @@ class EngineArgs:
|
|
cache_config = CacheConfig(
|
|
block_size=self.block_size, # type: ignore[arg-type]
|
|
gpu_memory_utilization=self.gpu_memory_utilization,
|
|
+ gpu_memory_utilization_gb=self.gpu_memory_utilization_gb,
|
|
kv_cache_memory_bytes=self.kv_cache_memory_bytes,
|
|
cache_dtype=resolved_cache_dtype, # type: ignore[arg-type]
|
|
is_attention_free=model_config.is_attention_free,
|
|
diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py
|
|
index 5909b3043..c2607df6a 100644
|
|
--- a/vllm/entrypoints/llm.py
|
|
+++ b/vllm/entrypoints/llm.py
|
|
@@ -156,6 +156,11 @@ class LLM:
|
|
values will increase the KV cache size and thus improve the model's
|
|
throughput. However, if the value is too high, it may cause out-of-
|
|
memory (OOM) errors.
|
|
+ gpu_memory_utilization_gb: Amount of GPU memory to reserve in GiB.
|
|
+ This provides fine-grained control over GPU memory usage and is
|
|
+ particularly useful on unified memory systems where available memory
|
|
+ changes dynamically. If specified, it overrides gpu_memory_utilization.
|
|
+ Cannot be used simultaneously with kv_cache_memory_bytes.
|
|
kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
|
|
this is set to None and vllm can automatically infer the kv cache
|
|
size based on gpu_memory_utilization. However, users may want to
|
|
@@ -234,6 +239,7 @@ class LLM:
|
|
chat_template: Path | str | None = None,
|
|
seed: int = 0,
|
|
gpu_memory_utilization: float = 0.92,
|
|
+ gpu_memory_utilization_gb: float | None = None,
|
|
cpu_offload_gb: float = 0,
|
|
offload_group_size: int = 0,
|
|
offload_num_in_group: int = 1,
|
|
@@ -356,6 +362,7 @@ class LLM:
|
|
tokenizer_revision=tokenizer_revision,
|
|
seed=seed,
|
|
gpu_memory_utilization=gpu_memory_utilization,
|
|
+ gpu_memory_utilization_gb=gpu_memory_utilization_gb,
|
|
kv_cache_memory_bytes=kv_cache_memory_bytes,
|
|
cpu_offload_gb=cpu_offload_gb,
|
|
offload_group_size=offload_group_size,
|
|
diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py
|
|
index 2ed7ef7e0..806830b17 100644
|
|
--- a/vllm/v1/core/kv_cache_utils.py
|
|
+++ b/vllm/v1/core/kv_cache_utils.py
|
|
@@ -622,7 +622,8 @@ def _check_enough_kv_cache_memory(
|
|
if available_memory <= 0:
|
|
raise ValueError(
|
|
"No available memory for the cache blocks. "
|
|
- "Try increasing `gpu_memory_utilization` when initializing the engine. "
|
|
+ "Try increasing `gpu_memory_utilization` or `gpu_memory_utilization_gb` "
|
|
+ "when initializing the engine. "
|
|
"See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ "
|
|
"for more details."
|
|
)
|
|
@@ -643,8 +644,8 @@ def _check_enough_kv_cache_memory(
|
|
f"({max_model_len}), ({format_gib(needed_memory)} GiB KV "
|
|
f"cache is needed, which is larger than the available KV cache "
|
|
f"memory ({format_gib(available_memory)} GiB). {estimated_msg}"
|
|
- f"Try increasing `gpu_memory_utilization` or decreasing `max_model_len` "
|
|
- f"when initializing the engine. "
|
|
+ f"Try increasing `gpu_memory_utilization` or `gpu_memory_utilization_gb`, "
|
|
+ f"or decreasing `max_model_len` when initializing the engine. "
|
|
f"See https://docs.vllm.ai/en/latest/configuration/conserving_memory/ "
|
|
f"for more details."
|
|
)
|
|
@@ -1438,7 +1439,8 @@ def _auto_fit_max_model_len(
|
|
if auto_fit_max <= 0:
|
|
raise ValueError(
|
|
"Cannot auto-fit max_model_len: not enough GPU memory available "
|
|
- "to serve even a single token. Try increasing `gpu_memory_utilization`."
|
|
+ "to serve even a single token. Try increasing `gpu_memory_utilization` "
|
|
+ "or `gpu_memory_utilization_gb`."
|
|
)
|
|
|
|
if auto_fit_max >= original_max:
|
|
diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py
|
|
index 3d065927e..e8cef2ceb 100644
|
|
--- a/vllm/v1/utils.py
|
|
+++ b/vllm/v1/utils.py
|
|
@@ -358,6 +358,7 @@ def report_usage_stats(
|
|
"dtype": str(vllm_config.model_config.dtype),
|
|
"block_size": vllm_config.cache_config.block_size,
|
|
"gpu_memory_utilization": vllm_config.cache_config.gpu_memory_utilization,
|
|
+ "gpu_memory_utilization_gb": vllm_config.cache_config.gpu_memory_utilization_gb,
|
|
"kv_cache_memory_bytes": vllm_config.cache_config.kv_cache_memory_bytes,
|
|
# Quantization
|
|
"quantization": vllm_config.model_config.quantization,
|
|
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
|
|
index b53bd71a1..d28821328 100644
|
|
--- a/vllm/v1/worker/gpu_model_runner.py
|
|
+++ b/vllm/v1/worker/gpu_model_runner.py
|
|
@@ -5355,8 +5355,8 @@ class GPUModelRunner(
|
|
raise RuntimeError(
|
|
"CUDA out of memory occurred when warming up sampler with "
|
|
f"{num_reqs} dummy requests. Please try lowering "
|
|
- "`max_num_seqs` or `gpu_memory_utilization` when "
|
|
- "initializing the engine."
|
|
+ "`max_num_seqs`, `gpu_memory_utilization`, or "
|
|
+ "`gpu_memory_utilization_gb` when initializing the engine."
|
|
) from e
|
|
else:
|
|
raise e
|
|
@@ -5434,8 +5434,8 @@ class GPUModelRunner(
|
|
raise RuntimeError(
|
|
"CUDA out of memory occurred when warming up pooler "
|
|
f"({task=}) with {num_reqs} dummy requests. Please try "
|
|
- "lowering `max_num_seqs` or `gpu_memory_utilization` when "
|
|
- "initializing the engine."
|
|
+ "lowering `max_num_seqs`, `gpu_memory_utilization`, or "
|
|
+ "`gpu_memory_utilization_gb` when initializing the engine."
|
|
) from e
|
|
else:
|
|
raise e
|
|
diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py
|
|
index 842e76549..bf3bb359b 100644
|
|
--- a/vllm/v1/worker/gpu_worker.py
|
|
+++ b/vllm/v1/worker/gpu_worker.py
|
|
@@ -357,7 +357,8 @@ class Worker(WorkerBase):
|
|
|
|
Tip:
|
|
You may limit the usage of GPU memory
|
|
- by adjusting the `gpu_memory_utilization` parameter.
|
|
+ by adjusting the `gpu_memory_utilization` or
|
|
+ `gpu_memory_utilization_gb` parameter.
|
|
"""
|
|
if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes:
|
|
# still need a profile run which compiles the model for
|
|
@@ -369,7 +370,8 @@ class Worker(WorkerBase):
|
|
f"GiB, reserved {format_gib(kv_cache_memory_bytes)} GiB memory for "
|
|
"KV Cache as specified by kv_cache_memory_bytes config and "
|
|
"skipped memory profiling. This does not respect the "
|
|
- "gpu_memory_utilization config. Only use kv_cache_memory_bytes "
|
|
+ "gpu_memory_utilization or gpu_memory_utilization_gb config. "
|
|
+ "Only use kv_cache_memory_bytes "
|
|
"config when you want manual control of KV cache memory "
|
|
"size. If OOM'ed, check the difference of initial free "
|
|
"memory between the current run and the previous run "
|
|
diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py
|
|
index d06c40ed6..89c94e641 100644
|
|
--- a/vllm/v1/worker/utils.py
|
|
+++ b/vllm/v1/worker/utils.py
|
|
@@ -405,21 +405,43 @@ def request_memory(init_snapshot: MemorySnapshot, cache_config: CacheConfig) ->
|
|
Calculate the amount of memory required by vLLM, then validate
|
|
that the current amount of free memory is sufficient for that.
|
|
"""
|
|
- requested_memory = math.ceil(
|
|
- init_snapshot.total_memory * cache_config.gpu_memory_utilization
|
|
- )
|
|
-
|
|
- if init_snapshot.free_memory < requested_memory:
|
|
- raise ValueError(
|
|
- f"Free memory on device {init_snapshot.device_} "
|
|
- f"({format_gib(init_snapshot.free_memory)}/"
|
|
- f"{format_gib(init_snapshot.total_memory)} GiB) on startup "
|
|
- f"is less than desired GPU memory utilization "
|
|
- f"({cache_config.gpu_memory_utilization}, "
|
|
- f"{format_gib(requested_memory)} GiB). Decrease GPU memory "
|
|
- f"utilization or reduce GPU memory used by other processes."
|
|
+ if cache_config.gpu_memory_utilization_gb is not None:
|
|
+ requested_memory = math.ceil(cache_config.gpu_memory_utilization_gb * 1024**3)
|
|
+ if requested_memory <= 0:
|
|
+ raise ValueError(
|
|
+ f"gpu_memory_utilization_gb must be positive, got "
|
|
+ f"{cache_config.gpu_memory_utilization_gb} GiB."
|
|
+ )
|
|
+ if requested_memory > init_snapshot.total_memory:
|
|
+ raise ValueError(
|
|
+ f"Requested memory ({format_gib(requested_memory)} GiB) exceeds "
|
|
+ f"total GPU memory ({format_gib(init_snapshot.total_memory)} GiB). "
|
|
+ f"Reduce gpu_memory_utilization_gb or use a smaller value."
|
|
+ )
|
|
+ safety_margin = 0.5 * 1024**3
|
|
+ if requested_memory > init_snapshot.free_memory + safety_margin:
|
|
+ raise ValueError(
|
|
+ f"Requested memory ({format_gib(requested_memory)} GiB) exceeds "
|
|
+ f"available memory ({format_gib(init_snapshot.free_memory)} GiB) "
|
|
+ f"with safety margin ({format_gib(safety_margin)} GiB). "
|
|
+ f"Reduce gpu_memory_utilization_gb or free up GPU memory."
|
|
+ )
|
|
+ else:
|
|
+ requested_memory = math.ceil(
|
|
+ init_snapshot.total_memory * cache_config.gpu_memory_utilization
|
|
)
|
|
|
|
+ if init_snapshot.free_memory < requested_memory:
|
|
+ raise ValueError(
|
|
+ f"Free memory on device {init_snapshot.device_} "
|
|
+ f"({format_gib(init_snapshot.free_memory)}/"
|
|
+ f"{format_gib(init_snapshot.total_memory)} GiB) on startup "
|
|
+ f"is less than desired GPU memory utilization "
|
|
+ f"({cache_config.gpu_memory_utilization}, "
|
|
+ f"{format_gib(requested_memory)} GiB). Decrease GPU memory "
|
|
+ f"utilization or reduce GPU memory used by other processes."
|
|
+ )
|
|
+
|
|
return requested_memory
|
|
|
|
|