def select_fp8_moe_backend(
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
allow_vllm_cutlass: bool = False,
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
k_cls: type[mk.FusedMoEPermuteExpertsUnpermute] | None = None
if config.is_lora_enabled:
return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)
# NOTE: the kernels are selected in the following order.
AVAILABLE_BACKENDS = [
Fp8MoeBackend.AITER,
Fp8MoeBackend.FLASHINFER_TRTLLM,
Fp8MoeBackend.FLASHINFER_CUTLASS,
Fp8MoeBackend.DEEPGEMM,
Fp8MoeBackend.BATCHED_DEEPGEMM,
Fp8MoeBackend.VLLM_CUTLASS,
Fp8MoeBackend.BATCHED_VLLM_CUTLASS,
Fp8MoeBackend.TRITON,
Fp8MoeBackend.BATCHED_TRITON,
Fp8MoeBackend.MARLIN,
Fp8MoeBackend.XPU,
]
# NOTE(rob): We need to peak into the P/F selection to determine
# if we are using the batched or standard expert format, which
# if not ideal. Once we unify TP + DP/EP, we can select P/F first.
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if config.moe_parallel_config.use_batched_activation_format
else mk.FusedMoEActivationFormat.Standard
)
def _make_log_backend(backend: Fp8MoeBackend):
available_backend_strs = [b.value for b in AVAILABLE_BACKENDS]
return (
f"Using {backend.value} Fp8 MoE backend out "
f"of potential backends: {available_backend_strs}."
)
def _make_log_unsupported(backend: Fp8MoeBackend, reason: str | None) -> str:
if reason:
return (
f"FP8 MoE backend {backend.value} does not support the "
f"deployment configuration since {reason}."
)
else:
return (
f"FP8 MoE backend '{backend.value}' does not support the "
"deployment configuration."
)
def _return_or_raise(
backend: Fp8MoeBackend,
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
raise ValueError(_make_log_unsupported(backend, reason))
# Handle explicit FlashInfer FP8 configuration.
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP8"):
if not envs.VLLM_USE_FLASHINFER_MOE_FP8:
# If the user rejects FlashInfer remove those backends.
AVAILABLE_BACKENDS.remove(Fp8MoeBackend.FLASHINFER_TRTLLM)
AVAILABLE_BACKENDS.remove(Fp8MoeBackend.FLASHINFER_CUTLASS)
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
# If user is explicit about backend, validate it.
fi_backend = get_flashinfer_moe_backend()
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
backend = Fp8MoeBackend.FLASHINFER_TRTLLM
supported, reason = is_supported_config_trtllm_fp8(
config, weight_key, activation_key, activation_format
)
if supported:
logger.info_once(_make_log_backend(backend))
return backend, None
else:
raise ValueError(_make_log_unsupported(backend, reason))
elif fi_backend == FlashinferMoeBackend.CUTLASS:
backend = Fp8MoeBackend.FLASHINFER_CUTLASS
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
else:
assert fi_backend == FlashinferMoeBackend.CUTEDSL
raise ValueError("FlashInfer MaskedGEMM not supported for FP8")
else:
# If the user is not explicit about the backend, try both.
for backend in [
Fp8MoeBackend.FLASHINFER_TRTLLM,
Fp8MoeBackend.FLASHINFER_CUTLASS,
]:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
k_cls = None
supported, reason = is_supported_config_trtllm_fp8(
config,
weight_key,
activation_key,
activation_format,
)
else:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(
_make_log_unsupported(backend, reason), scope="local"
)
raise NotImplementedError(
"Found VLLM_USE_FLASHINFER_MOE_FP8=1, but no "
"FlashInfer FP8 MoE backend supports the configuration."
)
# Handle explicit DeepGEMM FP8 configuration.
if envs.is_set("VLLM_USE_DEEP_GEMM") or envs.is_set("VLLM_MOE_USE_DEEP_GEMM"):
if not envs.VLLM_USE_DEEP_GEMM or not envs.VLLM_MOE_USE_DEEP_GEMM:
AVAILABLE_BACKENDS.remove(Fp8MoeBackend.DEEPGEMM)
AVAILABLE_BACKENDS.remove(Fp8MoeBackend.BATCHED_DEEPGEMM)
else:
backend = (
Fp8MoeBackend.DEEPGEMM
if activation_format == mk.FusedMoEActivationFormat.Standard
else Fp8MoeBackend.BATCHED_DEEPGEMM
)
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
# Handle explicit MARLIN FP8 configuration.
if envs.VLLM_TEST_FORCE_FP8_MARLIN:
backend = Fp8MoeBackend.MARLIN
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
# Handle explicit AITER FP8 configuration.
if envs.is_set("VLLM_ROCM_USE_AITER") or envs.is_set("VLLM_ROCM_USE_AITER_MOE"):
if not envs.VLLM_ROCM_USE_AITER or not envs.VLLM_ROCM_USE_AITER_MOE:
AVAILABLE_BACKENDS.remove(Fp8MoeBackend.AITER)
else:
backend = Fp8MoeBackend.AITER
return _return_or_raise(
backend, config, weight_key, activation_key, activation_format
)
if not allow_vllm_cutlass:
AVAILABLE_BACKENDS.remove(Fp8MoeBackend.VLLM_CUTLASS)
AVAILABLE_BACKENDS.remove(Fp8MoeBackend.BATCHED_VLLM_CUTLASS)
# Select kernels in order of backend.
for backend in AVAILABLE_BACKENDS:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
k_cls = None
supported, reason = is_supported_config_trtllm_fp8(
config,
weight_key,
activation_key,
activation_format,
)
else:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls,
config,
weight_key,
activation_key,
activation_format,
)
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
return backend, k_cls
else:
logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
# TODO(rob): per discussion with TPU team, we need a way to register
# MoE backends by OOT plugins, rather than having an explicit list
# of AVAILABLE_BACKENDS. Enabling returning `Fp8MoeBackend.NONE` is
# a temporary measure until these register APIs are complete.
if current_platform.is_cuda() or current_platform.is_rocm():
raise NotImplementedError(
"No FP8 MoE backend supports the deployment configuration."
)
return Fp8MoeBackend.NONE, None