Skip to content

vllm.model_executor.models.config

DeepseekV32ForCausalLM

Bases: VerifyAndUpdateConfig

Source code in vllm/model_executor/models/config.py
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
        """
        hf_config = vllm_config.model_config.hf_config

        # Mirror the check in vllm/model_executor/models/deepseek_v2.py
        is_v32 = hasattr(hf_config, "index_topk")
        assert is_v32

        # For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled.
        cache_config = vllm_config.cache_config
        if cache_config.cache_dtype.startswith("fp8"):
            cache_config.cache_dtype = "fp8_ds_mla"
            logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
        if cache_config.cache_dtype == "bfloat16":
            cache_config.cache_dtype = "auto"
            logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")

verify_and_update_config classmethod

verify_and_update_config(vllm_config: VllmConfig) -> None

Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32

Source code in vllm/model_executor/models/config.py
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
    """
    Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
    """
    hf_config = vllm_config.model_config.hf_config

    # Mirror the check in vllm/model_executor/models/deepseek_v2.py
    is_v32 = hasattr(hf_config, "index_topk")
    assert is_v32

    # For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled.
    cache_config = vllm_config.cache_config
    if cache_config.cache_dtype.startswith("fp8"):
        cache_config.cache_dtype = "fp8_ds_mla"
        logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
    if cache_config.cache_dtype == "bfloat16":
        cache_config.cache_dtype = "auto"
        logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")

HybridAttentionMambaModelConfig

Bases: VerifyAndUpdateConfig

Source code in vllm/model_executor/models/config.py
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Ensure that page size of attention layers is greater than or
        equal to the mamba layers. If not, automatically set the attention
        block size to ensure that it is. If the attention page size is
        strictly greater than the mamba page size, we pad the mamba page size
        to make them equal.

        Args:
            vllm_config: vLLM Config
        """
        # Save the user input before it gets modified by MambaModelConfig
        mamba_block_size = vllm_config.cache_config.mamba_block_size
        # Enable FULL_AND_PIECEWISE by default
        MambaModelConfig.verify_and_update_config(vllm_config)

        attention_config = vllm_config.attention_config
        cache_config = vllm_config.cache_config
        model_config = vllm_config.model_config
        parallel_config = vllm_config.parallel_config

        if cache_config.cache_dtype == "auto":
            kv_cache_dtype = model_config.dtype
        else:
            kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

        # get attention page size (for 1 token)
        # Attention backend constraints:
        # - FlashAttention (FA) requires block size to be multiple of 16
        # - MLA (Multi-head Latent Attention) requires larger alignment:
        #   * CUTLASS_MLA backend: kernel_block_size 128 alignment
        #   * Other MLA backends: kernel_block_size 64 alignment
        if model_config.use_mla:
            use_cutlass_mla = (
                attention_config.backend == AttentionBackendEnum.CUTLASS_MLA
            )
            kernel_block_alignment_size = 128 if use_cutlass_mla else 64
            attn_page_size_1_token = MLAAttentionSpec(
                block_size=1,
                num_kv_heads=model_config.get_num_kv_heads(parallel_config),
                head_size=model_config.get_head_size(),
                dtype=kv_cache_dtype,
            ).page_size_bytes
        else:
            kernel_block_alignment_size = 16
            if (
                current_platform.is_device_capability_family(100)
                and model_config.get_head_size() == 256
                and (
                    attention_config.backend is None
                    or attention_config.backend == AttentionBackendEnum.FLASHINFER
                )
            ):
                # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
                # head size 256 and block size 16 is not supported on blackwell.
                kernel_block_alignment_size = 32
            attn_page_size_1_token = FullAttentionSpec(
                block_size=1,
                num_kv_heads=model_config.get_num_kv_heads(parallel_config),
                head_size=model_config.get_head_size(),
                dtype=kv_cache_dtype,
            ).page_size_bytes

        model_cls, _ = ModelRegistry.resolve_model_cls(
            model_config.architecture,
            model_config=model_config,
        )

        # get mamba page size
        mamba_page_size = MambaSpec(
            shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
            dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
            block_size=-1,  # block_size doesn't matter for mamba page size
        ).page_size_bytes

        # Model may be marked as is_hybrid
        #  but mamba is skipped via config,
        #  return directly
        if mamba_page_size == 0:
            return

        if cache_config.mamba_cache_mode == "all":
            # With prefix caching, select attention block size to
            # optimize for mamba kernel performance

            # Mamba2 SSD kernel uses a chunk_size, e.g. 256
            # Align the block to the kernel: use lowest multiple of chunk_size
            # of attention tokens that would fit mamba_page_size:
            # e.g. for mamba page size = 788kB
            #          attn_1_token = 2kB -> fits ~394 tokens
            #      then round up to a multiple of 256 -> 512 tokens
            # End result:
            #  attn_block_size = 512
            #  mamba_block_size = 512 (aligned to a multiple of chunk_size)
            # TODO(tdoublep): this constraint can be relaxed fairly
            # easily by changing the way we layout chunks in the
            # mamba2 kernels.

            base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
            attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
            chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
            attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
            cache_config.mamba_block_size = attn_block_size
        else:
            # Without prefix caching, select minimum valid attention block size
            # to minimize mamba state padding

            # Calculate minimum attention block size that satisfies both:
            # 1. Backend alignment requirements (kernel_block_alignment_size)
            # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
            attn_block_size = kernel_block_alignment_size * cdiv(
                mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
            )

        # override attention block size if either (a) the
        # user has not set it or (b) the user has set it
        # too small.
        if cache_config.block_size is None or cache_config.block_size < attn_block_size:
            cache_config.block_size = attn_block_size
            logger.info(
                "Setting attention block size to %d tokens "
                "to ensure that attention page size is >= mamba page size.",
                attn_block_size,
            )

        # By default, mamba block size will be set to max_model_len.
        # When enabling prefix caching and using align mamba cache
        # mode, we align mamba block size to the block size as the
        # basic granularity for prefix caching.
        if cache_config.mamba_cache_mode == "align":
            cache_config.mamba_block_size = cache_config.block_size

        # compute new attention page size
        attn_page_size = cache_config.block_size * attn_page_size_1_token

        assert attn_page_size >= mamba_page_size

        if attn_page_size == mamba_page_size:
            # don't need to pad mamba page size
            return

        # pad mamba page size to exactly match attention
        if (
            cache_config.mamba_page_size_padded is None
            or cache_config.mamba_page_size_padded != attn_page_size
        ):
            cache_config.mamba_page_size_padded = attn_page_size
            mamba_padding_pct = (
                100 * (attn_page_size - mamba_page_size) / mamba_page_size
            )
            logger.info(
                "Padding mamba page size by %.2f%% to ensure "
                "that mamba page size and attention page size are "
                "exactly equal.",
                mamba_padding_pct,
            )

verify_and_update_config classmethod

verify_and_update_config(vllm_config: VllmConfig) -> None

Ensure that page size of attention layers is greater than or equal to the mamba layers. If not, automatically set the attention block size to ensure that it is. If the attention page size is strictly greater than the mamba page size, we pad the mamba page size to make them equal.

Parameters:

Name Type Description Default
vllm_config VllmConfig

vLLM Config

required
Source code in vllm/model_executor/models/config.py
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
    """
    Ensure that page size of attention layers is greater than or
    equal to the mamba layers. If not, automatically set the attention
    block size to ensure that it is. If the attention page size is
    strictly greater than the mamba page size, we pad the mamba page size
    to make them equal.

    Args:
        vllm_config: vLLM Config
    """
    # Save the user input before it gets modified by MambaModelConfig
    mamba_block_size = vllm_config.cache_config.mamba_block_size
    # Enable FULL_AND_PIECEWISE by default
    MambaModelConfig.verify_and_update_config(vllm_config)

    attention_config = vllm_config.attention_config
    cache_config = vllm_config.cache_config
    model_config = vllm_config.model_config
    parallel_config = vllm_config.parallel_config

    if cache_config.cache_dtype == "auto":
        kv_cache_dtype = model_config.dtype
    else:
        kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]

    # get attention page size (for 1 token)
    # Attention backend constraints:
    # - FlashAttention (FA) requires block size to be multiple of 16
    # - MLA (Multi-head Latent Attention) requires larger alignment:
    #   * CUTLASS_MLA backend: kernel_block_size 128 alignment
    #   * Other MLA backends: kernel_block_size 64 alignment
    if model_config.use_mla:
        use_cutlass_mla = (
            attention_config.backend == AttentionBackendEnum.CUTLASS_MLA
        )
        kernel_block_alignment_size = 128 if use_cutlass_mla else 64
        attn_page_size_1_token = MLAAttentionSpec(
            block_size=1,
            num_kv_heads=model_config.get_num_kv_heads(parallel_config),
            head_size=model_config.get_head_size(),
            dtype=kv_cache_dtype,
        ).page_size_bytes
    else:
        kernel_block_alignment_size = 16
        if (
            current_platform.is_device_capability_family(100)
            and model_config.get_head_size() == 256
            and (
                attention_config.backend is None
                or attention_config.backend == AttentionBackendEnum.FLASHINFER
            )
        ):
            # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
            # head size 256 and block size 16 is not supported on blackwell.
            kernel_block_alignment_size = 32
        attn_page_size_1_token = FullAttentionSpec(
            block_size=1,
            num_kv_heads=model_config.get_num_kv_heads(parallel_config),
            head_size=model_config.get_head_size(),
            dtype=kv_cache_dtype,
        ).page_size_bytes

    model_cls, _ = ModelRegistry.resolve_model_cls(
        model_config.architecture,
        model_config=model_config,
    )

    # get mamba page size
    mamba_page_size = MambaSpec(
        shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
        dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config),
        block_size=-1,  # block_size doesn't matter for mamba page size
    ).page_size_bytes

    # Model may be marked as is_hybrid
    #  but mamba is skipped via config,
    #  return directly
    if mamba_page_size == 0:
        return

    if cache_config.mamba_cache_mode == "all":
        # With prefix caching, select attention block size to
        # optimize for mamba kernel performance

        # Mamba2 SSD kernel uses a chunk_size, e.g. 256
        # Align the block to the kernel: use lowest multiple of chunk_size
        # of attention tokens that would fit mamba_page_size:
        # e.g. for mamba page size = 788kB
        #          attn_1_token = 2kB -> fits ~394 tokens
        #      then round up to a multiple of 256 -> 512 tokens
        # End result:
        #  attn_block_size = 512
        #  mamba_block_size = 512 (aligned to a multiple of chunk_size)
        # TODO(tdoublep): this constraint can be relaxed fairly
        # easily by changing the way we layout chunks in the
        # mamba2 kernels.

        base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
        attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
        chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
        attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size)
        cache_config.mamba_block_size = attn_block_size
    else:
        # Without prefix caching, select minimum valid attention block size
        # to minimize mamba state padding

        # Calculate minimum attention block size that satisfies both:
        # 1. Backend alignment requirements (kernel_block_alignment_size)
        # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size)
        attn_block_size = kernel_block_alignment_size * cdiv(
            mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
        )

    # override attention block size if either (a) the
    # user has not set it or (b) the user has set it
    # too small.
    if cache_config.block_size is None or cache_config.block_size < attn_block_size:
        cache_config.block_size = attn_block_size
        logger.info(
            "Setting attention block size to %d tokens "
            "to ensure that attention page size is >= mamba page size.",
            attn_block_size,
        )

    # By default, mamba block size will be set to max_model_len.
    # When enabling prefix caching and using align mamba cache
    # mode, we align mamba block size to the block size as the
    # basic granularity for prefix caching.
    if cache_config.mamba_cache_mode == "align":
        cache_config.mamba_block_size = cache_config.block_size

    # compute new attention page size
    attn_page_size = cache_config.block_size * attn_page_size_1_token

    assert attn_page_size >= mamba_page_size

    if attn_page_size == mamba_page_size:
        # don't need to pad mamba page size
        return

    # pad mamba page size to exactly match attention
    if (
        cache_config.mamba_page_size_padded is None
        or cache_config.mamba_page_size_padded != attn_page_size
    ):
        cache_config.mamba_page_size_padded = attn_page_size
        mamba_padding_pct = (
            100 * (attn_page_size - mamba_page_size) / mamba_page_size
        )
        logger.info(
            "Padding mamba page size by %.2f%% to ensure "
            "that mamba page size and attention page size are "
            "exactly equal.",
            mamba_padding_pct,
        )

MambaModelConfig

Bases: VerifyAndUpdateConfig

Source code in vllm/model_executor/models/config.py
class MambaModelConfig(VerifyAndUpdateConfig):
    @classmethod
    def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
        """
        Enable FULL_AND_PIECEWISE cuda graph mode by default (required
        to get good performance for mamba layers in V1).

        Args:
            vllm_config: vLLM Config
        """
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config

        if cache_config.enable_prefix_caching:
            if cache_config.mamba_cache_mode == "none":
                cache_config.mamba_cache_mode = (
                    "all" if model_config.supports_mamba_prefix_caching else "align"
                )
                logger.warning(
                    "Mamba cache mode is set to '%s' for %s by default "
                    "when prefix caching is enabled",
                    cache_config.mamba_cache_mode,
                    model_config.architecture,
                )
            if (
                cache_config.mamba_cache_mode == "all"
                and not model_config.supports_mamba_prefix_caching
            ):
                cache_config.mamba_cache_mode = "align"
                logger.warning(
                    "Hybrid or mamba-based model detected without support "
                    "for prefix caching with Mamba cache 'all' mode: "
                    "falling back to 'align' mode."
                )
            if cache_config.mamba_cache_mode == "align":
                assert vllm_config.scheduler_config.enable_chunked_prefill, (
                    "Chunked prefill is required for mamba cache mode 'align'."
                )
                assert not vllm_config.speculative_config, (
                    "Mamba cache mode 'align' is currently not compatible "
                    "with speculative decoding."
                )
            logger.info(
                "Warning: Prefix caching in Mamba cache '%s' "
                "mode is currently enabled. "
                "Its support for Mamba layers is experimental. "
                "Please report any issues you may observe.",
                cache_config.mamba_cache_mode,
            )
            # By default, mamba block size will be set to max_model_len (see
            # below). When enabling prefix caching, we align mamba block size
            # to the block size as the basic granularity for prefix caching.
            if cache_config.mamba_block_size is None:
                cache_config.mamba_block_size = cache_config.block_size
        else:
            if cache_config.mamba_cache_mode != "none":
                cache_config.mamba_cache_mode = "none"
                logger.warning(
                    "Mamba cache mode is set to 'none' when prefix caching is disabled"
                )
            if cache_config.mamba_block_size is None:
                cache_config.mamba_block_size = model_config.max_model_len

verify_and_update_config classmethod

verify_and_update_config(vllm_config: VllmConfig) -> None

Enable FULL_AND_PIECEWISE cuda graph mode by default (required to get good performance for mamba layers in V1).

Parameters:

Name Type Description Default
vllm_config VllmConfig

vLLM Config

required
Source code in vllm/model_executor/models/config.py
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
    """
    Enable FULL_AND_PIECEWISE cuda graph mode by default (required
    to get good performance for mamba layers in V1).

    Args:
        vllm_config: vLLM Config
    """
    model_config = vllm_config.model_config
    cache_config = vllm_config.cache_config

    if cache_config.enable_prefix_caching:
        if cache_config.mamba_cache_mode == "none":
            cache_config.mamba_cache_mode = (
                "all" if model_config.supports_mamba_prefix_caching else "align"
            )
            logger.warning(
                "Mamba cache mode is set to '%s' for %s by default "
                "when prefix caching is enabled",
                cache_config.mamba_cache_mode,
                model_config.architecture,
            )
        if (
            cache_config.mamba_cache_mode == "all"
            and not model_config.supports_mamba_prefix_caching
        ):
            cache_config.mamba_cache_mode = "align"
            logger.warning(
                "Hybrid or mamba-based model detected without support "
                "for prefix caching with Mamba cache 'all' mode: "
                "falling back to 'align' mode."
            )
        if cache_config.mamba_cache_mode == "align":
            assert vllm_config.scheduler_config.enable_chunked_prefill, (
                "Chunked prefill is required for mamba cache mode 'align'."
            )
            assert not vllm_config.speculative_config, (
                "Mamba cache mode 'align' is currently not compatible "
                "with speculative decoding."
            )
        logger.info(
            "Warning: Prefix caching in Mamba cache '%s' "
            "mode is currently enabled. "
            "Its support for Mamba layers is experimental. "
            "Please report any issues you may observe.",
            cache_config.mamba_cache_mode,
        )
        # By default, mamba block size will be set to max_model_len (see
        # below). When enabling prefix caching, we align mamba block size
        # to the block size as the basic granularity for prefix caching.
        if cache_config.mamba_block_size is None:
            cache_config.mamba_block_size = cache_config.block_size
    else:
        if cache_config.mamba_cache_mode != "none":
            cache_config.mamba_cache_mode = "none"
            logger.warning(
                "Mamba cache mode is set to 'none' when prefix caching is disabled"
            )
        if cache_config.mamba_block_size is None:
            cache_config.mamba_block_size = model_config.max_model_len

NemotronHForCausalLMConfig

Bases: VerifyAndUpdateConfig

Source code in vllm/model_executor/models/config.py
class NemotronHForCausalLMConfig(VerifyAndUpdateConfig):
    @staticmethod
    def verify_and_update_config(vllm_config: "VllmConfig") -> None:
        """Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto'
        (or not explicitly set), to the value specified in the HF config, or to
        float16 if not specified.
        """
        cache_config = vllm_config.cache_config
        if cache_config.mamba_ssm_cache_dtype == "auto":
            hf_config = vllm_config.model_config.hf_config
            mamba_ssm_cache_dtype = getattr(
                hf_config, "mamba_ssm_cache_dtype", "float16"
            )
            logger.info(
                "Updating mamba_ssm_cache_dtype to '%s' for NemotronH model",
                mamba_ssm_cache_dtype,
            )
            cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype

verify_and_update_config staticmethod

verify_and_update_config(vllm_config: VllmConfig) -> None

Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto' (or not explicitly set), to the value specified in the HF config, or to float16 if not specified.

Source code in vllm/model_executor/models/config.py
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
    """Update mamba_ssm_cache_dtype for NemotronH models when set to 'auto'
    (or not explicitly set), to the value specified in the HF config, or to
    float16 if not specified.
    """
    cache_config = vllm_config.cache_config
    if cache_config.mamba_ssm_cache_dtype == "auto":
        hf_config = vllm_config.model_config.hf_config
        mamba_ssm_cache_dtype = getattr(
            hf_config, "mamba_ssm_cache_dtype", "float16"
        )
        logger.info(
            "Updating mamba_ssm_cache_dtype to '%s' for NemotronH model",
            mamba_ssm_cache_dtype,
        )
        cache_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype