Skip to content

vllm.model_executor.models.qwen3_omni_moe_thinker

Inference-only Qwen3-Omni-Moe model (thinker part).

Qwen3OmniMoeAudioAttention

Bases: Module

Multi-headed attention for Qwen3-Omni Audio Encoder using MMEncoderAttention.

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
class Qwen3OmniMoeAudioAttention(nn.Module):
    """Multi-headed attention for Qwen3-Omni Audio Encoder using MMEncoderAttention."""

    def __init__(
        self,
        config: Qwen3OmniMoeAudioEncoderConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.embed_dim = config.d_model
        self.num_heads = config.encoder_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        tp_size = get_tensor_model_parallel_world_size()
        self.num_local_heads = self.num_heads // tp_size

        if (self.head_dim * self.num_heads) != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: "
                f"{self.embed_dim} and `num_heads`: {self.num_heads})."
            )

        self.scaling = self.head_dim**-0.5

        self.qkv = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            total_num_kv_heads=self.num_heads,
            bias=True,
            prefix=f"{prefix}.qkv",
        )

        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            bias=True,
            prefix=f"{prefix}.out_proj",
        )

        self.attn = MMEncoderAttention(
            num_heads=self.num_local_heads,
            head_size=self.head_dim,
            scale=self.scaling,
            prefix=f"{prefix}.attn",
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        max_seqlen: torch.Tensor | None,
    ) -> torch.Tensor:
        seq_length, _ = hidden_states.size()
        qkv, _ = self.qkv(hidden_states)
        q, k, v = qkv.chunk(3, dim=-1)
        q = q.view(1, seq_length, -1, self.head_dim)
        k = k.view(1, seq_length, -1, self.head_dim)
        v = v.view(1, seq_length, -1, self.head_dim)

        attn_output = self.attn(
            query=q,
            key=k,
            value=v,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )

        attn_output = attn_output.view(seq_length, -1)
        output, _ = self.out_proj(attn_output)
        return output

Qwen3OmniMoeAudioEncoder

Bases: Module

vLLM-native Qwen3-Omni Audio Encoder.

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
class Qwen3OmniMoeAudioEncoder(nn.Module):
    """vLLM-native Qwen3-Omni Audio Encoder."""

    def __init__(
        self,
        config: Qwen3OmniMoeAudioEncoderConfig,
        prefix: str = "",
    ):
        super().__init__()

        embed_dim = config.d_model
        self.num_mel_bins = config.num_mel_bins
        self.max_source_positions = config.max_source_positions
        self.n_window = config.n_window
        self.n_window_infer = config.n_window_infer
        self.conv_chunksize = config.conv_chunksize

        # Position embedding
        self.positional_embedding = SinusoidsPositionEmbedding(
            self.max_source_positions, embed_dim
        )

        # Convolutional layers for mel-spectrogram processing
        self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1)
        self.conv2d2 = nn.Conv2d(
            config.downsample_hidden_size,
            config.downsample_hidden_size,
            3,
            2,
            padding=1,
        )
        self.conv2d3 = nn.Conv2d(
            config.downsample_hidden_size,
            config.downsample_hidden_size,
            3,
            2,
            padding=1,
        )

        conv_out_dim = config.downsample_hidden_size * (
            (((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2
        )
        self.conv_out = nn.Linear(conv_out_dim, config.d_model, bias=False)

        # Transformer encoder layers
        self.layers = nn.ModuleList(
            [
                Qwen3OmniMoeAudioEncoderLayer(
                    config,
                    prefix=f"{prefix}.layers.{i}",
                )
                for i in range(config.encoder_layers)
            ]
        )

        # Output layers
        self.ln_post = nn.LayerNorm(config.d_model)
        self.proj1 = nn.Linear(config.d_model, config.d_model)
        self.act = _ACTIVATION_REGISTRY[config.activation_function]
        self.proj2 = nn.Linear(config.d_model, config.output_dim)

        # Get attention backend
        self.attn_backend = get_vit_attn_backend(
            head_size=config.d_model // config.encoder_attention_heads,
            dtype=torch.get_default_dtype(),
        )

    def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
        """Compute max_seqlen only for flash attention backends."""
        max_seqlen = None
        if self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
        }:
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
        return max_seqlen

    @property
    def dtype(self) -> torch.dtype:
        return self.conv2d1.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.conv2d1.weight.device

    def forward(
        self,
        input_features: torch.Tensor,
        feature_lens: torch.Tensor,
        aftercnn_lens: torch.Tensor,
    ):
        # Compute chunk information
        chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()

        chunk_lengths = torch.tensor(
            [self.n_window * 2] * chunk_num.sum(),
            dtype=torch.long,
            device=feature_lens.device,
        )
        tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
        chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
        chunk_lengths[chunk_lengths == 0] = self.n_window * 2

        # Split input features into chunks and pad
        chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
        padded_feature = nn.utils.rnn.pad_sequence(
            chunk_list, batch_first=True
        ).transpose(1, 2)

        # Compute feature lengths after CNN
        feature_lens_after_cnn = self._get_cnn_output_lengths(chunk_lengths)
        # Vectorized mask creation: avoid creating many small tensors
        max_len_after_cnn = feature_lens_after_cnn.max().item()
        indices = torch.arange(max_len_after_cnn, device=padded_feature.device)
        padded_mask_after_cnn = indices.unsqueeze(0) < feature_lens_after_cnn.unsqueeze(
            1
        )

        # Add channel dimension for conv2d
        padded_feature = padded_feature.unsqueeze(1)

        # Apply convolutional layers (chunk if needed to avoid OOM)
        if padded_feature.size(0) <= self.conv_chunksize:
            # Fast path: no chunking needed
            padded_embed = F.gelu(self.conv2d1(padded_feature))
            padded_embed = F.gelu(self.conv2d2(padded_embed))
            padded_embed = F.gelu(self.conv2d3(padded_embed))
        else:
            # Chunked processing to avoid OOM
            padded_embeds = []
            for chunk in padded_feature.split(self.conv_chunksize, dim=0):
                padded_embed = F.gelu(self.conv2d1(chunk))
                padded_embed = F.gelu(self.conv2d2(padded_embed))
                padded_embed = F.gelu(self.conv2d3(padded_embed))
                padded_embeds.append(padded_embed)
            padded_embed = torch.cat(padded_embeds, dim=0)

        # (batch, channels, freq, time) -> (batch, time, channels*freq)
        b, c, f, t = padded_embed.size()
        padded_embed = self.conv_out(
            padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)
        )

        # Add positional embedding
        positional_embedding = (
            self.positional_embedding.positional_embedding[: padded_embed.shape[1], :]
            .unsqueeze(0)
            .to(padded_embed.dtype)
        )
        padded_embed = padded_embed + positional_embedding

        # Extract valid hidden states and compute cu_seqlens
        hidden_states = padded_embed[padded_mask_after_cnn]

        # Compute cumulative sequence lengths for chunked attention
        cu_chunk_lens = [0]
        window_aftercnn = padded_mask_after_cnn.shape[-1] * (
            self.n_window_infer // (self.n_window * 2)
        )
        # Use tolist() for efficient batch conversion from tensor to Python
        for cnn_len in aftercnn_lens.tolist():
            num_full_chunks = cnn_len // window_aftercnn
            remainder = cnn_len % window_aftercnn
            cu_chunk_lens.extend([window_aftercnn] * num_full_chunks)
            if remainder:
                cu_chunk_lens.append(remainder)
        cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(
            -1, dtype=torch.int32
        )

        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)

        # Apply transformer layers
        for encoder_layer in self.layers:
            hidden_states = encoder_layer(
                hidden_states,
                cu_seqlens,
                max_seqlen,
            )

        # Apply output layers
        hidden_states = self.ln_post(hidden_states)
        hidden_states = self.proj1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states = self.proj2(hidden_states)

        return hidden_states

    def _get_cnn_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor:
        """Compute output lengths after the three conv2d layers."""
        lengths = input_lengths
        for _ in range(3):
            lengths = (lengths - 1) // 2 + 1
        return lengths

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        """Load weights with mapping from HuggingFace format."""
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("self_attn.qkv.", "self_attn.q_proj.", "q"),
            ("self_attn.qkv.", "self_attn.k_proj.", "k"),
            ("self_attn.qkv.", "self_attn.v_proj.", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict.get(name)
                if param is not None:
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

_get_cnn_output_lengths

_get_cnn_output_lengths(input_lengths: Tensor) -> Tensor

Compute output lengths after the three conv2d layers.

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _get_cnn_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor:
    """Compute output lengths after the three conv2d layers."""
    lengths = input_lengths
    for _ in range(3):
        lengths = (lengths - 1) // 2 + 1
    return lengths

compute_attn_mask_seqlen

compute_attn_mask_seqlen(
    cu_seqlens: Tensor,
) -> Tensor | None

Compute max_seqlen only for flash attention backends.

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def compute_attn_mask_seqlen(self, cu_seqlens: torch.Tensor) -> torch.Tensor | None:
    """Compute max_seqlen only for flash attention backends."""
    max_seqlen = None
    if self.attn_backend in {
        AttentionBackendEnum.FLASH_ATTN,
        AttentionBackendEnum.ROCM_AITER_FA,
    }:
        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
    return max_seqlen

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]

Load weights with mapping from HuggingFace format.

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    """Load weights with mapping from HuggingFace format."""
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        ("self_attn.qkv.", "self_attn.q_proj.", "q"),
        ("self_attn.qkv.", "self_attn.k_proj.", "k"),
        ("self_attn.qkv.", "self_attn.v_proj.", "v"),
    ]
    params_dict = dict(self.named_parameters(remove_duplicate=False))
    loaded_params: set[str] = set()

    for name, loaded_weight in weights:
        for param_name, weight_name, shard_id in stacked_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)

            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            param = params_dict.get(name)
            if param is not None:
                weight_loader = getattr(
                    param, "weight_loader", default_weight_loader
                )
                weight_loader(param, loaded_weight)
        loaded_params.add(name)
    return loaded_params

Qwen3OmniMoeAudioEncoderLayer

Bases: Module

Transformer encoder layer for Qwen3-Omni Audio Encoder.

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
class Qwen3OmniMoeAudioEncoderLayer(nn.Module):
    """Transformer encoder layer for Qwen3-Omni Audio Encoder."""

    def __init__(
        self,
        config: Qwen3OmniMoeAudioEncoderConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.embed_dim = config.d_model
        self.self_attn = Qwen3OmniMoeAudioAttention(
            config, prefix=f"{prefix}.self_attn"
        )
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
        self.activation_fn = _ACTIVATION_REGISTRY[config.activation_function]
        self.fc1 = ColumnParallelLinear(
            self.embed_dim,
            config.encoder_ffn_dim,
            bias=True,
            prefix=f"{prefix}.fc1",
        )
        self.fc2 = RowParallelLinear(
            config.encoder_ffn_dim,
            self.embed_dim,
            bias=True,
            prefix=f"{prefix}.fc2",
        )
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        max_seqlen: torch.Tensor | None,
    ) -> torch.Tensor:
        """
        Args:
            hidden_states: Input tensor of shape (seq_len, hidden_size)
            cu_seqlens: Cumulative sequence lengths
            max_seqlen: Maximum sequence length in the batch
        """
        residual = hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.final_layer_norm(hidden_states)
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        hidden_states = residual + hidden_states

        # Clamp for numerical stability with fp16
        if hidden_states.dtype == torch.float16:
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(
                hidden_states, min=-clamp_value, max=clamp_value
            )

        return hidden_states

forward

forward(
    hidden_states: Tensor,
    cu_seqlens: Tensor,
    max_seqlen: Tensor | None,
) -> Tensor

Parameters:

Name Type Description Default
hidden_states Tensor

Input tensor of shape (seq_len, hidden_size)

required
cu_seqlens Tensor

Cumulative sequence lengths

required
max_seqlen Tensor | None

Maximum sequence length in the batch

required
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def forward(
    self,
    hidden_states: torch.Tensor,
    cu_seqlens: torch.Tensor,
    max_seqlen: torch.Tensor | None,
) -> torch.Tensor:
    """
    Args:
        hidden_states: Input tensor of shape (seq_len, hidden_size)
        cu_seqlens: Cumulative sequence lengths
        max_seqlen: Maximum sequence length in the batch
    """
    residual = hidden_states
    hidden_states = self.self_attn_layer_norm(hidden_states)
    hidden_states = self.self_attn(
        hidden_states=hidden_states,
        cu_seqlens=cu_seqlens,
        max_seqlen=max_seqlen,
    )
    hidden_states = residual + hidden_states

    residual = hidden_states
    hidden_states = self.final_layer_norm(hidden_states)
    hidden_states, _ = self.fc1(hidden_states)
    hidden_states = self.activation_fn(hidden_states)
    hidden_states, _ = self.fc2(hidden_states)
    hidden_states = residual + hidden_states

    # Clamp for numerical stability with fp16
    if hidden_states.dtype == torch.float16:
        clamp_value = torch.finfo(hidden_states.dtype).max - 1000
        hidden_states = torch.clamp(
            hidden_states, min=-clamp_value, max=clamp_value
        )

    return hidden_states

Qwen3OmniMoeThinkerForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsPP, SupportsMRoPE, Qwen3OmniMoeConditionalGenerationMixin, SupportsTranscription

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
@MULTIMODAL_REGISTRY.register_processor(
    Qwen3OmniMoeThinkerMultiModalProcessor,
    info=Qwen3OmniMoeThinkerProcessingInfo,
    dummy_inputs=Qwen3OmniMoeThinkerDummyInputsBuilder,
)
class Qwen3OmniMoeThinkerForConditionalGeneration(
    nn.Module,
    SupportsMultiModal,
    SupportsPP,
    SupportsMRoPE,
    Qwen3OmniMoeConditionalGenerationMixin,
    SupportsTranscription,
):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "thinker.lm_head.": "language_model.lm_head.",
            "thinker.model.": "language_model.model.",
            "thinker.": "",
        }
    )

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    supported_languages = ISO639_1_SUPPORTED_LANGS

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"
        if modality.startswith("audio"):
            return "<|audio_start|><|audio_pad|><|audio_end|>"

        raise ValueError("Only image, video or audio modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.vllm_config = vllm_config  # needed for torch compile forward context
        thinker_config: Qwen3OmniMoeThinkerConfig = (
            vllm_config.model_config.hf_config.thinker_config
        )
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = thinker_config
        self.multimodal_config = multimodal_config
        self.quant_config = quant_config

        with self._mark_tower_model(vllm_config, "audio"):
            self.audio_tower = Qwen3OmniMoeAudioEncoder(
                thinker_config.audio_config,
                prefix=maybe_prefix(prefix, "audio_tower"),
            )

        self.use_deepstack = hasattr(
            thinker_config.vision_config, "deepstack_visual_indexes"
        )
        self.deepstack_num_level = (
            len(thinker_config.vision_config.deepstack_visual_indexes)
            if self.use_deepstack
            else 0
        )
        self.visual_dim = thinker_config.vision_config.out_hidden_size
        self.multiscale_dim = self.visual_dim * self.deepstack_num_level

        with self._mark_tower_model(vllm_config, {"image", "video"}):
            self.visual = Qwen3Omni_VisionTransformer(
                vision_config=thinker_config.vision_config,
                norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "visual"),
            )

            # register buffer for deepstack
            if self.use_deepstack:
                self.deepstack_input_embeds = [
                    torch.zeros(
                        vllm_config.scheduler_config.max_num_batched_tokens,
                        thinker_config.text_config.hidden_size,
                    )
                    for _ in range(self.deepstack_num_level)
                ]

        with self._mark_language_model(vllm_config):
            self.language_model = Qwen3MoeLLMForCausalLM(
                vllm_config=vllm_config.with_hf_config(
                    thinker_config.text_config,
                    architectures=["Qwen3MoeForCausalLM"],
                ),
                prefix=maybe_prefix(prefix, "language_model"),
            )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    def _get_deepstack_input_embeds(
        self,
        num_tokens: int,
    ) -> IntermediateTensors | None:
        if not getattr(self, "deepstack_input_embeds", None):
            return None  # If vision tower is skipped

        # get deepstack_input_embeds from buffer, and clear the buffer
        return IntermediateTensors(
            {
                f"deepstack_input_embeds_{idx}": self.deepstack_input_embeds[idx][
                    :num_tokens
                ]
                for idx in range(self.deepstack_num_level)
            }
        )

    def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
        if not getattr(self, "deepstack_input_embeds", None):
            return

        # set deepstack_input_embeds to buffer
        num_tokens = deepstack_input_embeds.size(1)
        if num_tokens > self.deepstack_input_embeds[0].size(0):
            self.deepstack_input_embeds = [
                torch.zeros(
                    num_tokens,
                    self.config.text_config.hidden_size,
                    device=self.deepstack_input_embeds[0].device,
                    dtype=self.deepstack_input_embeds[0].dtype,
                )
                for _ in range(self.deepstack_num_level)
            ]
        for idx in range(self.deepstack_num_level):
            self.deepstack_input_embeds[idx][:num_tokens].copy_(
                deepstack_input_embeds[idx]
            )

    def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
        if not getattr(self, "deepstack_input_embeds", None):
            return

        # clear deepstack_input_embeds in buffer
        if num_tokens > 0:
            for idx in range(self.deepstack_num_level):
                self.deepstack_input_embeds[idx][:num_tokens].zero_()

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        mm_input_by_modality = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if (
                input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality
            ):
                mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                    **kwargs
                )
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "video" not in mm_input_by_modality
            ):
                mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                    **kwargs
                )
            if (
                input_key in ("input_audio_features")
                and "audio" not in mm_input_by_modality
            ):
                mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(
                    **kwargs
                )
        return mm_input_by_modality

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not mm_input_by_modality:
            return []

        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
                image_embeddings = self._process_image_input(multimodal_input)
                multimodal_embeddings += tuple(image_embeddings)
            if modality == "video":
                video_embeddings = self._process_video_input(multimodal_input)
                multimodal_embeddings += tuple(video_embeddings)
            if modality == "audio":
                audio_embeddings = self._process_audio_input(multimodal_input)
                multimodal_embeddings += tuple(audio_embeddings)
        return multimodal_embeddings

    def embed_input_ids(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
        handle_oov_mm_token: bool = False,
    ) -> torch.Tensor:
        inputs_embeds = self._embed_text_input_ids(
            input_ids,
            self.language_model.embed_input_ids,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        # Detect interleaved audio-in-video early, since it affects
        # both the deepstack path and the final embedding merge.
        video_token_id = self.config.video_token_id
        audio_token_id = self.config.audio_token_id
        is_video = is_multimodal & (input_ids == video_token_id)
        is_audio = is_multimodal & (input_ids == audio_token_id)
        num_video = is_video.sum().item()
        num_audio = is_audio.sum().item()

        is_interleaved = check_interleaved_audio_video(
            is_video, is_audio, num_video, num_audio
        )

        deepstack_input_embeds = None
        # split the feat dim to obtain multi-scale visual feature
        has_vision_embeddings = [
            embeddings.shape[-1] != self.config.text_config.hidden_size
            for embeddings in multimodal_embeddings
        ]
        if self.visual.deepstack_visual_indexes is not None and any(
            has_vision_embeddings
        ):
            multiscale_len = len(self.visual.deepstack_visual_indexes)
            multimodal_embeddings_multiscale = []

            if is_interleaved:
                # Use input_ids-based mask for correct vision positions
                # when audio and video tokens are interleaved.
                is_vision = is_video.clone()
            else:
                is_vision = torch.zeros_like(is_multimodal)
                mm_positions = torch.nonzero(is_multimodal, as_tuple=True)[0]
                mm_position_idx = 0

            for index, embeddings in enumerate(multimodal_embeddings):
                num_tokens = embeddings.shape[0]

                # Vision embeddings
                if embeddings.shape[-1] != self.config.text_config.hidden_size:
                    visual_dim = embeddings.shape[-1] // (multiscale_len + 1)
                    multi_dim = visual_dim * multiscale_len
                    embeddings_main, embeddings_multiscale = torch.split(
                        embeddings, [visual_dim, multi_dim], dim=-1
                    )
                    multimodal_embeddings[index] = embeddings_main
                    multimodal_embeddings_multiscale.append(embeddings_multiscale)
                    if not is_interleaved:
                        current_positions = mm_positions[
                            mm_position_idx : mm_position_idx + num_tokens
                        ]
                        is_vision[current_positions] = True

                # Audio embeddings
                else:
                    if not is_interleaved:
                        current_positions = mm_positions[
                            mm_position_idx : mm_position_idx + num_tokens
                        ]
                        is_vision[current_positions] = False

                if not is_interleaved:
                    mm_position_idx += num_tokens

            deepstack_input_embeds = inputs_embeds.new_zeros(
                inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1)
            )
            deepstack_input_embeds = _merge_multimodal_embeddings(
                inputs_embeds=deepstack_input_embeds,
                multimodal_embeddings=multimodal_embeddings_multiscale,
                is_multimodal=is_vision,
            )
            deepstack_input_embeds = (
                deepstack_input_embeds.view(
                    inputs_embeds.shape[0], multiscale_len, visual_dim
                )
                .permute(1, 0, 2)
                .contiguous()
            )
            self._set_deepstack_input_embeds(deepstack_input_embeds)

        if is_interleaved:
            return merge_interleaved_embeddings(
                inputs_embeds,
                multimodal_embeddings,
                is_video,
                is_audio,
                is_multimodal,
                num_video,
                num_audio,
            )

        # Default: standard merge (no interleaving)
        inputs_embeds = _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )

        return inputs_embeds

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        if intermediate_tensors is not None:
            inputs_embeds = None

        if inputs_embeds is not None and get_pp_group().is_first_rank:
            deepstack_input_embeds = self._get_deepstack_input_embeds(
                inputs_embeds.size(0)
            )
        else:
            deepstack_input_embeds = None

        hidden_states = self.language_model.model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
            # args for deepstack
            deepstack_input_embeds=deepstack_input_embeds,
        )

        if inputs_embeds is not None and get_pp_group().is_first_rank:
            self._clear_deepstack_input_embeds(inputs_embeds.size(0))

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=["talker.", "code2wav."],
        )
        loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

        return loaded_weights

    def _compute_audio_token_count(self, audio_feature_length: int) -> int:
        """Compute audio tokens from feature length using Qwen3-Omni formula."""
        return _get_feat_extract_output_lengths(
            torch.tensor([audio_feature_length])
        ).item()

    def _get_audio_for_video_mapping(
        self, mm_features: list[MultiModalFeatureSpec]
    ) -> tuple[dict[int, int], set[int]]:
        """
        Map video offset -> paired audio_feature_length for use_audio_in_video.

        When use_audio_in_video=True, audio is interleaved within video.
        The pairing is based on feature order in mm_features.

        Returns:
            Tuple of (video_offset -> audio_feature_length mapping,
                      set of paired audio offsets to skip)
        """
        videos_with_audio = [
            f
            for f in mm_features
            if f.modality == "video"
            and f.data.get("use_audio_in_video")
            and f.data["use_audio_in_video"].data.item()
        ]
        audios = [f for f in mm_features if f.modality == "audio"]

        mapping: dict[int, int] = {}
        paired_audio_offsets: set[int] = set()
        for i, video_f in enumerate(videos_with_audio):
            if i < len(audios):
                audio_len = audios[i].data["audio_feature_lengths"].data.item()
                mapping[video_f.mm_position.offset] = audio_len
                paired_audio_offsets.add(audios[i].mm_position.offset)
        return mapping, paired_audio_offsets

    def iter_mm_features(
        self, mm_features: list[MultiModalFeatureSpec]
    ) -> Iterator[tuple[int, str, dict[str, Any]]]:
        """
        Iterate over multimodal features sorted by position offset.

        Yields: (offset, modality, feature_data) where feature_data contains:
        - image: {"grid_t", "grid_h", "grid_w", "t_factor"}
        - video: {"grid_t", "grid_h", "grid_w", "t_factor",
                  "use_audio_in_video", "audio_feature_length"}
        - audio: {"audio_feature_length"}
        """
        config = self.config
        spatial_merge_size = config.vision_config.spatial_merge_size
        position_id_per_seconds = config.position_id_per_seconds

        sorted_features = sorted(mm_features, key=lambda f: f.mm_position.offset)
        audio_for_video, paired_audio_offsets = self._get_audio_for_video_mapping(
            sorted_features
        )

        for mm_feature in sorted_features:
            offset = mm_feature.mm_position.offset
            modality = mm_feature.modality

            if modality == "image":
                t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
                yield (
                    offset,
                    "image",
                    {
                        "grid_t": t,
                        "grid_h": h // spatial_merge_size,
                        "grid_w": w // spatial_merge_size,
                        "t_factor": position_id_per_seconds,
                    },
                )
            elif modality == "video":
                t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
                second_per_grid_ts = 2.0
                if mm_feature.data.get("second_per_grid_ts"):
                    second_per_grid_ts = mm_feature.data[
                        "second_per_grid_ts"
                    ].data.item()
                use_audio_in_video = bool(
                    mm_feature.data.get("use_audio_in_video")
                    and mm_feature.data["use_audio_in_video"].data.item()
                )

                yield (
                    offset,
                    "video",
                    {
                        "grid_t": t,
                        "grid_h": h // spatial_merge_size,
                        "grid_w": w // spatial_merge_size,
                        "t_factor": second_per_grid_ts * position_id_per_seconds,
                        "use_audio_in_video": use_audio_in_video,
                        "audio_feature_length": audio_for_video.get(offset),
                    },
                )
            elif modality == "audio":
                if offset not in paired_audio_offsets:
                    audio_len = mm_feature.data["audio_feature_lengths"].data.item()
                    yield offset, "audio", {"audio_feature_length": audio_len}

    def _compute_interleaved_positions(
        self, start_idx: int, data: dict[str, Any]
    ) -> tuple[np.ndarray, int]:
        """
        Compute positions for interleaved video+audio using Qwen3 token-by-token
        interleaving logic.

        Returns: (position_ids [3, N], total_token_count)
        """
        grid_t = data["grid_t"]
        grid_h = data["grid_h"]
        grid_w = data["grid_w"]
        t_factor = data["t_factor"]
        audio_feature_length = data["audio_feature_length"]

        audio_len = self._compute_audio_token_count(audio_feature_length)

        h_index = np.tile(
            np.arange(grid_h).reshape(1, -1, 1), (grid_t, 1, grid_w)
        ).flatten()
        w_index = np.tile(
            np.arange(grid_w).reshape(1, 1, -1), (grid_t, grid_h, 1)
        ).flatten()
        t_index_raw = np.arange(grid_t)
        t_index_scaled = (t_index_raw * t_factor).astype(np.int64)
        t_index = np.repeat(t_index_scaled, grid_h * grid_w)

        video_pos = np.stack([t_index, h_index, w_index]) + start_idx
        audio_pos = np.broadcast_to(np.arange(audio_len), (3, audio_len)) + start_idx

        video_t_values = video_pos[0]
        audio_t_values = audio_pos[0]

        pos_ids_list: list[np.ndarray] = []
        video_idx, audio_idx = 0, 0
        num_video = grid_t * grid_h * grid_w

        while video_idx < num_video and audio_idx < audio_len:
            if video_t_values[video_idx] <= audio_t_values[audio_idx]:
                pos_ids_list.append(video_pos[:, video_idx : video_idx + 1])
                video_idx += 1
            else:
                pos_ids_list.append(audio_pos[:, audio_idx : audio_idx + 1])
                audio_idx += 1

        if video_idx < num_video:
            pos_ids_list.append(video_pos[:, video_idx:])
        if audio_idx < audio_len:
            pos_ids_list.append(audio_pos[:, audio_idx:])

        total_tokens = num_video + audio_len
        return np.concatenate(pos_ids_list, axis=1), total_tokens

    @classmethod
    def get_speech_to_text_config(
        cls, model_config: ModelConfig, task_type: str
    ) -> SpeechToTextConfig:
        processor = cached_processor_from_config(
            model_config, processor_cls=Qwen3OmniMoeProcessor
        )
        return SpeechToTextConfig(
            max_audio_clip_s=processor.feature_extractor.chunk_length,
            sample_rate=processor.feature_extractor.sampling_rate,
            min_energy_split_window_size=None,
        )

    @classmethod
    def get_generation_prompt(
        cls,
        audio: np.ndarray,
        stt_config: SpeechToTextConfig,
        model_config: ModelConfig,
        language: str | None,
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
        to_language: str | None,
    ) -> PromptType:
        """
        Construct a transcription/translation prompt for Qwen3-Omni.
        """
        # Transcribe this audio [into <language>] | for transcription
        # Translate this audio [from <language> into <to_language>] | for translation
        instruction = "Transcribe" if task_type == "transcribe" else "Translate"
        instruction += " this audio"

        # Default to_language to English for translation
        if task_type == "translate" and to_language is None:
            to_language = "en"

        # Get full language names from supported_languages mapping
        full_lang_name = cls.supported_languages.get(language, "")
        full_lang_name_to = cls.supported_languages.get(to_language, "")

        if task_type == "transcribe" and full_lang_name:
            instruction += f" into {full_lang_name}"
        elif task_type == "translate":
            if full_lang_name:
                instruction += f" from {full_lang_name}"
            if full_lang_name_to:
                instruction += f" into {full_lang_name_to}"

        instruction += "."

        if request_prompt:
            instruction += f" {request_prompt}"

        processor = cached_processor_from_config(
            model_config, processor_cls=Qwen3OmniMoeProcessor
        )
        # Audio placeholder format: <|audio_start|><|audio_pad|><|audio_end|>
        audio_placeholder = "<|audio_start|><|audio_pad|><|audio_end|>"
        user_content = f"{audio_placeholder}{instruction}"

        messages = [{"role": "user", "content": user_content}]
        prompt = processor.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )

        audio_data = (audio, stt_config.sample_rate)
        prompts_dict = {"multi_modal_data": {"audio": audio_data}, "prompt": prompt}
        return cast(PromptType, prompts_dict)

    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
        mm_features: list[MultiModalFeatureSpec],
    ) -> tuple[torch.Tensor, int]:
        """Compute M-RoPE input positions using mm_features directly."""
        seq_len = len(input_tokens)

        llm_pos_ids_list: list[np.ndarray] = []
        st = 0

        for offset, modality, data in self.iter_mm_features(mm_features):
            text_len = offset - st
            st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0

            if text_len > 0:
                llm_pos_ids_list.append(
                    np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
                )
                st_idx += text_len

            bos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
            llm_pos_ids_list.append(bos_pos)
            st_idx += 1

            if modality == "audio":
                audio_tokens = self._compute_audio_token_count(
                    data["audio_feature_length"]
                )
                audio_pos = (
                    np.broadcast_to(np.arange(audio_tokens), (3, audio_tokens)) + st_idx
                )
                llm_pos_ids_list.append(audio_pos)
                st_idx = int(audio_pos.max()) + 1

                eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
                llm_pos_ids_list.append(eos_pos)
                st = offset + 1 + audio_tokens + 1

            elif modality == "image":
                grid_t = data["grid_t"]
                grid_h = data["grid_h"]
                grid_w = data["grid_w"]
                t_factor = data["t_factor"]

                grid_indices = np.indices((grid_t, grid_h, grid_w))
                if t_factor != 1.0:
                    grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
                llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)

                image_len = grid_t * grid_h * grid_w
                st_idx = int(llm_pos_ids_list[-1].max()) + 1

                eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
                llm_pos_ids_list.append(eos_pos)
                st = offset + 1 + image_len + 1

            elif modality == "video":
                grid_t = data["grid_t"]
                grid_h = data["grid_h"]
                grid_w = data["grid_w"]
                t_factor = data["t_factor"]

                if not data["use_audio_in_video"]:
                    grid_indices = np.indices((grid_t, grid_h, grid_w))
                    if t_factor != 1.0:
                        grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
                    llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)

                    video_len = grid_t * grid_h * grid_w
                    st_idx = int(llm_pos_ids_list[-1].max()) + 1

                    eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
                    llm_pos_ids_list.append(eos_pos)
                    st = offset + 1 + video_len + 1
                else:
                    audio_bos_pos = np.broadcast_to(np.array([st_idx - 1]), (3, 1))
                    llm_pos_ids_list.append(audio_bos_pos)

                    pos_ids, _ = self._compute_interleaved_positions(st_idx, data)
                    llm_pos_ids_list.append(pos_ids)
                    st_idx = int(pos_ids.max()) + 1

                    eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
                    llm_pos_ids_list.append(eos_pos)
                    llm_pos_ids_list.append(eos_pos)

                    video_len = grid_t * grid_h * grid_w
                    audio_len = self._compute_audio_token_count(
                        data["audio_feature_length"]
                    )
                    st = offset + 2 + video_len + audio_len + 2

        if st < seq_len:
            st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
            text_len = seq_len - st
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )

        llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
        if llm_positions.shape[1] != seq_len:
            raise RuntimeError("Position ids length mismatch with input ids length")

        mrope_position_delta = int(llm_positions.max()) + 1 - seq_len
        return torch.from_numpy(llm_positions), mrope_position_delta

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="visual.merger",
            tower_model=["visual.", "audio_tower."],
        )

_compute_audio_token_count

_compute_audio_token_count(
    audio_feature_length: int,
) -> int

Compute audio tokens from feature length using Qwen3-Omni formula.

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _compute_audio_token_count(self, audio_feature_length: int) -> int:
    """Compute audio tokens from feature length using Qwen3-Omni formula."""
    return _get_feat_extract_output_lengths(
        torch.tensor([audio_feature_length])
    ).item()

_compute_interleaved_positions

_compute_interleaved_positions(
    start_idx: int, data: dict[str, Any]
) -> tuple[ndarray, int]

Compute positions for interleaved video+audio using Qwen3 token-by-token interleaving logic.

Returns: (position_ids [3, N], total_token_count)

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _compute_interleaved_positions(
    self, start_idx: int, data: dict[str, Any]
) -> tuple[np.ndarray, int]:
    """
    Compute positions for interleaved video+audio using Qwen3 token-by-token
    interleaving logic.

    Returns: (position_ids [3, N], total_token_count)
    """
    grid_t = data["grid_t"]
    grid_h = data["grid_h"]
    grid_w = data["grid_w"]
    t_factor = data["t_factor"]
    audio_feature_length = data["audio_feature_length"]

    audio_len = self._compute_audio_token_count(audio_feature_length)

    h_index = np.tile(
        np.arange(grid_h).reshape(1, -1, 1), (grid_t, 1, grid_w)
    ).flatten()
    w_index = np.tile(
        np.arange(grid_w).reshape(1, 1, -1), (grid_t, grid_h, 1)
    ).flatten()
    t_index_raw = np.arange(grid_t)
    t_index_scaled = (t_index_raw * t_factor).astype(np.int64)
    t_index = np.repeat(t_index_scaled, grid_h * grid_w)

    video_pos = np.stack([t_index, h_index, w_index]) + start_idx
    audio_pos = np.broadcast_to(np.arange(audio_len), (3, audio_len)) + start_idx

    video_t_values = video_pos[0]
    audio_t_values = audio_pos[0]

    pos_ids_list: list[np.ndarray] = []
    video_idx, audio_idx = 0, 0
    num_video = grid_t * grid_h * grid_w

    while video_idx < num_video and audio_idx < audio_len:
        if video_t_values[video_idx] <= audio_t_values[audio_idx]:
            pos_ids_list.append(video_pos[:, video_idx : video_idx + 1])
            video_idx += 1
        else:
            pos_ids_list.append(audio_pos[:, audio_idx : audio_idx + 1])
            audio_idx += 1

    if video_idx < num_video:
        pos_ids_list.append(video_pos[:, video_idx:])
    if audio_idx < audio_len:
        pos_ids_list.append(audio_pos[:, audio_idx:])

    total_tokens = num_video + audio_len
    return np.concatenate(pos_ids_list, axis=1), total_tokens

_get_audio_for_video_mapping

_get_audio_for_video_mapping(
    mm_features: list[MultiModalFeatureSpec],
) -> tuple[dict[int, int], set[int]]

Map video offset -> paired audio_feature_length for use_audio_in_video.

When use_audio_in_video=True, audio is interleaved within video. The pairing is based on feature order in mm_features.

Returns:

Type Description
tuple[dict[int, int], set[int]]

Tuple of (video_offset -> audio_feature_length mapping, set of paired audio offsets to skip)

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _get_audio_for_video_mapping(
    self, mm_features: list[MultiModalFeatureSpec]
) -> tuple[dict[int, int], set[int]]:
    """
    Map video offset -> paired audio_feature_length for use_audio_in_video.

    When use_audio_in_video=True, audio is interleaved within video.
    The pairing is based on feature order in mm_features.

    Returns:
        Tuple of (video_offset -> audio_feature_length mapping,
                  set of paired audio offsets to skip)
    """
    videos_with_audio = [
        f
        for f in mm_features
        if f.modality == "video"
        and f.data.get("use_audio_in_video")
        and f.data["use_audio_in_video"].data.item()
    ]
    audios = [f for f in mm_features if f.modality == "audio"]

    mapping: dict[int, int] = {}
    paired_audio_offsets: set[int] = set()
    for i, video_f in enumerate(videos_with_audio):
        if i < len(audios):
            audio_len = audios[i].data["audio_feature_lengths"].data.item()
            mapping[video_f.mm_position.offset] = audio_len
            paired_audio_offsets.add(audios[i].mm_position.offset)
    return mapping, paired_audio_offsets

get_generation_prompt classmethod

get_generation_prompt(
    audio: ndarray,
    stt_config: SpeechToTextConfig,
    model_config: ModelConfig,
    language: str | None,
    task_type: Literal["transcribe", "translate"],
    request_prompt: str,
    to_language: str | None,
) -> PromptType

Construct a transcription/translation prompt for Qwen3-Omni.

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
@classmethod
def get_generation_prompt(
    cls,
    audio: np.ndarray,
    stt_config: SpeechToTextConfig,
    model_config: ModelConfig,
    language: str | None,
    task_type: Literal["transcribe", "translate"],
    request_prompt: str,
    to_language: str | None,
) -> PromptType:
    """
    Construct a transcription/translation prompt for Qwen3-Omni.
    """
    # Transcribe this audio [into <language>] | for transcription
    # Translate this audio [from <language> into <to_language>] | for translation
    instruction = "Transcribe" if task_type == "transcribe" else "Translate"
    instruction += " this audio"

    # Default to_language to English for translation
    if task_type == "translate" and to_language is None:
        to_language = "en"

    # Get full language names from supported_languages mapping
    full_lang_name = cls.supported_languages.get(language, "")
    full_lang_name_to = cls.supported_languages.get(to_language, "")

    if task_type == "transcribe" and full_lang_name:
        instruction += f" into {full_lang_name}"
    elif task_type == "translate":
        if full_lang_name:
            instruction += f" from {full_lang_name}"
        if full_lang_name_to:
            instruction += f" into {full_lang_name_to}"

    instruction += "."

    if request_prompt:
        instruction += f" {request_prompt}"

    processor = cached_processor_from_config(
        model_config, processor_cls=Qwen3OmniMoeProcessor
    )
    # Audio placeholder format: <|audio_start|><|audio_pad|><|audio_end|>
    audio_placeholder = "<|audio_start|><|audio_pad|><|audio_end|>"
    user_content = f"{audio_placeholder}{instruction}"

    messages = [{"role": "user", "content": user_content}]
    prompt = processor.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    audio_data = (audio, stt_config.sample_rate)
    prompts_dict = {"multi_modal_data": {"audio": audio_data}, "prompt": prompt}
    return cast(PromptType, prompts_dict)

get_mm_mapping

get_mm_mapping() -> MultiModelKeys

Get the module prefix in multimodal models

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def get_mm_mapping(self) -> MultiModelKeys:
    """
    Get the module prefix in multimodal models
    """
    return MultiModelKeys.from_string_field(
        language_model="language_model",
        connector="visual.merger",
        tower_model=["visual.", "audio_tower."],
    )

get_mrope_input_positions

get_mrope_input_positions(
    input_tokens: list[int],
    mm_features: list[MultiModalFeatureSpec],
) -> tuple[Tensor, int]

Compute M-RoPE input positions using mm_features directly.

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def get_mrope_input_positions(
    self,
    input_tokens: list[int],
    mm_features: list[MultiModalFeatureSpec],
) -> tuple[torch.Tensor, int]:
    """Compute M-RoPE input positions using mm_features directly."""
    seq_len = len(input_tokens)

    llm_pos_ids_list: list[np.ndarray] = []
    st = 0

    for offset, modality, data in self.iter_mm_features(mm_features):
        text_len = offset - st
        st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0

        if text_len > 0:
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )
            st_idx += text_len

        bos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
        llm_pos_ids_list.append(bos_pos)
        st_idx += 1

        if modality == "audio":
            audio_tokens = self._compute_audio_token_count(
                data["audio_feature_length"]
            )
            audio_pos = (
                np.broadcast_to(np.arange(audio_tokens), (3, audio_tokens)) + st_idx
            )
            llm_pos_ids_list.append(audio_pos)
            st_idx = int(audio_pos.max()) + 1

            eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
            llm_pos_ids_list.append(eos_pos)
            st = offset + 1 + audio_tokens + 1

        elif modality == "image":
            grid_t = data["grid_t"]
            grid_h = data["grid_h"]
            grid_w = data["grid_w"]
            t_factor = data["t_factor"]

            grid_indices = np.indices((grid_t, grid_h, grid_w))
            if t_factor != 1.0:
                grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
            llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)

            image_len = grid_t * grid_h * grid_w
            st_idx = int(llm_pos_ids_list[-1].max()) + 1

            eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
            llm_pos_ids_list.append(eos_pos)
            st = offset + 1 + image_len + 1

        elif modality == "video":
            grid_t = data["grid_t"]
            grid_h = data["grid_h"]
            grid_w = data["grid_w"]
            t_factor = data["t_factor"]

            if not data["use_audio_in_video"]:
                grid_indices = np.indices((grid_t, grid_h, grid_w))
                if t_factor != 1.0:
                    grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
                llm_pos_ids_list.append(grid_indices.reshape(3, -1) + st_idx)

                video_len = grid_t * grid_h * grid_w
                st_idx = int(llm_pos_ids_list[-1].max()) + 1

                eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
                llm_pos_ids_list.append(eos_pos)
                st = offset + 1 + video_len + 1
            else:
                audio_bos_pos = np.broadcast_to(np.array([st_idx - 1]), (3, 1))
                llm_pos_ids_list.append(audio_bos_pos)

                pos_ids, _ = self._compute_interleaved_positions(st_idx, data)
                llm_pos_ids_list.append(pos_ids)
                st_idx = int(pos_ids.max()) + 1

                eos_pos = np.broadcast_to(np.array([st_idx]), (3, 1))
                llm_pos_ids_list.append(eos_pos)
                llm_pos_ids_list.append(eos_pos)

                video_len = grid_t * grid_h * grid_w
                audio_len = self._compute_audio_token_count(
                    data["audio_feature_length"]
                )
                st = offset + 2 + video_len + audio_len + 2

    if st < seq_len:
        st_idx = int(llm_pos_ids_list[-1].max()) + 1 if llm_pos_ids_list else 0
        text_len = seq_len - st
        llm_pos_ids_list.append(
            np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
        )

    llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
    if llm_positions.shape[1] != seq_len:
        raise RuntimeError("Position ids length mismatch with input ids length")

    mrope_position_delta = int(llm_positions.max()) + 1 - seq_len
    return torch.from_numpy(llm_positions), mrope_position_delta

iter_mm_features

iter_mm_features(
    mm_features: list[MultiModalFeatureSpec],
) -> Iterator[tuple[int, str, dict[str, Any]]]

Iterate over multimodal features sorted by position offset.

Yields: (offset, modality, feature_data) where feature_data contains: - image: {"grid_t", "grid_h", "grid_w", "t_factor"} - video: {"grid_t", "grid_h", "grid_w", "t_factor", "use_audio_in_video", "audio_feature_length"} - audio: {"audio_feature_length"}

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def iter_mm_features(
    self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, str, dict[str, Any]]]:
    """
    Iterate over multimodal features sorted by position offset.

    Yields: (offset, modality, feature_data) where feature_data contains:
    - image: {"grid_t", "grid_h", "grid_w", "t_factor"}
    - video: {"grid_t", "grid_h", "grid_w", "t_factor",
              "use_audio_in_video", "audio_feature_length"}
    - audio: {"audio_feature_length"}
    """
    config = self.config
    spatial_merge_size = config.vision_config.spatial_merge_size
    position_id_per_seconds = config.position_id_per_seconds

    sorted_features = sorted(mm_features, key=lambda f: f.mm_position.offset)
    audio_for_video, paired_audio_offsets = self._get_audio_for_video_mapping(
        sorted_features
    )

    for mm_feature in sorted_features:
        offset = mm_feature.mm_position.offset
        modality = mm_feature.modality

        if modality == "image":
            t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
            yield (
                offset,
                "image",
                {
                    "grid_t": t,
                    "grid_h": h // spatial_merge_size,
                    "grid_w": w // spatial_merge_size,
                    "t_factor": position_id_per_seconds,
                },
            )
        elif modality == "video":
            t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
            second_per_grid_ts = 2.0
            if mm_feature.data.get("second_per_grid_ts"):
                second_per_grid_ts = mm_feature.data[
                    "second_per_grid_ts"
                ].data.item()
            use_audio_in_video = bool(
                mm_feature.data.get("use_audio_in_video")
                and mm_feature.data["use_audio_in_video"].data.item()
            )

            yield (
                offset,
                "video",
                {
                    "grid_t": t,
                    "grid_h": h // spatial_merge_size,
                    "grid_w": w // spatial_merge_size,
                    "t_factor": second_per_grid_ts * position_id_per_seconds,
                    "use_audio_in_video": use_audio_in_video,
                    "audio_feature_length": audio_for_video.get(offset),
                },
            )
        elif modality == "audio":
            if offset not in paired_audio_offsets:
                audio_len = mm_feature.data["audio_feature_lengths"].data.item()
                yield offset, "audio", {"audio_feature_length": audio_len}

Qwen3OmniMoeThinkerMultiModalProcessor

Bases: Qwen2_5OmniThinkerMultiModalProcessor

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
class Qwen3OmniMoeThinkerMultiModalProcessor(
    Qwen2_5OmniThinkerMultiModalProcessor,
):
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        mm_data = dict(mm_data)
        audios = mm_data.pop("audios", [])

        def pad_to_hop_length(x: np.ndarray, hop_length: int) -> np.ndarray:
            length = x.shape[-1]
            if length % hop_length != 0:
                pad_length = hop_length - (length % hop_length)
                x = np.pad(x, (0, pad_length), mode="constant", constant_values=0)
            return x

        # NOTE: WhisperFeatureExtractor cannot handle empty list of audios
        feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
        hop_length = feature_extractor.hop_length
        if audios:
            # NOTE: Qwen3-Omni processor accept "audio"
            # To make sure the cache works with padding=True, we pre-padded
            # the audio to multiple of hop_length.
            mm_data["audio"] = [
                pad_to_hop_length(audio, hop_length)
                if isinstance(audio, np.ndarray)
                else (pad_to_hop_length(audio[0], hop_length), audio[1])
                for audio in audios
            ]

            # TODO(Isotr0py): Remove this patch after upstream fix PR
            # released and Transformers version update:
            # https://github.com/huggingface/transformers/pull/41473
            mm_kwargs = dict(mm_kwargs)
            tok_kwargs = dict(tok_kwargs)
            mm_kwargs["audio_kwargs"] = dict(mm_kwargs.get("audio_kwargs") or {})
            mm_kwargs["text_kwargs"] = dict(mm_kwargs.get("text_kwargs") or {})
            if Version(TRANSFORMERS_VERSION) < Version("4.58.0"):
                # Extract audio_sample_rate before restructuring
                audio_sample_rate = mm_kwargs.pop("audio_sample_rate", None)

                # move truncation to audio_kwargs level to avoid conflict
                # with tok_kwargs
                mm_kwargs["audio_kwargs"].setdefault(
                    "truncation", mm_kwargs.pop("truncation", False)
                )
                mm_kwargs["text_kwargs"].setdefault(
                    "truncation", tok_kwargs.pop("truncation", False)
                )

                # Validate and conditionally pass audio_sample_rate
                # WhisperFeatureExtractor has a fixed sampling rate, and vLLM's
                # audio loader already resamples audio to the target rate.
                # Only pass the value if it matches to avoid unexpected behavior.
                if audio_sample_rate is not None:
                    expected_sr = feature_extractor.sampling_rate
                    if audio_sample_rate != expected_sr:
                        logger.warning(
                            "[%s] audio_sample_rate mismatch: user provided %dHz "
                            "but model expects %dHz. Ignoring user value. "
                            "vLLM's audio loader already resampled to %dHz.",
                            self.__class__.__name__,
                            audio_sample_rate,
                            expected_sr,
                            expected_sr,
                        )
                    else:
                        # Sample rate matches, safe to pass
                        mm_kwargs["audio_kwargs"]["audio_sample_rate"] = (
                            audio_sample_rate
                        )

        hf_inputs = super()._call_hf_processor(
            prompt=prompt,
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
            tok_kwargs=tok_kwargs,
        )

        if (
            "audio_feature_lengths" in hf_inputs
            and "feature_attention_mask" in hf_inputs
            and (audios := mm_data.get("audio", []))
        ):
            audio_num_frames = []
            for _, audio in enumerate(audios):
                audio_length = len(audio[0]) if isinstance(audio, tuple) else len(audio)
                num_frame = (
                    (audio_length // hop_length)
                    if audio_length % hop_length == 0
                    else (audio_length // hop_length - 1)
                )
                if mm_kwargs.get("truncation", False):
                    num_frame = min(
                        num_frame, feature_extractor.n_samples // hop_length
                    )
                audio_num_frames.append(num_frame)
            hf_inputs["feature_attention_mask"] = [
                torch.ones(num_frame) for num_frame in audio_num_frames
            ]
            hf_inputs["audio_feature_lengths"] = torch.tensor(audio_num_frames)
        return hf_inputs

    def _maybe_apply_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        prompt_ids: list[int],
        mm_kwargs: MultiModalKwargsItems,
        mm_prompt_updates: MultiModalPromptUpdates,
        is_update_applied: bool,
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
        """
        Qwen3-Omni reimplements this function to handle `use_audio_in_video`.
        """
        mm_item_counts = mm_items.get_all_counts()
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

        use_audio_in_video = False
        if "video" in mm_kwargs:
            for item in mm_kwargs["video"]:
                if item and item["use_audio_in_video"].data:
                    use_audio_in_video = True
                else:
                    use_audio_in_video = False

        # normal case with `use_audio_in_video=False`
        if is_update_applied:
            mm_placeholders = self._find_mm_placeholders(
                prompt_ids,
                mm_prompt_updates,
            )
            self._validate_mm_placeholders(
                mm_placeholders,
                mm_item_counts,
            )
        else:
            if use_audio_in_video and "audio" in mm_prompt_updates:
                filtered_updates = {
                    k: v for k, v in mm_prompt_updates.items() if k != "audio"
                }
                prompt_ids, mm_placeholders = self._apply_prompt_updates(
                    prompt_ids,
                    filtered_updates,
                )
                # Derive audio placeholders from video placeholders
                mm_placeholders = self._derive_audio_from_video_placeholders(
                    mm_placeholders, mm_prompt_updates
                )
            else:
                prompt_ids, mm_placeholders = self._apply_prompt_updates(
                    prompt_ids,
                    mm_prompt_updates,
                )

            self._validate_mm_placeholders(
                mm_placeholders,
                mm_item_counts,
            )

        return prompt_ids, mm_placeholders

    def get_updates_use_audio_in_video(
        self,
        thinker_config: PretrainedConfig,
        audio_len: int,
        video_grid_thw: list[int] | torch.Tensor,
        video_second_per_grid_t: float,
    ) -> list[int]:
        shift = 0
        audio_token_id = thinker_config.audio_token_id
        video_token_id = thinker_config.video_token_id
        audio_start_token_id = thinker_config.audio_start_token_id
        audio_end_token_id = thinker_config.audio_end_token_id
        spatial_merge_size = thinker_config.vision_config.spatial_merge_size
        position_id_per_seconds = thinker_config.position_id_per_seconds
        audio_token_indices = np.arange(next(iter([audio_len])))
        curr_video_grid_thw = next(iter([video_grid_thw]))
        height = curr_video_grid_thw[1] // spatial_merge_size
        width = curr_video_grid_thw[2] // spatial_merge_size
        video_token_indices = np.arange(curr_video_grid_thw[0]).reshape(-1, 1, 1)
        video_token_indices = np.broadcast_to(
            video_token_indices, (video_token_indices.shape[0], height, width)
        ).reshape(-1)
        video_token_indices = (
            (video_token_indices + shift)
            * next(iter([video_second_per_grid_t]))
            * position_id_per_seconds
        )
        video_data_index, audio_data_index = 0, 0
        updates = [audio_start_token_id]
        while video_data_index < len(video_token_indices) and audio_data_index < len(
            audio_token_indices
        ):
            if (
                video_token_indices[video_data_index]
                <= audio_token_indices[audio_data_index]
            ):
                updates += [video_token_id]
                video_data_index += 1
            else:
                updates += [audio_token_id]
                audio_data_index += 1
        if video_data_index < len(video_token_indices):
            updates += [video_token_id] * (len(video_token_indices) - video_data_index)
        if audio_data_index < len(audio_token_indices):
            updates += [audio_token_id] * (len(audio_token_indices) - audio_data_index)
        updates += [audio_end_token_id]
        return updates

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        tokenizer = self.info.get_tokenizer()
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
        vocab = tokenizer.get_vocab()

        audio_token = processor.audio_token
        image_token = processor.image_token
        video_token = processor.video_token
        audio_token_id = vocab[audio_token]
        image_token_id = vocab[image_token]
        video_token_id = vocab[video_token]

        out_mm_data = out_mm_kwargs.get_data()
        audio_feature_lengths = out_mm_data.get("audio_feature_lengths")
        feature_attention_mask = out_mm_data.get("feature_attention_mask")
        if audio_feature_lengths is None and feature_attention_mask is None:
            audio_output_lengths = []
        elif audio_feature_lengths is not None:
            audio_output_lens = _get_feat_extract_output_lengths(audio_feature_lengths)
            audio_output_lengths = audio_output_lens.tolist()
        elif feature_attention_mask is not None:
            assert isinstance(feature_attention_mask, torch.Tensor)
            audio_output_lens = _get_feat_extract_output_lengths(
                feature_attention_mask.sum(-1)
            )
            audio_output_lengths = audio_output_lens.tolist()

        # number of audios read from video.
        audio_in_video_item_idx = 0
        audio_item_idx = 0

        def get_replacement_qwen2_audio(item_idx: int):
            nonlocal audio_item_idx
            item_idx += audio_in_video_item_idx

            audio_item_idx += 1

            num_features = audio_output_lengths[item_idx]
            if num_features == 0:
                audios = mm_items.get_items("audio", AudioProcessorItems)
                audio = audios.get(item_idx)
                raise ValueError(
                    f"The audio {audio} (len={len(audio)}) is too short "
                    "to be represented inside the model"
                )

            return [audio_token_id] * num_features

        def get_replacement_qwen2_vision(item_idx: int, modality: str):
            grid_thw = out_mm_data[f"{modality}_grid_thw"][item_idx]
            assert isinstance(grid_thw, torch.Tensor)
            merge_length = image_processor.merge_size**2

            token_id = image_token_id if modality == "image" else video_token_id
            return [token_id] * (int(grid_thw.prod()) // merge_length)

        use_audio_in_video = hf_processor_mm_kwargs.get("use_audio_in_video", False)
        thinker_config = self.info.get_hf_config()

        def get_replacement_qwen2_use_audio_in_video(item_idx: int):
            nonlocal audio_in_video_item_idx
            audio_num_features = audio_output_lengths[
                audio_in_video_item_idx + item_idx
            ]
            video_grid_thw = out_mm_data["video_grid_thw"][item_idx]

            audio_in_video_item_idx += 1

            second_per_grid_ts = hf_processor_mm_kwargs.get("second_per_grid_ts", None)
            if second_per_grid_ts:
                video_second_per_grid_t = second_per_grid_ts[item_idx]
            else:
                video_second_per_grid_t = 2.0

            placeholder = self.get_updates_use_audio_in_video(
                thinker_config=thinker_config,
                audio_len=audio_num_features,
                video_grid_thw=video_grid_thw,
                video_second_per_grid_t=video_second_per_grid_t,
            )
            return PromptUpdateDetails.select_token_id(
                placeholder, embed_token_id=video_token_id
            )

        video_replacement_fn = (
            get_replacement_qwen2_use_audio_in_video
            if use_audio_in_video
            else partial(get_replacement_qwen2_vision, modality="video")
        )

        return [
            PromptReplacement(
                modality="audio",
                target=audio_token,
                replacement=get_replacement_qwen2_audio,
            ),
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=partial(get_replacement_qwen2_vision, modality="image"),
            ),
            PromptReplacement(
                modality="video",
                target=video_token,
                replacement=video_replacement_fn,
            ),
        ]

    def _derive_audio_from_video_placeholders(
        self,
        placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
        mm_prompt_updates: MultiModalPromptUpdates,
    ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
        """
        Helper to derive audio placeholders from video placeholders when
        use_audio_in_video=True.
        """
        if "video" not in placeholders:
            return placeholders

        # Validate audio and video counts match
        num_videos = len(placeholders["video"])
        num_audios = len(mm_prompt_updates.get("audio", []))
        if num_audios != num_videos:
            raise ValueError(
                f"use_audio_in_video requires equal number of audio and video items, "
                f"got {num_audios=}, {num_videos=}"
            )

        tokenizer = self.info.get_tokenizer()
        processor = self.info.get_hf_processor()
        audio_token_id = tokenizer.get_vocab()[processor.audio_token]

        result_placeholders = dict(placeholders)
        audio_placeholders = []

        # Each video is paired with one audio
        for video_idx, video_placeholder in enumerate(placeholders["video"]):
            # Create is_embed mask selecting only audio tokens
            audio_is_embed = torch.tensor(video_placeholder.tokens) == audio_token_id

            audio_placeholder = PlaceholderFeaturesInfo(
                modality="audio",
                item_idx=video_idx,
                start_idx=video_placeholder.start_idx,
                tokens=video_placeholder.tokens,
                is_embed=audio_is_embed,
            )
            audio_placeholders.append(audio_placeholder)

        result_placeholders["audio"] = audio_placeholders
        return result_placeholders

    def _get_raw_input_ids(
        self,
        token_ids: list[int],
        use_audio_in_video: bool = False,
    ) -> list[int]:
        tokenizer = self.info.get_tokenizer()
        vision_bos_token = tokenizer.encode(tokenizer.vision_bos_token)[0]
        vision_eos_token = tokenizer.encode(tokenizer.vision_eos_token)[0]
        audio_bos_token = tokenizer.encode(tokenizer.audio_bos_token)[0]
        audio_eos_token = tokenizer.encode(tokenizer.audio_eos_token)[0]
        audio_token = tokenizer.encode("<|audio_pad|>")[0]
        image_token = tokenizer.encode("<|image_pad|>")[0]
        video_token = tokenizer.encode("<|video_pad|>")[0]

        result = token_ids[:]
        if use_audio_in_video:
            while True:
                start = None
                for i in range(len(result) - 1):
                    if result[i : i + 2] == [vision_bos_token, audio_bos_token]:
                        start = i
                        break
                if start is not None:
                    end = None
                    for i in range(start + 2, len(result) - 1):
                        if result[i : i + 2] == [audio_eos_token, vision_eos_token]:
                            end = i
                            break
                    if end is not None:
                        result = (
                            result[:start]
                            + [vision_bos_token, video_token, vision_eos_token]
                            + result[end + 2 :]
                        )
                else:
                    break

        for mm_token in [audio_token, image_token, video_token]:
            compressed = []
            for x in result:
                if x != mm_token or (not compressed or compressed[-1] != mm_token):
                    compressed.append(x)
            result = compressed

        return result

_derive_audio_from_video_placeholders

_derive_audio_from_video_placeholders(
    placeholders: Mapping[
        str, list[PlaceholderFeaturesInfo]
    ],
    mm_prompt_updates: MultiModalPromptUpdates,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]

Helper to derive audio placeholders from video placeholders when use_audio_in_video=True.

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _derive_audio_from_video_placeholders(
    self,
    placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
    mm_prompt_updates: MultiModalPromptUpdates,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
    """
    Helper to derive audio placeholders from video placeholders when
    use_audio_in_video=True.
    """
    if "video" not in placeholders:
        return placeholders

    # Validate audio and video counts match
    num_videos = len(placeholders["video"])
    num_audios = len(mm_prompt_updates.get("audio", []))
    if num_audios != num_videos:
        raise ValueError(
            f"use_audio_in_video requires equal number of audio and video items, "
            f"got {num_audios=}, {num_videos=}"
        )

    tokenizer = self.info.get_tokenizer()
    processor = self.info.get_hf_processor()
    audio_token_id = tokenizer.get_vocab()[processor.audio_token]

    result_placeholders = dict(placeholders)
    audio_placeholders = []

    # Each video is paired with one audio
    for video_idx, video_placeholder in enumerate(placeholders["video"]):
        # Create is_embed mask selecting only audio tokens
        audio_is_embed = torch.tensor(video_placeholder.tokens) == audio_token_id

        audio_placeholder = PlaceholderFeaturesInfo(
            modality="audio",
            item_idx=video_idx,
            start_idx=video_placeholder.start_idx,
            tokens=video_placeholder.tokens,
            is_embed=audio_is_embed,
        )
        audio_placeholders.append(audio_placeholder)

    result_placeholders["audio"] = audio_placeholders
    return result_placeholders

_maybe_apply_prompt_updates

_maybe_apply_prompt_updates(
    mm_items: MultiModalDataItems,
    prompt_ids: list[int],
    mm_kwargs: MultiModalKwargsItems,
    mm_prompt_updates: MultiModalPromptUpdates,
    is_update_applied: bool,
) -> tuple[
    list[int],
    str,
    Mapping[str, list[PlaceholderFeaturesInfo]],
]

Qwen3-Omni reimplements this function to handle use_audio_in_video.

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _maybe_apply_prompt_updates(
    self,
    mm_items: MultiModalDataItems,
    prompt_ids: list[int],
    mm_kwargs: MultiModalKwargsItems,
    mm_prompt_updates: MultiModalPromptUpdates,
    is_update_applied: bool,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
    """
    Qwen3-Omni reimplements this function to handle `use_audio_in_video`.
    """
    mm_item_counts = mm_items.get_all_counts()
    self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

    use_audio_in_video = False
    if "video" in mm_kwargs:
        for item in mm_kwargs["video"]:
            if item and item["use_audio_in_video"].data:
                use_audio_in_video = True
            else:
                use_audio_in_video = False

    # normal case with `use_audio_in_video=False`
    if is_update_applied:
        mm_placeholders = self._find_mm_placeholders(
            prompt_ids,
            mm_prompt_updates,
        )
        self._validate_mm_placeholders(
            mm_placeholders,
            mm_item_counts,
        )
    else:
        if use_audio_in_video and "audio" in mm_prompt_updates:
            filtered_updates = {
                k: v for k, v in mm_prompt_updates.items() if k != "audio"
            }
            prompt_ids, mm_placeholders = self._apply_prompt_updates(
                prompt_ids,
                filtered_updates,
            )
            # Derive audio placeholders from video placeholders
            mm_placeholders = self._derive_audio_from_video_placeholders(
                mm_placeholders, mm_prompt_updates
            )
        else:
            prompt_ids, mm_placeholders = self._apply_prompt_updates(
                prompt_ids,
                mm_prompt_updates,
            )

        self._validate_mm_placeholders(
            mm_placeholders,
            mm_item_counts,
        )

    return prompt_ids, mm_placeholders

SinusoidsPositionEmbedding

Bases: Module

Sinusoidal position embedding for audio encoder.

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
class SinusoidsPositionEmbedding(nn.Module):
    """Sinusoidal position embedding for audio encoder."""

    def __init__(self, length: int, channels: int, max_timescale: int = 10000):
        super().__init__()
        self.length = length
        self.channels = channels
        self.max_timescale = max_timescale

        if channels % 2 != 0:
            raise ValueError("SinusoidsPositionEmbedding needs even channels input")

        log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
        inv_timescales = torch.exp(
            -log_timescale_increment * torch.arange(channels // 2).float()
        )
        scaled_time = (
            torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
        )
        positional_embedding = torch.cat(
            [torch.sin(scaled_time), torch.cos(scaled_time)], dim=1
        )
        self.register_buffer(
            "positional_embedding", positional_embedding, persistent=False
        )

    def forward(self, seqlen: int) -> torch.Tensor:
        return self.positional_embedding[:seqlen, :]