Skip to content

vllm.model_executor.layers.fused_moe.fused_marlin_moe

Fused MoE utilities for GPTQ.

BatchedMarlinExperts

Bases: MarlinExpertsBase

Batched Marlin-based fused MoE expert implementation.

Source code in vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
class BatchedMarlinExperts(MarlinExpertsBase):
    """Batched Marlin-based fused MoE expert implementation."""

    def __init__(
        self,
        moe_config: FusedMoEConfig,
        quant_config: FusedMoEQuantConfig,
        max_num_tokens: int,
        num_dispatchers: int,
        w13_g_idx: torch.Tensor | None = None,
        w2_g_idx: torch.Tensor | None = None,
        w13_g_idx_sort_indices: torch.Tensor | None = None,
        w2_g_idx_sort_indices: torch.Tensor | None = None,
        is_k_full: bool = True,
    ):
        super().__init__(
            moe_config=moe_config,
            quant_config=quant_config,
            max_num_tokens=max_num_tokens,
            num_dispatchers=num_dispatchers,
            w13_g_idx=w13_g_idx,
            w2_g_idx=w2_g_idx,
            w13_g_idx_sort_indices=w13_g_idx_sort_indices,
            w2_g_idx_sort_indices=w2_g_idx_sort_indices,
            is_k_full=is_k_full,
        )

    def supports_expert_map(self) -> bool:
        return True

    def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
        return TopKWeightAndReduceDelegate()

    @staticmethod
    def activation_format() -> mk.FusedMoEActivationFormat:
        return mk.FusedMoEActivationFormat.BatchedExperts

    def supports_chunking(self) -> bool:
        return False

    def workspace_shapes(
        self,
        M: int,
        N: int,
        K: int,
        topk: int,
        global_num_experts: int,
        local_num_experts: int,
        expert_tokens_meta: mk.ExpertTokensMetadata | None,
        activation: MoEActivation,
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
        assert self.num_dispatchers is not None
        assert self.max_num_tokens is not None
        num_dispatchers = self.num_dispatchers
        num_experts = local_num_experts
        max_num_tokens = self.max_num_tokens
        workspace13 = (num_experts * max_num_tokens * num_dispatchers, max(K, N * 2))
        workspace2 = (num_experts * max_num_tokens * num_dispatchers, N)
        output = (num_experts, max_num_tokens * num_dispatchers, K)
        return (workspace13, workspace2, output)

    def apply(
        self,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        activation: MoEActivation,
        global_num_experts: int,
        expert_map: torch.Tensor | None,
        a1q_scale: torch.Tensor | None,
        a2_scale: torch.Tensor | None,
        workspace13: torch.Tensor,
        workspace2: torch.Tensor,
        expert_tokens_meta: mk.ExpertTokensMetadata | None,
        apply_router_weight_on_input: bool,
    ):
        assert expert_tokens_meta is not None, "Num valid tokens per batch is required"
        return batched_fused_marlin_moe(
            hidden_states=hidden_states,
            expert_num_tokens=expert_tokens_meta.expert_num_tokens,
            w1=w1,
            w2=w2,
            bias1=self.w1_bias,
            bias2=self.w2_bias,
            w1_scale=self.w1_scale,
            w2_scale=self.w2_scale,
            quant_type_id=self.quant_type_id,
            apply_router_weight_on_input=apply_router_weight_on_input,
            global_num_experts=global_num_experts,
            activation=activation,
            expert_map=expert_map,
            output=output,
            intermediate_cache13=workspace13,
            intermediate_cache2=workspace2,
            g_idx1=self.w13_g_idx,
            g_idx2=self.w2_g_idx,
            sort_indices1=self.w13_g_idx_sort_indices,
            sort_indices2=self.w2_g_idx_sort_indices,
            is_k_full=self.is_k_full,
        )

MarlinExperts

Bases: MarlinExpertsBase

Marlin-based fused MoE expert implementation.

Source code in vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
class MarlinExperts(MarlinExpertsBase):
    """Marlin-based fused MoE expert implementation."""

    def supports_expert_map(self) -> bool:
        return True

    def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
        return TopKWeightAndReduceNoOP()

    @staticmethod
    def activation_format() -> mk.FusedMoEActivationFormat:
        return mk.FusedMoEActivationFormat.Standard

    def supports_chunking(self) -> bool:
        return True

    def workspace_shapes(
        self,
        M: int,
        N: int,
        K: int,
        topk: int,
        global_num_experts: int,
        local_num_experts: int,
        expert_tokens_meta: mk.ExpertTokensMetadata | None,
        activation: MoEActivation,
    ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
        # Modular Kernel provisions output buffer from workspace1. However in
        # the fused_marlin_moe() function, the final torch.sum(), is defined
        # essentially as,
        # `torch.sum(workspace1, dim=1, out=output)`
        # Having overlapping input and output tensors for torch.sum seems
        # error prone and depends on how the torch.sum is implemented.
        # For this reason we swap let the output buffer provision from
        # workspace2.

        # Workspace/IntermediateCache allocation matching fused_marlin_moe()
        # workspace1 = (M * topk * max(2 * N, K),)
        # workspace2 = (M * topk, N)

        # Workspace/IntermediateCache allocation accounting for output buffer
        # provisioning
        workspace1 = (M * topk, max(N, K))
        workspace2 = (M * topk * max(2 * N, K),)
        output = (M, K)

        return (workspace1, workspace2, output)

    def apply(
        self,
        output: torch.Tensor,
        hidden_states: torch.Tensor,
        w1: torch.Tensor,
        w2: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        activation: MoEActivation,
        global_num_experts: int,
        expert_map: torch.Tensor | None,
        a1q_scale: torch.Tensor | None,
        a2_scale: torch.Tensor | None,
        workspace13: torch.Tensor,
        workspace2: torch.Tensor,
        expert_tokens_meta: mk.ExpertTokensMetadata | None,
        apply_router_weight_on_input: bool,
    ):
        assert self.w1_scale is not None
        assert self.w2_scale is not None
        return fused_marlin_moe(
            hidden_states=hidden_states,
            w1=w1,
            w2=w2,
            bias1=self.w1_bias,
            bias2=self.w2_bias,
            w1_scale=self.w1_scale,
            w2_scale=self.w2_scale,
            topk_weights=topk_weights,
            topk_ids=topk_ids,
            global_scale1=self.g1_alphas,
            global_scale2=self.g2_alphas,
            quant_type_id=self.quant_type_id,
            apply_router_weight_on_input=apply_router_weight_on_input,
            global_num_experts=global_num_experts,
            activation=activation,
            activation_func=self.activation,
            moe_sum=self.moe_sum,
            expert_map=expert_map,
            output=output,
            # Workspaces are swapped in workspace_shapes() to account for proper
            # output buffer allocation. Please refer to workspace_shapes().
            intermediate_cache13=workspace2,
            intermediate_cache2=workspace13,
            g_idx1=self.w13_g_idx,
            g_idx2=self.w2_g_idx,
            sort_indices1=self.w13_g_idx_sort_indices,
            sort_indices2=self.w2_g_idx_sort_indices,
            is_k_full=self.is_k_full,
        )

    def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
        ops.moe_sum(input, output)

batched_fused_marlin_moe

batched_fused_marlin_moe(
    hidden_states: Tensor,
    expert_num_tokens: Tensor,
    w1: Tensor,
    w2: Tensor,
    bias1: Tensor | None,
    bias2: Tensor | None,
    w1_scale: Tensor,
    w2_scale: Tensor,
    quant_type_id: int,
    apply_router_weight_on_input: bool = False,
    global_num_experts: int = -1,
    activation: MoEActivation = SILU,
    expert_map: Tensor | None = None,
    global_scale1: Tensor | None = None,
    global_scale2: Tensor | None = None,
    g_idx1: Tensor | None = None,
    g_idx2: Tensor | None = None,
    sort_indices1: Tensor | None = None,
    sort_indices2: Tensor | None = None,
    w1_zeros: Tensor | None = None,
    w2_zeros: Tensor | None = None,
    workspace: Tensor | None = None,
    intermediate_cache13: Tensor | None = None,
    intermediate_cache2: Tensor | None = None,
    is_k_full: bool = True,
    output: Tensor | None = None,
    inplace: bool = False,
) -> Tensor

This function massages the inputs so the batched hidden_states can be presented as a 2D contiguous tensor that could be used with _fused_marlin_moe.

Note that both batched_fused_marlin_moe and fused_marlin_moe ultimately use ops.moe_wna16_marlin_gemm for the gemm operation and ops.moe_mna16_marlin_gemm supports only 2D contiguous hidden_states. Note that the moe_align_block_size function indicates, - What rows of the A matrix (hidden_states) to access during the matmul, via sorted_ids output. - What expert_id to use for each block matmul, via expert_ids ouptut.

In the batched version, the tokens are already grouped/batched by experts they subscribe to. Due to this, we can represent the batched hidden_states tensor of shape [B, MAX_TOKENS_PER_BATCH, K] as a 2D tensor of shape, [B * MAX_TOKENS_PER_BATCH, K]. We may treat this a 2D contiguous tensor with topk=1 as each token (row in the tensor) subscribes to exactly one expert_id (which is the batch_id). With the expert_num_tokens tensor, that indicates how many tokens are actually valid in each batch, the batched_moe_align_block_size function constructs the sorted_ids and expert_ids tensors, so only relevant/valid rows of A (hidden_states) are accessed and are processed with the correct expert_ids.

Source code in vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
def batched_fused_marlin_moe(
    hidden_states: torch.Tensor,
    expert_num_tokens: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    bias1: torch.Tensor | None,
    bias2: torch.Tensor | None,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    quant_type_id: int,
    apply_router_weight_on_input: bool = False,
    global_num_experts: int = -1,
    activation: MoEActivation = MoEActivation.SILU,
    expert_map: torch.Tensor | None = None,
    global_scale1: torch.Tensor | None = None,
    global_scale2: torch.Tensor | None = None,
    g_idx1: torch.Tensor | None = None,
    g_idx2: torch.Tensor | None = None,
    sort_indices1: torch.Tensor | None = None,
    sort_indices2: torch.Tensor | None = None,
    w1_zeros: torch.Tensor | None = None,
    w2_zeros: torch.Tensor | None = None,
    workspace: torch.Tensor | None = None,
    intermediate_cache13: torch.Tensor | None = None,
    intermediate_cache2: torch.Tensor | None = None,
    is_k_full: bool = True,
    output: torch.Tensor | None = None,
    inplace: bool = False,
) -> torch.Tensor:
    """
    This function massages the inputs so the batched hidden_states can be
    presented as a 2D contiguous tensor that could be used with
    _fused_marlin_moe.

    Note that both batched_fused_marlin_moe and fused_marlin_moe ultimately
    use `ops.moe_wna16_marlin_gemm` for the gemm operation and
    `ops.moe_mna16_marlin_gemm` supports only 2D contiguous hidden_states.
    Note that the moe_align_block_size function indicates,
        - What rows of the A matrix (hidden_states) to access during the
        matmul, via sorted_ids output.
        - What expert_id to use for each block matmul, via expert_ids ouptut.

    In the batched version, the tokens are already grouped/batched by experts
    they subscribe to. Due to this, we can represent the batched hidden_states
    tensor of shape [B, MAX_TOKENS_PER_BATCH, K] as a 2D tensor of shape,
    [B * MAX_TOKENS_PER_BATCH, K]. We may treat this a 2D contiguous tensor
    with topk=1 as each token (row in the tensor) subscribes to exactly one
    expert_id (which is the batch_id). With the expert_num_tokens tensor, that
    indicates how many tokens are actually valid in each batch, the
    batched_moe_align_block_size function constructs the sorted_ids and
    expert_ids tensors, so only relevant/valid rows of A (hidden_states)
    are accessed and are processed with the correct expert_ids.
    """

    assert hidden_states.ndim == 3, (
        f"hidden states must be batched. e.g. [B, MAX_TOKENS, K]."
        f"But got {hidden_states.size()}"
    )
    if inplace:
        assert output is None, "Conflicting request."

    quant_type = ScalarType.from_id(quant_type_id)
    assert quant_type in [
        scalar_types.uint4,
        scalar_types.uint8b128,
        scalar_types.uint4b8,
        scalar_types.float8_e4m3fn,
        scalar_types.float4_e2m1f,
    ]

    bit4_scalar_types = [
        scalar_types.uint4,
        scalar_types.uint4b8,
        scalar_types.float4_e2m1f,
    ]
    num_bits = 4 if quant_type in bit4_scalar_types else 8

    B, BATCH_TOKENS_MAX, K = hidden_states.size()
    M = hidden_states.view(-1, K).size(0)
    E = w1.size(0)

    # Check constraints.
    assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
    assert hidden_states.dtype in [torch.float16, torch.bfloat16]
    assert expert_num_tokens.size(0) == E
    assert B == E, (
        "Batch must be as big as number of experts as the tokens"
        "are sorted into the batch/expert they belong to"
    )
    assert w1.size(1) * 16 == K, "Hidden size mismatch w1"
    assert w2.size(2) // (num_bits // 2) == K, "Hidden size mismatch w2"
    assert w1.is_contiguous(), "Expert weights1 must be contiguous"
    assert w2.is_contiguous(), "Expert weights2 must be contiguous"
    assert num_bits in [4, 8]

    # Technically, the tokens are already separated by their expert ids.
    # Hidden-States can just be squeezed to have just 2 dimensions,
    # [B * MAX_TOKENS, K] and top_k can be interpreted as just 1.
    topk = 1

    # TODO(varun) : Choose a decent block size like in fused_marlin_moe
    block_size_m = 64

    sorted_token_ids, expert_ids, num_tokens_post_padded = batched_moe_align_block_size(
        max_tokens_per_batch=BATCH_TOKENS_MAX,
        block_size=block_size_m,
        expert_num_tokens=expert_num_tokens,
    )

    if output is None and inplace:
        output = hidden_states

    # TODO (varun): This can be avoided by plumbing the marlin kernel to
    # ignore topk_weights when topk_weights_ptr is a nullptr.
    topk_weights = torch.ones(
        (M, topk), device=hidden_states.device, dtype=torch.float32
    )

    assert activation is not None
    output = _fused_marlin_moe(
        hidden_states=hidden_states.view(-1, K),
        w1=w1,
        w2=w2,
        bias1=bias1,
        bias2=bias2,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        topk_weights=topk_weights,
        num_topk=topk,
        quant_type=quant_type,
        apply_router_weight_on_input=apply_router_weight_on_input,
        activation=activation,
        expert_map=expert_map,
        block_size_m=block_size_m,
        sorted_token_ids=sorted_token_ids,
        expert_ids=expert_ids,
        num_tokens_post_padded=num_tokens_post_padded,
        global_scale1=global_scale1,
        global_scale2=global_scale2,
        g_idx1=g_idx1,
        g_idx2=g_idx2,
        sort_indices1=sort_indices1,
        sort_indices2=sort_indices2,
        w1_zeros=w1_zeros,
        w2_zeros=w2_zeros,
        workspace=workspace,
        intermediate_cache13=intermediate_cache13,
        intermediate_cache2=intermediate_cache2,
        output=output.view(-1, K) if output is not None else output,
        is_k_full=is_k_full,
    )

    output = output.view(B, BATCH_TOKENS_MAX, K)

    return output

fused_marlin_moe

fused_marlin_moe(
    hidden_states: Tensor,
    w1: Tensor,
    w2: Tensor,
    bias1: Tensor | None,
    bias2: Tensor | None,
    w1_scale: Tensor,
    w2_scale: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    quant_type_id: int,
    apply_router_weight_on_input: bool = False,
    global_num_experts: int = -1,
    activation: MoEActivation = SILU,
    activation_func: Callable[
        [MoEActivation, Tensor, Tensor], None
    ] = apply_moe_activation,
    moe_sum: Callable[[Tensor, Tensor], None] | None = None,
    expert_map: Tensor | None = None,
    input_global_scale1: Tensor | None = None,
    input_global_scale2: Tensor | None = None,
    global_scale1: Tensor | None = None,
    global_scale2: Tensor | None = None,
    g_idx1: Tensor | None = None,
    g_idx2: Tensor | None = None,
    sort_indices1: Tensor | None = None,
    sort_indices2: Tensor | None = None,
    w1_zeros: Tensor | None = None,
    w2_zeros: Tensor | None = None,
    workspace: Tensor | None = None,
    intermediate_cache13: Tensor | None = None,
    intermediate_cache2: Tensor | None = None,
    is_k_full: bool = True,
    output: Tensor | None = None,
    input_dtype: dtype | None = None,
    inplace: bool = False,
) -> Tensor

This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.

Parameters: - hidden_states (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - w1_scale (torch.Tensor): Scale to be used for w1. - w2_scale (torch.Tensor): Scale to be used for w2. - g_idx1 (torch.Tensor|None): The first set of act_order indices. - g_idx2 (torch.Tensor|None): The second set of act_order indices. - sort_indices1 (torch.Tensor|None): The first act_order input permutation. - sort_indices2 (torch.Tensor|None): The second act_order input permutation. - topk_weights (torch.Tensor): Top-k weights. - topk_ids (torch.Tensor): Indices of topk-k elements. - w1_zeros (torch.Tensor|None): Optional zero points to be used for w1. - w2_zeros (torch.Tensor|None): Optional zero points to be used for w2. - num_bits (bool): The number of bits in expert weights quantization.

Returns: - torch.Tensor: The output tensor after applying the MoE layer.

Source code in vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
def fused_marlin_moe(
    hidden_states: torch.Tensor,
    w1: torch.Tensor,
    w2: torch.Tensor,
    bias1: torch.Tensor | None,
    bias2: torch.Tensor | None,
    w1_scale: torch.Tensor,
    w2_scale: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    quant_type_id: int,
    apply_router_weight_on_input: bool = False,
    global_num_experts: int = -1,
    activation: MoEActivation = MoEActivation.SILU,
    activation_func: Callable[
        [MoEActivation, torch.Tensor, torch.Tensor], None
    ] = apply_moe_activation,
    moe_sum: Callable[[torch.Tensor, torch.Tensor], None] | None = None,
    expert_map: torch.Tensor | None = None,
    input_global_scale1: torch.Tensor | None = None,
    input_global_scale2: torch.Tensor | None = None,
    global_scale1: torch.Tensor | None = None,
    global_scale2: torch.Tensor | None = None,
    g_idx1: torch.Tensor | None = None,
    g_idx2: torch.Tensor | None = None,
    sort_indices1: torch.Tensor | None = None,
    sort_indices2: torch.Tensor | None = None,
    w1_zeros: torch.Tensor | None = None,
    w2_zeros: torch.Tensor | None = None,
    workspace: torch.Tensor | None = None,
    intermediate_cache13: torch.Tensor | None = None,
    intermediate_cache2: torch.Tensor | None = None,
    is_k_full: bool = True,
    output: torch.Tensor | None = None,
    input_dtype: torch.dtype | None = None,
    inplace: bool = False,
) -> torch.Tensor:
    """
    This function computes a Mixture of Experts (MoE) layer using two sets of
    weights, w1 and w2, and top-k gating mechanism.

    Parameters:
    - hidden_states (torch.Tensor): The input tensor to the MoE layer.
    - w1 (torch.Tensor): The first set of expert weights.
    - w2 (torch.Tensor): The second set of expert weights.
    - w1_scale (torch.Tensor): Scale to be used for w1.
    - w2_scale (torch.Tensor): Scale to be used for w2.
    - g_idx1 (torch.Tensor|None): The first set of act_order indices.
    - g_idx2 (torch.Tensor|None): The second set of act_order indices.
    - sort_indices1 (torch.Tensor|None): The first act_order input
        permutation.
    - sort_indices2 (torch.Tensor|None): The second act_order input
        permutation.
    - topk_weights (torch.Tensor): Top-k weights.
    - topk_ids (torch.Tensor): Indices of topk-k elements.
    - w1_zeros (torch.Tensor|None): Optional zero points to be used for w1.
    - w2_zeros (torch.Tensor|None): Optional zero points to be used for w2.
    - num_bits (bool): The number of bits in expert weights quantization.

    Returns:
    - torch.Tensor: The output tensor after applying the MoE layer.
    """

    if inplace:
        assert output is None, "Conflicting request"
        assert not disable_inplace()

    quant_type = ScalarType.from_id(quant_type_id)
    assert quant_type in [
        scalar_types.uint4,
        scalar_types.uint8b128,
        scalar_types.uint4b8,
        scalar_types.float8_e4m3fn,
        scalar_types.float4_e2m1f,
    ]

    bit4_scalar_types = [
        scalar_types.uint4,
        scalar_types.uint4b8,
        scalar_types.float4_e2m1f,
    ]
    num_bits = 4 if quant_type in bit4_scalar_types else 8

    M, K = hidden_states.size()
    E = w1.size(0)
    topk = topk_ids.size(1)

    # Check constraints.
    assert w1.size(1) * 16 == K, "Hidden size mismatch w1"
    assert w2.size(2) // (num_bits // 2) == K, "Hidden size mismatch w2"
    assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
    assert w1.is_contiguous(), "Expert weights1 must be contiguous"
    assert w2.is_contiguous(), "Expert weights2 must be contiguous"
    assert hidden_states.dtype in [torch.float16, torch.bfloat16]
    assert num_bits in [4, 8]
    assert topk_weights.dtype == torch.float32

    # M block size selection logic
    # TODO: tune this further for specific models
    for block_size_m in [8, 16, 32, 48, 64]:
        if M * topk / E / block_size_m < 0.9:
            break

    if input_dtype is not None and input_dtype.itemsize == 1:
        block_size_m = max(block_size_m, 16)

    if global_num_experts == -1:
        global_num_experts = E
    sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
        topk_ids,
        block_size_m,
        global_num_experts,
        expert_map,
        ignore_invalid_experts=True,
    )

    assert activation is not None
    moe_output = _fused_marlin_moe(
        hidden_states=hidden_states,
        w1=w1,
        w2=w2,
        bias1=bias1,
        bias2=bias2,
        w1_scale=w1_scale,
        w2_scale=w2_scale,
        topk_weights=topk_weights,
        num_topk=topk,
        quant_type=quant_type,
        apply_router_weight_on_input=apply_router_weight_on_input,
        expert_map=expert_map,
        block_size_m=block_size_m,
        sorted_token_ids=sorted_token_ids,
        expert_ids=expert_ids,
        num_tokens_post_padded=num_tokens_post_padded,
        activation=activation,
        activation_func=activation_func,
        input_global_scale1=input_global_scale1,
        input_global_scale2=input_global_scale2,
        global_scale1=global_scale1,
        global_scale2=global_scale2,
        g_idx1=g_idx1,
        g_idx2=g_idx2,
        sort_indices1=sort_indices1,
        sort_indices2=sort_indices2,
        w1_zeros=w1_zeros,
        w2_zeros=w2_zeros,
        workspace=workspace,
        intermediate_cache13=intermediate_cache13,
        intermediate_cache2=intermediate_cache2,
        output=None,
        input_dtype=input_dtype,
        is_k_full=is_k_full,
    ).view(-1, topk, K)

    if output is None:
        output = hidden_states if inplace else torch.empty_like(hidden_states)

    if moe_sum is None:
        return torch.sum(moe_output.view(-1, topk, K), dim=1, out=output)
    else:
        return moe_sum(moe_output, output)