Skip to content

vllm.model_executor.models.voxtral

VoxtralForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription

Source code in vllm/model_executor/models/voxtral.py
@MULTIMODAL_REGISTRY.register_processor(
    VoxtralMultiModalProcessor,
    info=VoxtralProcessingInfo,
    dummy_inputs=VoxtralDummyInputsBuilder,
)
class VoxtralForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
):
    supported_languages = ISO639_1_SUPPORTED_LANGS
    # transformers' currently has limited support for MistralCommon backend
    # and cached_get_processor. Let's skip until fixed
    skip_warmup_audio_preprocessing = True

    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)

        # update quant config to so that ignored module and target module names
        # match the vLLM model names
        if hasattr(vllm_config, "quant_config"):
            vllm_config.quant_config = self.maybe_update_quant_config(
                vllm_config.quant_config
            )

        config = vllm_config.model_config.hf_config
        self.config = config
        self.downsample_factor = self.config.audio_config.downsample_factor

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )

        with self._mark_tower_model(vllm_config, "audio"):
            self.whisper_encoder = VoxtralEncoderModel(
                vllm_config.with_hf_config(config.audio_config),
                prefix=maybe_prefix(prefix, "whisper_encoder"),
            )
            self.audio_language_adapter = AudioLanguageAdapter(
                hidden_size=config.audio_config.d_model * self.downsample_factor,
                dim=config.text_config.hidden_size,
            )

    def get_mm_mapping(self) -> MultiModelKeys:
        """Get module prefix for multimodal models to filter LoRA modules."""
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="audio_language_adapter",
            tower_model=["whisper_encoder"],
        )

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )

        return hidden_states

    def embed_multimodal(
        self, **kwargs
    ) -> list[torch.Tensor] | torch.Tensor | tuple[torch.Tensor, ...] | None:
        audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
        if audio_inputs is None:
            return None

        audio_embeddings = self.whisper_encoder(audio_inputs)

        for i, audio_embedding in enumerate(audio_embeddings):
            seq_len, dim = audio_embedding.shape
            # Pad such that seq_len is divisible by downsample_factor
            target_seq_len = self.downsample_factor * math.ceil(
                seq_len / self.downsample_factor
            )
            audio_embedding = torch.nn.functional.pad(
                audio_embedding,
                (0, 0, 0, target_seq_len - seq_len),
            )
            audio_embeddings[i] = audio_embedding.reshape(
                target_seq_len // self.downsample_factor, dim * self.downsample_factor
            )

        # Concat, project and resplit
        audio_embeddings_packed = torch.cat(audio_embeddings, dim=0)
        audio_embeddings_packed = self.audio_language_adapter(audio_embeddings_packed)
        audio_embeddings = torch.split(
            audio_embeddings_packed, [a.shape[0] for a in audio_embeddings], dim=0
        )

        return audio_embeddings

    def _parse_and_validate_audio_arrays(
        self, **kwargs: object
    ) -> list[torch.Tensor] | None:
        audio_arrays = kwargs.pop("audio_arrays", None)
        if audio_arrays is None:
            return None

        if not isinstance(audio_arrays, (torch.Tensor, list)):
            raise ValueError(
                f"Incorrect type of audio_arrays. Got type: {type(audio_arrays)}"
            )

        if isinstance(audio_arrays, torch.Tensor):
            audio_arrays = list(audio_arrays.unbind(0))
        return audio_arrays

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    @classmethod
    def get_speech_to_text_config(
        cls, model_config: ModelConfig, task_type: str
    ) -> SpeechToTextConfig:
        tokenizer = cached_tokenizer_from_config(model_config)
        audio_config = tokenizer.instruct.audio_encoder.audio_config
        max_audio_clip_s = audio_config.chunk_length_s
        sample_rate = audio_config.sampling_rate
        return SpeechToTextConfig(
            max_audio_clip_s=max_audio_clip_s,
            sample_rate=sample_rate,
            # mistral_common and whisper encoder take care of chunking
            min_energy_split_window_size=None,
        )

    @classmethod
    # for speech-to-text transcription
    def get_generation_prompt(
        cls,
        audio: np.ndarray,
        model_config: ModelConfig,
        stt_config: SpeechToTextConfig,
        language: str | None,
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
        to_language: str | None,
    ) -> PromptType:
        tokenizer = cached_tokenizer_from_config(model_config)
        audio = Audio(audio, int(stt_config.sample_rate), format="wav")  # lossless
        req = TranscriptionRequest(
            model=model_config.model,
            audio=RawAudio.from_audio(audio),
            language=language,
        )

        tokenized = tokenizer.instruct.encode_transcription(req)

        return TokensPrompt(
            prompt_token_ids=tokenized.tokens,
            multi_modal_data={
                "audio": [
                    (audio.audio_array, stt_config.sample_rate)
                    for audio in tokenized.audios
                ],
            },
        )

    @classmethod
    def get_num_audio_tokens(
        cls,
        audio_duration_s: float,
        stt_config: SpeechToTextConfig,
        model_config: ModelConfig,
    ) -> int | None:
        """
        Map from audio duration to number of audio tokens produced by the ASR
        model, without running a forward pass.
        This is used for estimating the amount of processing for this audio.
        """
        tokenizer = cached_tokenizer_from_config(model_config)
        adapter = VoxtralProcessorAdapter(tokenizer)
        return adapter.get_num_audio_tokens(
            int(audio_duration_s * stt_config.sample_rate)
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        remapping_rules = [
            (r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"),
            (r"mm_whisper_embeddings\.(.*)", r"\1"),
            (r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
            (
                r"audio_language_adapter\.0\.weight",
                r"audio_language_adapter.w_in.weight",
            ),
            (
                r"audio_language_adapter\.2\.weight",
                r"audio_language_adapter.w_out.weight",
            ),
        ]

        audio_params = dict(
            nn.ModuleDict(
                {
                    "audio_language_adapter": self.audio_language_adapter,
                }
            ).named_parameters()
        )
        weights = _create_fake_bias_for_k_proj(weights, ".wk.weight")

        loaded_weights = set()

        def llm_weights_generator():
            nonlocal loaded_weights
            for name, w in weights:
                is_encoder = False
                for k in [
                    "mm_whisper_embeddings",
                    "mm_streams_embeddings.embedding_module",
                ]:
                    is_encoder |= (
                        name.startswith(k)
                        and not name.startswith(f"{k}.tok_embeddings")
                        and not name.startswith(f"{k}.audio_language_projection")
                    )

                for pattern, repl in remapping_rules:
                    if re.fullmatch(pattern, name):
                        name = re.sub(pattern, repl, name)

                if is_encoder:
                    name = self.whisper_encoder.load_weight((name, w))
                    loaded_weights.add(f"whisper_encoder.{name}")
                    continue

                if name in audio_params:
                    param = audio_params[name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
                    loaded_weights.add(name)
                else:
                    yield (name, w)

        for name in self.language_model.load_weights(llm_weights_generator()):
            loaded_weights.add(f"language_model.{name}")

        # potentially manually add position embeddings
        sin_key = "whisper_encoder.whisper_encoder.embed_positions.weight"
        if sin_key not in loaded_weights:
            # make sure we don't hit an error here
            loaded_weights.add(sin_key)

        return loaded_weights

    def maybe_update_quant_config(
        self, quant_config: QuantizationConfig
    ) -> QuantizationConfig:
        """
        Update quant config to so that ignored module and target module names
        match the vLLM model names.
        Right now this is specific for compressed-tensors format and
        load_format mistral.
        """
        remapping_rules = [
            (r"output", r"language_model.lm_head"),
            (
                r"layers\.(\d+)\.attention\.wo",
                r"language_model.model.layers.\1.self_attn.out_proj",
            ),
            (
                r"layers\.(\d+)\.attention\.w(.*)",
                r"language_model.model.layers.\1.self_attn.\2_proj",
            ),
            (
                r"layers\.(\d+)\.feed_forward\.w1",
                r"language_model.model.layers.\1.mlp.gate_proj",
            ),
            (
                r"layers\.(\d+)\.feed_forward\.w2",
                r"language_model.model.layers.\1.mlp.down_proj",
            ),
            (
                r"layers\.(\d+)\.feed_forward\.w3",
                r"language_model.model.layers.\1.mlp.up_proj",
            ),
            (
                r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
                r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj",
            ),
            (
                r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
                r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj",
            ),
            (
                r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
                r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2",
            ),
            (
                r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
                r"whisper_encoder.whisper_encoder.conv1",
            ),
            (
                r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1",
                r"whisper_encoder.whisper_encoder.conv2",
            ),
            (
                r"mm_whisper_embeddings\.audio_language_projection\.0",
                r"audio_language_adapter.w_in",
            ),
            (
                r"mm_whisper_embeddings\.audio_language_projection\.2",
                r"audio_language_adapter.w_out",
            ),
        ]

        # Update ignore list
        if hasattr(quant_config, "ignore"):
            mistral_ignore = []
            for name in quant_config.ignore:
                mistral_name = name
                for pattern, repl in remapping_rules:
                    if re.fullmatch(pattern, name):
                        mistral_name = re.sub(pattern, repl, name)
                mistral_ignore.append(mistral_name)
            quant_config.ignore = mistral_ignore

        # Update target list
        if hasattr(quant_config, "config_groups"):
            config_groups = quant_config.config_groups
            for group_name in config_groups:
                if "targets" in config_groups[group_name]:
                    targets = []
                    for name in config_groups[group_name]["targets"]:
                        mistral_name = name
                        for pattern, repl in remapping_rules:
                            if re.fullmatch(pattern, name):
                                mistral_name = re.sub(pattern, repl, name)
                        targets.append(mistral_name)
                config_groups[group_name]["targets"] = targets
            quant_config.config_groups = config_groups

        return quant_config

get_mm_mapping

get_mm_mapping() -> MultiModelKeys

Get module prefix for multimodal models to filter LoRA modules.

Source code in vllm/model_executor/models/voxtral.py
def get_mm_mapping(self) -> MultiModelKeys:
    """Get module prefix for multimodal models to filter LoRA modules."""
    return MultiModelKeys.from_string_field(
        language_model="language_model",
        connector="audio_language_adapter",
        tower_model=["whisper_encoder"],
    )

get_num_audio_tokens classmethod

get_num_audio_tokens(
    audio_duration_s: float,
    stt_config: SpeechToTextConfig,
    model_config: ModelConfig,
) -> int | None

Map from audio duration to number of audio tokens produced by the ASR model, without running a forward pass. This is used for estimating the amount of processing for this audio.

Source code in vllm/model_executor/models/voxtral.py
@classmethod
def get_num_audio_tokens(
    cls,
    audio_duration_s: float,
    stt_config: SpeechToTextConfig,
    model_config: ModelConfig,
) -> int | None:
    """
    Map from audio duration to number of audio tokens produced by the ASR
    model, without running a forward pass.
    This is used for estimating the amount of processing for this audio.
    """
    tokenizer = cached_tokenizer_from_config(model_config)
    adapter = VoxtralProcessorAdapter(tokenizer)
    return adapter.get_num_audio_tokens(
        int(audio_duration_s * stt_config.sample_rate)
    )

maybe_update_quant_config

maybe_update_quant_config(
    quant_config: QuantizationConfig,
) -> QuantizationConfig

Update quant config to so that ignored module and target module names match the vLLM model names. Right now this is specific for compressed-tensors format and load_format mistral.

Source code in vllm/model_executor/models/voxtral.py
def maybe_update_quant_config(
    self, quant_config: QuantizationConfig
) -> QuantizationConfig:
    """
    Update quant config to so that ignored module and target module names
    match the vLLM model names.
    Right now this is specific for compressed-tensors format and
    load_format mistral.
    """
    remapping_rules = [
        (r"output", r"language_model.lm_head"),
        (
            r"layers\.(\d+)\.attention\.wo",
            r"language_model.model.layers.\1.self_attn.out_proj",
        ),
        (
            r"layers\.(\d+)\.attention\.w(.*)",
            r"language_model.model.layers.\1.self_attn.\2_proj",
        ),
        (
            r"layers\.(\d+)\.feed_forward\.w1",
            r"language_model.model.layers.\1.mlp.gate_proj",
        ),
        (
            r"layers\.(\d+)\.feed_forward\.w2",
            r"language_model.model.layers.\1.mlp.down_proj",
        ),
        (
            r"layers\.(\d+)\.feed_forward\.w3",
            r"language_model.model.layers.\1.mlp.up_proj",
        ),
        (
            r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
            r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj",
        ),
        (
            r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
            r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj",
        ),
        (
            r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
            r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2",
        ),
        (
            r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
            r"whisper_encoder.whisper_encoder.conv1",
        ),
        (
            r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1",
            r"whisper_encoder.whisper_encoder.conv2",
        ),
        (
            r"mm_whisper_embeddings\.audio_language_projection\.0",
            r"audio_language_adapter.w_in",
        ),
        (
            r"mm_whisper_embeddings\.audio_language_projection\.2",
            r"audio_language_adapter.w_out",
        ),
    ]

    # Update ignore list
    if hasattr(quant_config, "ignore"):
        mistral_ignore = []
        for name in quant_config.ignore:
            mistral_name = name
            for pattern, repl in remapping_rules:
                if re.fullmatch(pattern, name):
                    mistral_name = re.sub(pattern, repl, name)
            mistral_ignore.append(mistral_name)
        quant_config.ignore = mistral_ignore

    # Update target list
    if hasattr(quant_config, "config_groups"):
        config_groups = quant_config.config_groups
        for group_name in config_groups:
            if "targets" in config_groups[group_name]:
                targets = []
                for name in config_groups[group_name]["targets"]:
                    mistral_name = name
                    for pattern, repl in remapping_rules:
                        if re.fullmatch(pattern, name):
                            mistral_name = re.sub(pattern, repl, name)
                    targets.append(mistral_name)
            config_groups[group_name]["targets"] = targets
        quant_config.config_groups = config_groups

    return quant_config

VoxtralProcessorAdapter

Provide a HF-compatible interface for :class:mistral_common.tokens.tokenizers.multimodal.AudioEncoder.

Source code in vllm/model_executor/models/voxtral.py
class VoxtralProcessorAdapter:
    """
    Provide a HF-compatible interface for
    :class:`mistral_common.tokens.tokenizers.multimodal.AudioEncoder`.
    """

    def __init__(self, tokenizer: MistralTokenizer) -> None:
        super().__init__()
        self.tokenizer = tokenizer

    @cached_property
    def _audio_processor(self) -> AudioEncoder:
        audio_encoder = self.tokenizer.instruct.audio_encoder
        assert isinstance(audio_encoder, AudioEncoder)
        return audio_encoder

    @cached_property
    def audio_token_id(self) -> int:
        return self._audio_processor.special_ids.audio

    @cached_property
    def begin_audio_token_id(self) -> int:
        return self._audio_processor.special_ids.begin_audio

    @cached_property
    def sampling_rate(self) -> int:
        return self._audio_processor.audio_config.sampling_rate

    @cached_property
    def frame_rate(self) -> float:
        return self._audio_processor.audio_config.frame_rate

    def get_num_audio_tokens(
        self,
        audio_length: int,
    ) -> int:
        return ceil(audio_length / (self.sampling_rate // self.frame_rate))

    def __call__(
        self,
        text: TextInput | list[TextInput] | None = None,
        audios: np.ndarray | list[np.ndarray] | None = None,
        return_tensors: str | TensorType | None = None,
        **kwargs,
    ) -> Mapping[str, NestedTensors]:
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if audios is None:
            audios = []
        if not isinstance(audios, list):
            audios = [audios]

        if not audios:
            input_ids = self.tokenizer(text).input_ids
            return {"input_ids": torch.tensor(input_ids)}

        # Allow dummy text, which is used for profiling as well as token inputs
        if any(len(t) > 0 for t in text):
            raise ValueError(
                "You've passed text inputs instead of token inputs. "
                "Make sure to process your input via `mistral_common`'s "
                "tokenizer or pass a chat completion request. "
                "For more info, see: "
                "https://github.com/vllm-project/vllm/issues/8411."
            )

        audios_tokens = list[torch.Tensor]()
        audios_processed = list[torch.Tensor]()
        for audio in audios:
            assert isinstance(audio, np.ndarray)
            assert audio.ndim == 1

            if not self._audio_processor.audio_config.is_streaming:
                audio = self._audio_processor.pad(
                    audio, self.sampling_rate, is_online_streaming=False
                )

            audio_tokens = [self.begin_audio_token_id] + [
                self.audio_token_id
            ] * self.get_num_audio_tokens(len(audio))

            audios_tokens.append(torch.tensor(audio_tokens))
            audios_processed.append(torch.tensor(audio))

        return BatchFeature(
            {
                "input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1),
                "audio_arrays": audios_processed,
            }
        )