class CudaPlatformBase(Platform):
_enum = PlatformEnum.CUDA
device_name: str = "cuda"
device_type: str = "cuda"
dispatch_key: str = "CUDA"
ray_device_key: str = "GPU"
dist_backend: str = "nccl"
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
ray_noset_device_env_vars: list[str] = [
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
]
@property
def supported_dtypes(self) -> list[torch.dtype]:
if self.has_device_capability(80):
# Ampere and Hopper or later NVIDIA GPUs.
return [torch.bfloat16, torch.float16, torch.float32]
if self.has_device_capability(60):
# Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
return [torch.float16, torch.float32]
# Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
# though vLLM doesn't support these GPUs.
return [torch.float32]
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
torch.cuda.set_device(device)
# With this trick we can force the device to be set eagerly
# see https://github.com/pytorch/pytorch/issues/155668
# for why and when it is needed
_ = torch.zeros(1, device=device)
@classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
raise NotImplementedError
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError
@classmethod
def is_fully_connected(cls, device_ids: list[int]) -> bool:
raise NotImplementedError
@classmethod
def log_warnings(cls):
pass
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
parallel_config = vllm_config.parallel_config
model_config = vllm_config.model_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
cache_config = vllm_config.cache_config
user_specified_block_size = cache_config.block_size is not None
if not user_specified_block_size:
cache_config.block_size = 16
# Ensure block_size is compatible with the attention backend.
# Note: model_config may be None during testing.
# Skip hybrid (attention+mamba) models — their block_size is
# managed by HybridAttentionMambaModelConfig
if model_config is not None and not model_config.is_hybrid:
cls._update_block_size_for_backend(
vllm_config,
user_specified_block_size,
)
scheduler_config = vllm_config.scheduler_config
# Note: model_config may be None during testing
if (
model_config is not None
and model_config.is_mm_prefix_lm
and scheduler_config.is_multimodal_model
and not scheduler_config.disable_chunked_mm_input
):
logger.warning(
"Forcing --disable_chunked_mm_input for models "
"with multimodal-bidirectional attention."
)
scheduler_config.disable_chunked_mm_input = True
@classmethod
def _update_block_size_for_backend(
cls,
vllm_config: "VllmConfig",
user_specified_block_size: bool,
) -> None:
"""Ensure block_size is compatible with the attention backend.
If the user specified --block-size, the selector validates/filters
backends by that block size (raising on incompatibility). Otherwise,
the backend is selected unconstrained and block_size is set to the
backend's preferred value.
"""
from vllm.config.vllm import set_current_vllm_config
from vllm.v1.attention.selector import AttentionSelectorConfig
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
device_capability = cls.get_device_capability()
if device_capability is None:
return
use_mla = model_config.use_mla
attn_selector_config = AttentionSelectorConfig(
head_size=model_config.get_head_size(),
dtype=model_config.dtype, # type: ignore[arg-type]
kv_cache_dtype=cache_config.cache_dtype,
block_size=cache_config.block_size if user_specified_block_size else None,
use_mla=use_mla,
has_sink=False,
use_sparse=use_mla and hasattr(model_config.hf_config, "index_topk"),
use_mm_prefix=model_config.is_mm_prefix_lm,
)
user_specified_backend = vllm_config.attention_config.backend
num_heads = model_config.get_num_attention_heads(
vllm_config.parallel_config,
)
with set_current_vllm_config(vllm_config):
chosen_backend = cls.select_attention_backend(
selected_backend=user_specified_backend,
attn_selector_config=attn_selector_config,
device_capability=device_capability,
# Don't raise here — we produce better errors below.
raise_on_invalid=False,
num_heads=num_heads,
)
# If the user's --block-size forced a non-optimal backend,
# warn them. Only relevant when the user didn't also specify
# --attention-backend (in which case the choice is explicit).
if (
chosen_backend is not None
and user_specified_block_size
and user_specified_backend is None
):
optimal = cls.select_attention_backend(
selected_backend=None,
attn_selector_config=attn_selector_config._replace(
block_size=None,
),
device_capability=device_capability,
raise_on_invalid=False,
num_heads=num_heads,
)
if optimal is not None and optimal != chosen_backend:
logger.warning(
"--block-size %d is not supported by the preferred "
"%s backend. Using %s instead, which may result "
"in reduced performance. Consider removing "
"--block-size to auto-select the optimal "
"block size.",
cache_config.block_size,
optimal.name,
chosen_backend.name,
)
if chosen_backend is not None:
if user_specified_block_size:
# User's block_size is compatible with the chosen
# backend.
return
# User didn't specify --block-size, so auto-select the
# preferred block size for the chosen backend.
try:
backend_class = chosen_backend.get_class()
except ImportError:
return # Will fail later with a better error
preferred = backend_class.get_preferred_block_size(
cache_config.block_size,
)
if cache_config.block_size != preferred:
logger.info(
"Setting kv cache block size to %d for %s backend.",
preferred,
chosen_backend.name,
)
cache_config.block_size = preferred
return
# No valid backend found. If the user didn't constrain the
# selection, defer the error to get_attn_backend_cls where
# the full config (including per-layer settings) is
# available.
if not user_specified_block_size:
return
if user_specified_backend is not None:
# User specified --block-size and --attention-backend
# and they are incompatible.
try:
backend_class = user_specified_backend.get_class()
supported = backend_class.get_supported_kernel_block_sizes()
except ImportError:
supported = None
raise ValueError(
f"User-specified --block-size "
f"{cache_config.block_size} is incompatible with "
f"the specified --attention-backend "
f"{user_specified_backend.name} (supported kernel "
f"block sizes: {supported}). Either remove "
f"--block-size to auto-select, or choose a "
f"compatible value."
)
else:
# User specified --block-size but no backend supports
# it.
_, invalid_reasons = cls.get_valid_backends(
device_capability=device_capability,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
)
reasons_str = ", ".join(
f"{b.name}: [{', '.join(r)}]" for b, r in invalid_reasons.items()
)
raise ValueError(
f"No valid attention backend found for "
f"--block-size {cache_config.block_size}. "
f"Reasons: {{{reasons_str}}}. Either remove "
f"--block-size to auto-select, or choose a "
f"compatible value."
)
@classmethod
def get_current_memory_usage(
cls, device: torch.types.Device | None = None
) -> float:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
return torch.cuda.max_memory_allocated(device)
@classmethod
def get_valid_backends(
cls,
device_capability: DeviceCapability,
attn_selector_config: "AttentionSelectorConfig",
num_heads: int | None = None,
) -> tuple[
list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]],
]:
valid_backends_priorities = []
invalid_reasons = {}
backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla,
device_capability,
num_heads,
)
for priority, backend in enumerate(backend_priorities):
try:
backend_class = backend.get_class()
invalid_reasons_i = backend_class.validate_configuration(
device_capability=device_capability,
**attn_selector_config._asdict(),
)
except ImportError:
invalid_reasons_i = ["ImportError"]
if invalid_reasons_i:
invalid_reasons[backend] = invalid_reasons_i
else:
valid_backends_priorities.append((backend, priority))
return valid_backends_priorities, invalid_reasons
@classmethod
def select_attention_backend(
cls,
selected_backend: "AttentionBackendEnum | None",
attn_selector_config: "AttentionSelectorConfig",
device_capability: "DeviceCapability",
raise_on_invalid: bool = True,
num_heads: int | None = None,
) -> "AttentionBackendEnum | None":
"""Select the best attention backend for the given configuration.
Args:
selected_backend: User-specified backend, or None for auto-selection
attn_selector_config: Configuration for attention selection
device_capability: Device capability info
raise_on_invalid: If True, raise ValueError when no valid backend
num_heads: Number of attention heads per GPU, used for backend
priority ordering on Blackwell GPUs
Returns:
The selected backend enum, or None if no valid backend found
and raise_on_invalid is False
"""
# First try checking just the selected backend, if there is one.
if selected_backend is not None:
try:
backend_class = selected_backend.get_class()
validation_errors = backend_class.validate_configuration(
device_capability=device_capability,
**attn_selector_config._asdict(),
)
except ImportError:
validation_errors = ["ImportError"]
if validation_errors:
if raise_on_invalid:
raise ValueError(
f"Selected backend {selected_backend} is not valid for "
f"this configuration. Reason: {validation_errors}"
)
return None
return selected_backend
# No selected backend, so find the best valid one.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
device_capability=device_capability,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
)
if len(valid_backends_priorities) == 0:
if raise_on_invalid:
reasons_str = (
"{"
+ ", ".join(
f"{backend.name}: [{', '.join(reasons)}]"
for backend, reasons in invalid_reasons.items()
)
+ "}"
)
config_str = attn_selector_config.__repr__()
raise ValueError(
f"No valid attention backend found for {cls.device_name} "
f"with {config_str}. Reasons: {reasons_str}."
)
return None
# Select the one with the highest priority (lowest index).
sorted_backends = sorted(valid_backends_priorities, key=lambda x: x[1])
return sorted_backends[0][0]
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum | None",
attn_selector_config: "AttentionSelectorConfig",
num_heads: int | None = None,
) -> str:
device_capability = cls.get_device_capability()
assert device_capability is not None
chosen_backend = cls.select_attention_backend(
selected_backend=selected_backend,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
device_capability=device_capability,
raise_on_invalid=True,
)
assert chosen_backend is not None # raise_on_invalid=True guarantees this
# Log the selection
if selected_backend is not None:
logger.info("Using %s backend.", chosen_backend)
else:
# Get all valid backends for logging
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
device_capability=device_capability,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
)
reasons_str = (
"{"
+ ", ".join(
f"{backend.name}: [{', '.join(reasons)}]"
for backend, reasons in invalid_reasons.items()
)
+ "}"
)
config_str = attn_selector_config.__repr__()
logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}."
)
logger.info_once(
"Using %s attention backend out of potential backends: %s",
chosen_backend.name,
tuple(b[0].name for b in valid_backends_priorities),
scope="local",
)
return chosen_backend.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.TORCH_SDPA,
]
@classmethod
def get_vit_attn_backend(
cls,
head_size: int,
dtype: torch.dtype,
backend: "AttentionBackendEnum | None" = None,
) -> "AttentionBackendEnum":
if backend is not None:
assert backend in cls.get_supported_vit_attn_backends(), (
f"Backend {backend} is not supported for vit attention. "
f"Supported backends are: {cls.get_supported_vit_attn_backends()}"
)
logger.info_once(f"Using backend {backend} for vit attention")
return backend
cc = cls.get_device_capability()
for vit_attn_backend in cls.get_supported_vit_attn_backends():
if vit_attn_backend == AttentionBackendEnum.TORCH_SDPA:
continue
try:
backend_class = vit_attn_backend.get_class()
is_backend_supported = backend_class.supports_head_size(
head_size
) and backend_class.supports_dtype(dtype)
if cc is not None:
is_backend_supported = (
is_backend_supported
and backend_class.supports_compute_capability(cc)
)
if is_backend_supported:
logger.info_once(
f"Using backend {vit_attn_backend} for vit attention"
)
return vit_attn_backend
except ImportError:
pass
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
@classmethod
def get_device_communicator_cls(cls) -> str:
return (
"vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
)
@classmethod
def supports_fp8(cls) -> bool:
return cls.has_device_capability(89)
@classmethod
def use_custom_allreduce(cls) -> bool:
return True
@classmethod
def opaque_attention_op(cls) -> bool:
return True
@classmethod
def get_static_graph_wrapper_cls(cls) -> str:
return "vllm.compilation.cuda_graph.CUDAGraphWrapper"
@classmethod
def device_count(cls) -> int:
return cuda_device_count_stateless()
@classmethod
def check_if_supports_dtype(cls, dtype: torch.dtype):
if dtype == torch.bfloat16: # noqa: SIM102
if not cls.has_device_capability(80):
capability = cls.get_device_capability()
gpu_name = cls.get_device_name()
if capability is None:
compute_str = "does not have a compute capability"
else:
version_str = capability.as_version_str()
compute_str = f"has compute capability {version_str}"
raise ValueError(
"Bfloat16 is only supported on GPUs "
"with compute capability of at least 8.0. "
f"Your {gpu_name} GPU {compute_str}. "
"You can use float16 instead by explicitly setting the "
"`dtype` flag in CLI, for example: --dtype=half."
)
@classmethod
def insert_blocks_to_device(
cls,
src_cache: torch.Tensor,
dst_cache: torch.Tensor,
src_block_indices: torch.Tensor,
dst_block_indices: torch.Tensor,
) -> None:
"""Copy blocks from src_cache to dst_cache on GPU."""
_src_cache = src_cache[:, src_block_indices]
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
@classmethod
def swap_out_blocks_to_host(
cls,
src_cache: torch.Tensor,
dst_cache: torch.Tensor,
src_block_indices: torch.Tensor,
dst_block_indices: torch.Tensor,
) -> None:
"""Copy blocks from GPU to host (CPU)."""
_src_cache = src_cache[:, src_block_indices]
dst_cache[:, dst_block_indices] = _src_cache.cpu()
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True
@classmethod
def support_static_graph_mode(cls) -> bool:
return True