Skip to content

vllm.model_executor.layers.fused_moe.cutlass_moe

CUTLASS based Fused MoE kernels.

CutlassBatchedExpertsFp8

Bases: CutlassExpertsFp8Base

Batched CUTLASS FP8 fused MoE expert implementation.

Source code in vllm/model_executor/layers/fused_moe/cutlass_moe.py
class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
    """Batched CUTLASS FP8 fused MoE expert implementation."""

    @staticmethod
    def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
        # BATCHED activation format works with EP because
        # expert_map is not used to identify experts (the
        # info is encoded/managed by the P/F logic).
        return True

    @staticmethod
    def activation_format() -> mk.FusedMoEActivationFormat:
        return mk.FusedMoEActivationFormat.BatchedExperts

    def supports_chunking(self) -> bool:
        return False

    def supports_expert_map(self) -> bool:
        return False

    def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
        return self.out_dtype if self.out_dtype is not None else act_dtype

    def workspace_shapes(
        self,
        M: int,
        N: int,
        K: int,
        topk: int,
        global_num_experts: int,
        local_num_experts: int,
        expert_tokens_meta: mk.ExpertTokensMetadata | None,
        activation: MoEActivation,
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
        num_dp = self.num_dispatchers
        assert num_dp is not None
        experts_per_worker = self.moe_config.num_local_experts
        activation_out_dim = self.adjust_N_for_activation(N, activation)
        workspace1 = (experts_per_worker, M * num_dp, max(N, K))
        workspace2 = (
            experts_per_worker,
            M * num_dp,
            max(activation_out_dim, K),
        )
        output = (experts_per_worker, M, K)
        return (workspace1, workspace2, output)

CutlassExpertsFp4

Bases: FusedMoEPermuteExpertsUnpermute

CUTLASS FP4 fused MoE expert implementation.

Source code in vllm/model_executor/layers/fused_moe/cutlass_moe.py
class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
    """CUTLASS FP4 fused MoE expert implementation."""

    @property
    def expects_unquantized_inputs(self) -> bool:
        return True

    @staticmethod
    def _supports_current_device() -> bool:
        p = current_platform
        return p.is_cuda() and (
            p.is_device_capability_family(100)
            or p.is_device_capability_family(110)
            or p.is_device_capability_family(120)
        )

    @staticmethod
    def _supports_no_act_and_mul() -> bool:
        return False

    @staticmethod
    def _supports_quant_scheme(
        weight_key: QuantKey | None,
        activation_key: QuantKey | None,
    ) -> bool:
        return (weight_key, activation_key) == (kNvfp4Static, kNvfp4Dynamic)

    @staticmethod
    def _supports_activation(activation: MoEActivation) -> bool:
        return activation in [
            MoEActivation.SILU,
            MoEActivation.GELU,
            MoEActivation.SWIGLUOAI,
        ]

    @staticmethod
    def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
        # CutlassExpertsFp4 does not support expert map, which is
        # needed for STANDARD activation format kernels in EP mode.
        return moe_parallel_config.ep_size == 1

    @staticmethod
    def activation_format() -> mk.FusedMoEActivationFormat:
        return mk.FusedMoEActivationFormat.Standard

    def supports_expert_map(self) -> bool:
        return False

    def supports_chunking(self) -> bool:
        return True

    def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
        return TopKWeightAndReduceNoOP()

    def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
        return act_dtype

    def workspace_shapes(
        self,
        M: int,
        N: int,
        K: int,
        topk: int,
        global_num_experts: int,
        local_num_experts: int,
        expert_tokens_meta: mk.ExpertTokensMetadata | None,
        activation: MoEActivation,
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
        workspace1 = (M * topk, max(2 * N, K))
        workspace2 = (M * topk, N)
        output = (M, K)
        return (workspace1, workspace2, output)

    def apply(
        self,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        activation: MoEActivation,
        global_num_experts: int,
        expert_map: torch.Tensor | None,
        a1q_scale: torch.Tensor | None,  # unused
        a2_scale: torch.Tensor | None,  # unused
        workspace13: torch.Tensor | None,
        workspace2: torch.Tensor | None,
        expert_tokens_meta: mk.ExpertTokensMetadata | None,
        apply_router_weight_on_input: bool,
    ):
        e, m, n, k, _ = self.moe_problem_size(hidden_states, w1, w2, topk_ids)
        n = w2.shape[2] * 2

        run_cutlass_moe_fp4(
            output=output,
            a=hidden_states,
            a1_gscale=self.a1_gscale,
            w1_fp4=w1,
            w1_blockscale=self.w1_scale,
            w1_alphas=self.g1_alphas,
            a2_gscale=self.a2_gscale,
            w2_fp4=w2,
            w2_blockscale=self.w2_scale,
            w2_alphas=self.g2_alphas,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            activation=activation,
            workspace13=workspace13,
            workspace2=workspace2,
            m=m,
            n=n,
            k=k,
            e=e,
            device=hidden_states.device,
            apply_router_weight_on_input=apply_router_weight_on_input,
        )

CutlassExpertsFp8

Bases: CutlassExpertsFp8Base

CUTLASS FP8 fused MoE expert implementation.

Source code in vllm/model_executor/layers/fused_moe/cutlass_moe.py
class CutlassExpertsFp8(CutlassExpertsFp8Base):
    """CUTLASS FP8 fused MoE expert implementation."""

    @staticmethod
    def activation_format() -> mk.FusedMoEActivationFormat:
        return mk.FusedMoEActivationFormat.Standard

    @staticmethod
    def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
        # CutlassExpertsFp8 does not support expert map, which is
        # needed for STANDARD activation format kernels in DP/EP mode.
        # Note that the BATCHED activation format does not use
        # the expert map for identifying experts.
        return not (
            moe_parallel_config.use_fi_all2allv_kernels
            or moe_parallel_config.use_deepep_ht_kernels
        )

    def supports_chunking(self) -> bool:
        return True

    def supports_expert_map(self) -> bool:
        return False

    def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
        # topk weights and reduction are fused in moe_unpermute cuda kernel
        return TopKWeightAndReduceNoOP()

    def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
        return self.out_dtype if self.out_dtype is not None else act_dtype

    def workspace_shapes(
        self,
        M: int,
        N: int,
        K: int,
        topk: int,
        global_num_experts: int,
        local_num_experts: int,
        expert_tokens_meta: mk.ExpertTokensMetadata | None,
        activation: MoEActivation,
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
        activation_out_dim = self.adjust_N_for_activation(N, activation)
        workspace1 = (M * topk, max(N, K))
        workspace2 = (M * topk, max(activation_out_dim, K))
        output = (M, K)
        return (workspace1, workspace2, output)

cutlass_moe_w4a8_fp8

cutlass_moe_w4a8_fp8(
    a: Tensor,
    w1_q: Tensor,
    w2_q: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    a_strides1: Tensor,
    a_strides2: Tensor,
    b_strides1: Tensor,
    b_strides2: Tensor,
    c_strides1: Tensor,
    c_strides2: Tensor,
    s_strides1: Tensor,
    s_strides2: Tensor,
    quant_config: FusedMoEQuantConfig,
    moe_config: FusedMoEConfig,
    activation: MoEActivation = SILU,
    expert_map: Tensor | None = None,
    apply_router_weight_on_input: bool = False,
    global_num_experts: int = -1,
    group_size: int = 128,
) -> Tensor

This function computes a w4a8-quantized Mixture of Experts (MoE) layer using two sets of quantized weights, w1_q and w2_q, and top-k gating mechanism. The matrix multiplications are implemented with CUTLASS mixed-dtype grouped gemm.

  • a (torch.Tensor): The input tensor to the MoE layer. Shape: [M, K]
  • w1_q (torch.Tensor): The first set of fp8-quantized expert weights. Shape: [num_experts, 2*N, K // packed_factor]
  • w2_q (torch.Tensor): The second set of fp8-quantized expert weights. Shape: [num_experts, K, N // packed_factor]
  • topk_weights (torch.Tensor): The weights of each token->expert mapping.
  • topk_ids (torch.Tensor): The token->expert mappings.
  • a_strides1 (torch.Tensor): The input strides for the first gemm. Shape: [num_experts]
  • a_strides2 (torch.Tensor): The input strides for the second gemm. Shape: [num_experts]
  • b_strides1 (torch.Tensor): The packed layout for the first gemm weights. Shape: [num_experts, 3] dtype: torch.int32
  • b_strides2 (torch.Tensor): The packed layout for the second gemm weights. Shape: [num_experts, 3] dtype: torch.int32
  • c_strides1 (torch.Tensor): The output strides for the first gemm. Shape: [num_experts]
  • c_strides2 (torch.Tensor): The output strides for the second gemm. Shape: [num_experts]
  • s_strides1 (torch.Tensor): strides for the group-wise scales for the first gemm. Shape: [num_experts, 2] dtype: torch.int64
  • s_strides2 (torch.Tensor): strides for the group-wise scales for the second gemm. Shape: [num_experts, 2] dtype: torch.int64
  • per_act_token (Optional[bool]): Whether the scale is per-token or per-tensor.
  • activation (MoEActivation): The activation function to use.
  • expert_map (Optional[torch.Tensor]): In the case of Expert parallel, every Rank is responsible for a subset of experts. expert_map is a mapping from global expert-id to local expert-id. When expert_map[i] is -1, it means that this Rank is not responsible for global expert-id i.
  • apply_router_weight_on_input (bool): When true, the topk weights are applied directly on the inputs. This is only applicable when topk is 1.
  • global_num_experts (int): The total number of experts.
  • group_size (int): The number of weights per scale factor

Returns: - torch.Tensor: The bf16 output tensor after applying the MoE layer.

Source code in vllm/model_executor/layers/fused_moe/cutlass_moe.py
def cutlass_moe_w4a8_fp8(
    a: torch.Tensor,
    w1_q: torch.Tensor,
    w2_q: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    a_strides1: torch.Tensor,
    a_strides2: torch.Tensor,
    b_strides1: torch.Tensor,
    b_strides2: torch.Tensor,
    c_strides1: torch.Tensor,
    c_strides2: torch.Tensor,
    s_strides1: torch.Tensor,
    s_strides2: torch.Tensor,
    quant_config: FusedMoEQuantConfig,
    moe_config: FusedMoEConfig,
    activation: MoEActivation = MoEActivation.SILU,
    expert_map: torch.Tensor | None = None,
    apply_router_weight_on_input: bool = False,
    global_num_experts: int = -1,
    group_size: int = 128,
) -> torch.Tensor:
    """
    This function computes a w4a8-quantized Mixture of Experts (MoE) layer
    using two sets of quantized weights, w1_q and w2_q, and top-k gating
    mechanism. The matrix multiplications are implemented with CUTLASS
    mixed-dtype grouped gemm.

    Parameters:
    - a (torch.Tensor): The input tensor to the MoE layer.
        Shape: [M, K]
    - w1_q (torch.Tensor): The first set of fp8-quantized expert weights.
        Shape: [num_experts, 2*N, K // packed_factor]
    - w2_q (torch.Tensor): The second set of fp8-quantized expert weights.
        Shape: [num_experts, K, N // packed_factor]
    - topk_weights (torch.Tensor): The weights of each token->expert mapping.
    - topk_ids (torch.Tensor): The token->expert mappings.
    - a_strides1 (torch.Tensor): The input strides for the first gemm.
        Shape: [num_experts]
    - a_strides2 (torch.Tensor): The input strides for the second gemm.
        Shape: [num_experts]
    - b_strides1 (torch.Tensor): The packed layout for the first gemm weights.
        Shape: [num_experts, 3]
        dtype: torch.int32
    - b_strides2 (torch.Tensor): The packed layout for the second gemm weights.
        Shape: [num_experts, 3]
        dtype: torch.int32
    - c_strides1 (torch.Tensor): The output strides for the first gemm.
        Shape: [num_experts]
    - c_strides2 (torch.Tensor): The output strides for the second gemm.
        Shape: [num_experts]
    - s_strides1 (torch.Tensor): strides for the group-wise scales for the first gemm.
        Shape: [num_experts, 2]
        dtype: torch.int64
    - s_strides2 (torch.Tensor): strides for the group-wise scales for the second gemm.
        Shape: [num_experts, 2]
        dtype: torch.int64
    - per_act_token (Optional[bool]): Whether the scale is per-token or
                                      per-tensor.
    - activation (MoEActivation): The activation function to use.
    - expert_map (Optional[torch.Tensor]): In the case of Expert parallel,
        every Rank is responsible for a subset of experts. expert_map is a
        mapping from global expert-id to local expert-id. When expert_map[i]
        is -1, it means that this Rank is not responsible for global
        expert-id i.
    - apply_router_weight_on_input (bool): When true, the topk weights are
        applied directly on the inputs. This is only applicable when topk is 1.
    - global_num_experts (int): The total number of experts.
    - group_size (int): The number of weights per scale factor

    Returns:
    - torch.Tensor: The bf16 output tensor after applying the MoE layer.
    """
    assert quant_config is not None

    num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)

    fn = mk.FusedMoEModularKernel(
        MoEPrepareAndFinalizeNoEP(),
        CutlassExpertsW4A8Fp8(
            out_dtype=a.dtype,
            a_strides1=a_strides1,
            a_strides2=a_strides2,
            b_strides1=b_strides1,
            b_strides2=b_strides2,
            c_strides1=c_strides1,
            c_strides2=c_strides2,
            s_strides1=s_strides1,
            s_strides2=s_strides2,
            moe_config=moe_config,
            quant_config=quant_config,
            group_size=group_size,
        ),
        inplace=False,
    )

    return fn(
        a,
        w1_q,
        w2_q,
        topk_weights,
        topk_ids,
        activation=activation,
        global_num_experts=num_experts,
        expert_map=expert_map,
        apply_router_weight_on_input=apply_router_weight_on_input,
    )

run_cutlass_moe_fp4

run_cutlass_moe_fp4(
    output: Tensor,
    a: Tensor,
    a1_gscale: Tensor,
    w1_fp4: Tensor,
    w1_blockscale: Tensor,
    w1_alphas: Tensor,
    a2_gscale: Tensor,
    w2_fp4: Tensor,
    w2_blockscale: Tensor,
    w2_alphas: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    activation: MoEActivation,
    workspace13: Tensor,
    workspace2: Tensor,
    m: int,
    n: int,
    k: int,
    e: int,
    device: device,
    apply_router_weight_on_input: bool = False,
) -> None

MoE implementation for FP4 Inputs

Gemm 1

a: Input tensor: [m, k] (half/bfloat16) a1_gscale: Activation scale per expert: [e] (float32) w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k] w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1) (Note: n is the up projection output dim, k is the input dim in full precision) w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3) (Block size = 16 for NVFP4)

Gemm 2

a2_gscale: Activation scale per expert: [e] w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n] w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1) w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3

topk_weights: [m, topk] dtype: float8 topk_ids: [m, topk] dtype: float8

m, n, k: Unquantized weight shapes, dtype: int e: number of experts, dtype: int

assumes that topk < k < n to satisfy - up/down projection expectations.

Source code in vllm/model_executor/layers/fused_moe/cutlass_moe.py
def run_cutlass_moe_fp4(
    output: torch.Tensor,
    a: torch.Tensor,
    a1_gscale: torch.Tensor,
    w1_fp4: torch.Tensor,
    w1_blockscale: torch.Tensor,
    w1_alphas: torch.Tensor,
    a2_gscale: torch.Tensor,
    w2_fp4: torch.Tensor,
    w2_blockscale: torch.Tensor,
    w2_alphas: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    activation: MoEActivation,
    workspace13: torch.Tensor,
    workspace2: torch.Tensor,
    m: int,
    n: int,
    k: int,
    e: int,
    device: torch.device,
    apply_router_weight_on_input: bool = False,
) -> None:
    """
    MoE implementation for FP4 Inputs

    # Gemm 1
    a: Input tensor: [m, k] (half/bfloat16)
    a1_gscale: Activation scale per expert: [e]  (float32)
    w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k]
    w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1)
    (Note: `n` is the up projection output dim, `k` is the input dim in
     full precision)
    w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3)
                   (Block size = 16 for NVFP4)

    # Gemm 2
    a2_gscale: Activation scale per expert: [e]
    w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n]
    w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1)
    w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3

    topk_weights: [m, topk] dtype: float8
    topk_ids: [m, topk] dtype: float8

    m, n, k: Unquantized weight shapes, dtype: int
    e: number of experts, dtype: int

    assumes that topk < k < n to satisfy - up/down projection expectations.
    """
    assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
    assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8"
    assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8"
    assert (
        w1_fp4.ndim == 3
        and w2_fp4.ndim == 3
        and w1_blockscale.ndim == 3
        and w2_blockscale.ndim == 3
    ), "All Weights must be of rank 3 for cutlass_moe_fp4"
    m_a, k_a = a.shape
    e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
    e_w2, k_w2, half_n_w2 = w2_fp4.shape

    assert e_w1 == e_w2 and e_w1 == e, (
        "Number of experts must match",
        f" between weights. {e_w1}, {e_w2}, {e}",
    )
    assert k_a == half_k_w1 * 2 and k == k_w2, (
        "Hidden size mismatch between a, w1 and w2"
    )
    assert nx2_w1 == n * 2 and half_n_w2 * 2 == n, "mismatch in expected `n`"
    assert m == m_a, "input shape mismatch"
    assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
    assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
    assert topk_weights.size(0) == m and topk_ids.size(0) == m, (
        "topk must be provided for each row of a"
    )
    topk = topk_ids.size(1)
    out_dtype = a.dtype
    num_topk = topk_ids.size(1)

    expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
    blockscale_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
    # Problem size:  (num_experts, (m,2n,k))
    problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
    # Problem size:  (num_experts, (m,n,k))
    problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device)

    a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
    c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)

    if apply_router_weight_on_input:
        # TODO: this only works for topK=1, will need to update for topK>1
        assert num_topk == 1, (
            "apply_router_weight_on_input is only implemented for topk=1"
        )
        a.mul_(topk_weights.to(out_dtype))

    # problem shapes should have [m, n, k]
    # Note that problem sizes are based on logical number of elements.
    ops.get_cutlass_moe_mm_data(
        topk_ids,
        expert_offsets,
        problem_sizes1,
        problem_sizes2,
        a_map,
        c_map,
        e,
        n,
        k,
        blockscale_offsets,
    )

    a = ops.shuffle_rows(a, a_map)
    rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant(
        a,
        a1_gscale,
        expert_offsets,
        blockscale_offsets,
        num_topk,
    )
    c1 = _resize_cache(workspace13, (m * topk, n * 2))
    c2 = _resize_cache(workspace2, (m * topk, n))
    c3 = _resize_cache(workspace13, (m * topk, k))
    ops.cutlass_fp4_moe_mm(
        c1,
        rep_a_fp4,
        w1_fp4,
        rep_a_blockscale,
        w1_blockscale,
        w1_alphas,
        problem_sizes1,
        expert_offsets[:-1],
        blockscale_offsets[:-1],
    )
    del rep_a_fp4, rep_a_blockscale
    if activation == MoEActivation.SILU:
        # Fused SiLU+Mul+NVFP4 quantization
        # Note: c2 workspace is no longer needed since SiLU is fused with quantization.
        # c3 reuses workspace13 after c1 is consumed.
        int_fp4, int_blockscale = ops.silu_and_mul_scaled_fp4_experts_quant(
            c1, a2_gscale, expert_offsets, blockscale_offsets, num_topk
        )
    else:
        apply_moe_activation(activation, c2, c1)
        int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
            c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk
        )

    ops.cutlass_fp4_moe_mm(
        c3,
        int_fp4,
        w2_fp4,
        int_blockscale,
        w2_blockscale,
        w2_alphas,
        problem_sizes2,
        expert_offsets[:-1],
        blockscale_offsets[:-1],
    )
    del int_fp4, int_blockscale

    c3 = ops.shuffle_rows(c3, c_map)

    assert output.dtype == out_dtype
    if not apply_router_weight_on_input:
        output.copy_(
            (
                c3.view(m, num_topk, k)
                * topk_weights.view(m, num_topk, 1).to(out_dtype)
            ).sum(dim=1),
            non_blocking=True,
        )
    else:
        output.copy_(c3.view(m, num_topk, k).sum(dim=1), non_blocking=True)
    return