Skip to content

vllm.model_executor.models.granite_speech

Inference-only IBM Granite speech model.

GraniteSpeechAudioInputs

Bases: TensorSchema

Audio input features for Granite Speech model.

Dimensions
  • b: Batch size
  • fi: Number of input features from the Mel spectrogram.
  • fo: Number of output features, i.e. the embedding size.
  • 160: Fixed feature dimension for Mel spectrogram features
Source code in vllm/model_executor/models/granite_speech.py
class GraniteSpeechAudioInputs(TensorSchema):
    """
    Audio input features for Granite Speech model.

    Dimensions:
        - b: Batch size
        - fi: Number of input features from the Mel spectrogram.
        - fo: Number of output features, i.e. the embedding size.
        - 160: Fixed feature dimension for Mel spectrogram features
    """

    input_features: Annotated[torch.Tensor, TensorShape("b", "fi", 160)]
    """Audio input features."""

    input_features_mask: Annotated[torch.Tensor, TensorShape("b", "fo")]
    """Mask for variable length audio features."""

    audio_embed_sizes: Annotated[list[int], TensorShape("b")]
    """List of audio embedding sizes for each item in batch."""

audio_embed_sizes instance-attribute

audio_embed_sizes: Annotated[list[int], TensorShape(b)]

List of audio embedding sizes for each item in batch.

input_features instance-attribute

input_features: Annotated[Tensor, TensorShape(b, fi, 160)]

Audio input features.

input_features_mask instance-attribute

input_features_mask: Annotated[Tensor, TensorShape(b, fo)]

Mask for variable length audio features.

GraniteSpeechCTCEncoder

Bases: Module

CTC Encoder comprising conformer blocks and additional linear layers.

Source code in vllm/model_executor/models/granite_speech.py
class GraniteSpeechCTCEncoder(nn.Module):
    """CTC Encoder comprising conformer blocks and additional linear layers."""

    def __init__(
        self,
        config: PretrainedConfig,
        prefix: str,
        quant_config: QuantizationConfig | None = None,
    ):
        super().__init__()
        self.config = config

        # Precompute clamped relative positional encoding distances
        seq = torch.arange(config.context_size)
        relpos_dist = seq.view(-1, 1) - seq.view(1, -1)
        self.attention_dists = (
            torch.clamp(relpos_dist, -config.context_size, config.context_size)
            + config.max_pos_emb
        )

        self.input_linear = nn.Linear(config.input_dim, config.hidden_dim, bias=True)
        self.layers = nn.ModuleList(
            [
                GraniteSpeechConformerBlock(
                    config,
                    prefix=f"{prefix}.layers.{idx}",
                )
                for idx in range(config.num_layers)
            ]
        )

        self.out = ColumnParallelLinear(
            input_size=config.hidden_dim,
            output_size=config.output_dim,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.out",
        )

        self.out_mid = RowParallelLinear(
            input_size=config.output_dim,
            output_size=config.hidden_dim,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.out_mid",
        )
        self.softmax = nn.Softmax(dim=-1)
        self.num_layers = config.num_layers

    def forward(self, hidden_states: torch.Tensor):
        hidden_states = self.input_linear(hidden_states)
        for idx, layer in enumerate(self.layers, start=1):
            hidden_states = layer(hidden_states, attention_dists=self.attention_dists)

            if idx == self.num_layers // 2:
                hidden_states_mid = hidden_states.clone()
                hidden_states_mid, _ = self.out(hidden_states_mid)
                hidden_states_mid = self.softmax(hidden_states_mid)
                hidden_states_mid, _ = self.out_mid(hidden_states_mid)
                hidden_states += hidden_states_mid
        return hidden_states

GraniteSpeechConformerAttention

Bases: Module

Attention for conformer blocks using Shaw's relative positional embeddings. See the following paper for more details.

Source code in vllm/model_executor/models/granite_speech.py
class GraniteSpeechConformerAttention(nn.Module):
    """Attention for conformer blocks using Shaw's relative positional
    embeddings. See the following [paper](https://arxiv.org/pdf/1803.02155)
    for more details.
    """

    def __init__(self, config: PretrainedConfig, prefix: str = ""):
        super().__init__()

        inner_dim = config.dim_head * config.num_heads
        self.max_pos_emb = config.max_pos_emb
        self.context_size = config.context_size
        self.num_heads = config.num_heads
        self.dim_head = config.dim_head
        self.scale = self.dim_head**-0.5
        self.pre_norm = nn.LayerNorm(config.hidden_dim)
        self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, config.hidden_dim)
        self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head)

        if self.context_size <= 0 or self.context_size > self.max_pos_emb:
            raise ValueError(
                f"Context size should be > 0 and "
                f"<= max_pos_emb ({self.max_pos_emb}), "
                f"got {self.context_size}."
            )

    def forward(
        self, hidden_states: torch.Tensor, attention_dists: torch.Tensor
    ) -> torch.Tensor:
        hidden_states = self.pre_norm(hidden_states)
        bsz, num_features, _ = hidden_states.shape

        num_blocks = math.ceil(num_features / self.context_size)
        remainder = num_features % self.context_size
        if remainder > 0:
            # right padding to reach block size
            hidden_states = torch.nn.functional.pad(
                hidden_states, (0, 0, 0, self.context_size - remainder)
            )

        # NOTE: would be nice to try to use qkvparallellinear
        # here for this block attention implementation if possible
        query_states = self.to_q(hidden_states)
        key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1)

        query_states = query_states.reshape(
            bsz, num_blocks, self.context_size, self.num_heads, -1
        ).transpose(2, 3)
        key_states = key_states.reshape(
            bsz, num_blocks, self.context_size, self.num_heads, -1
        ).transpose(2, 3)
        value_states = value_states.reshape(
            bsz, num_blocks, self.context_size, self.num_heads, -1
        ).transpose(2, 3)

        # shaw's relative positional embedding
        dist = attention_dists.to(hidden_states.device)
        rel_pos_emb = self.rel_pos_emb(dist)
        rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
        pos_attn = (
            torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1)
            * self.scale
        )

        if remainder > 0:
            # masked attention in the extended block
            mask = torch.ones(
                self.context_size,
                self.context_size,
                dtype=bool,
                device=hidden_states.device,
            )
            mask[:remainder, :remainder] = 0
            mask_value = -torch.finfo(pos_attn.dtype).max
            pos_attn[:, -1, :].masked_fill_(mask, mask_value)

        with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
            out = F.scaled_dot_product_attention(
                query_states,
                key_states,
                value_states,
                attn_mask=pos_attn,
                scale=self.scale,
            )
        out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1)
        return self.to_out(out[:, :num_features, :])

GraniteSpeechConformerBlock

Bases: Module

Conformer block, consisting largely of linear layers, attention, and convolutional layers.

Source code in vllm/model_executor/models/granite_speech.py
class GraniteSpeechConformerBlock(nn.Module):
    """Conformer block, consisting largely of linear layers,
    attention, and convolutional layers."""

    def __init__(self, config: PretrainedConfig, prefix: str = ""):
        super().__init__()
        self.ff1 = GraniteSpeechConformerFeedForward(config, prefix=f"{prefix}.ff1")
        self.attn = GraniteSpeechConformerAttention(config, prefix=f"{prefix}.attn")
        self.conv = GraniteSpeechConformerConvModule(config, prefix=f"{prefix}.conv")
        self.ff2 = GraniteSpeechConformerFeedForward(config, prefix=f"{prefix}.ff2")
        self.post_norm = nn.LayerNorm(config.hidden_dim)

    def forward(
        self, hidden_states: torch.Tensor, attention_dists: torch.Tensor
    ) -> torch.Tensor:
        hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states
        hidden_states = (
            self.attn(hidden_states, attention_dists=attention_dists) + hidden_states
        )
        hidden_states = self.conv(hidden_states) + hidden_states
        hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states
        hidden_states = self.post_norm(hidden_states)
        return hidden_states

GraniteSpeechConformerConvModule

Bases: Module

Conformer conv module consisting of several 1D/depthwise 1D convolutional layers.

Source code in vllm/model_executor/models/granite_speech.py
class GraniteSpeechConformerConvModule(nn.Module):
    """Conformer conv module consisting of several 1D/depthwise 1D
    convolutional layers.
    """

    def __init__(self, config: PretrainedConfig, prefix: str = ""):
        super().__init__()
        inner_dim = config.hidden_dim * config.conv_expansion_factor

        self.norm = nn.LayerNorm(config.hidden_dim)
        self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1)
        self.glu = nn.GLU(dim=1)
        self.depth_conv = GraniteSpeechConformerDepthWiseConv1d(
            inner_dim,
            inner_dim,
            kernel_size=config.conv_kernel_size,
            prefix=f"{prefix}.depth_conv",
        )
        self.silu = nn.SiLU()
        self.batch_norm = nn.BatchNorm1d(inner_dim)
        self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.norm(hidden_states)
        hidden_states = self.up_conv(hidden_states.permute(0, 2, 1))
        hidden_states = self.glu(hidden_states)
        hidden_states = self.depth_conv(hidden_states)
        hidden_states = self.silu(self.batch_norm(hidden_states))
        hidden_states = self.down_conv(hidden_states).permute(0, 2, 1)
        return hidden_states

GraniteSpeechConformerDepthWiseConv1d

Bases: Module

Wrapper for padded 1D pointwise convolution.

Source code in vllm/model_executor/models/granite_speech.py
class GraniteSpeechConformerDepthWiseConv1d(nn.Module):
    """Wrapper for padded 1D pointwise convolution."""

    def __init__(self, chan_in: int, chan_out: int, kernel_size: int, prefix: str = ""):
        super().__init__()
        # Padding for the 1D conv is symmetric or close (i.e., offset by one).
        pad = kernel_size // 2
        pad_offset = (kernel_size + 1) % 2
        self.padding = (pad, pad - pad_offset)

        self.conv = nn.Conv1d(
            chan_in, chan_out, kernel_size, groups=chan_in, bias=False
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = F.pad(hidden_states, self.padding)
        return self.conv(hidden_states)

GraniteSpeechConformerFeedForward

Bases: Module

Feedforward module for conformer encoder blocks.

Source code in vllm/model_executor/models/granite_speech.py
class GraniteSpeechConformerFeedForward(nn.Module):
    """Feedforward module for conformer encoder blocks."""

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.pre_norm = nn.LayerNorm(config.hidden_dim)

        self.up_proj = ColumnParallelLinear(
            input_size=config.hidden_dim,
            output_size=config.hidden_dim * config.feedforward_mult,
            quant_config=quant_config,
            prefix=f"{prefix}.up_proj",
        )
        self.silu = nn.SiLU()

        self.down_proj = RowParallelLinear(
            input_size=config.hidden_dim * config.feedforward_mult,
            output_size=config.hidden_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.pre_norm(hidden_states)
        hidden_states, _ = self.up_proj(hidden_states)
        hidden_states = self.silu(hidden_states)
        hidden_states, _ = self.down_proj(hidden_states)
        return hidden_states

GraniteSpeechForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription

Source code in vllm/model_executor/models/granite_speech.py
@MULTIMODAL_REGISTRY.register_processor(
    GraniteSpeechMultiModalProcessor,
    info=GraniteSpeechMultiModalProcessingInfo,
    dummy_inputs=GraniteSpeechDummyInputsBuilder,
)
class GraniteSpeechForConditionalGeneration(
    nn.Module,
    SupportsMultiModal,
    SupportsPP,
    SupportsLoRA,
    SupportsTranscription,
):
    supported_languages = ISO639_1_SUPPORTED_LANGS

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("audio"):
            return "<|audio|>"

        raise ValueError("Only audio modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        cache_config = vllm_config.cache_config

        self.config = config
        self.quant_config = quant_config
        self.cache_config = cache_config

        with self._mark_language_model(vllm_config):
            # The language model is typically a Granite LLM
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )

        with self._mark_tower_model(vllm_config, "audio"):
            # Conformer encoder
            self.encoder = GraniteSpeechCTCEncoder(
                config=config.encoder_config,
                quant_config=quant_config,
                prefix=f"{prefix}.encoder",
            )

            # Blip2 QFormer
            self.projector = GraniteSpeechEncoderProjector(
                config=config,
                quant_config=quant_config,
                cache_config=cache_config,
                prefix=f"{prefix}.projector",
            )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    def _parse_and_validate_audio_input(
        self,
        **kwargs: object,
    ) -> GraniteSpeechAudioInputs | None:
        input_features = kwargs.pop("input_features", None)
        input_features_mask = kwargs.pop("input_features_mask", None)
        audio_embed_sizes = kwargs.pop("audio_embed_sizes", None)

        if input_features is None:
            return None

        # If we have a batch of variable feature length audio clips, we need
        # to mask the features; usually we would get an input_features_mask
        # from the processor, but we handle rebuilding it here since
        # vLLM generally processes everything independently + batches.
        if input_features_mask is None:
            input_features_mask = self._build_input_features_mask(audio_embed_sizes)

        if not isinstance(input_features, (torch.Tensor, list)):
            raise ValueError(
                "Incorrect type of audio input features. "
                f"Got type: {type(input_features)}"
            )

        if input_features_mask is not None and not isinstance(
            input_features_mask, torch.Tensor
        ):
            raise ValueError(
                "Incorrect type of audio input features mask. "
                f"Got type: {type(input_features_mask)}"
            )

        if isinstance(input_features, torch.Tensor):
            # Granite speech currently only allows one audio token per instance
            # and features are already unsqueezed in the processor, so one
            # instance will have shape [1, {num_features}, 160]. As such,
            # input features will usually be of shape
            # [bsz, 1, num_features, 160], which we squeeze to be 3D here.
            if len(input_features.shape) == 4:
                input_features = input_features.squeeze(1)
            if len(input_features.shape) != 3:
                raise ValueError(
                    "Squeezed input features should be 3D but are of shape "
                    f"{input_features.shape}"
                )
            input_features = input_features.to(self.encoder.input_linear.weight.dtype)

        else:
            # Otherwise we have a list of tensors, which are almost certainly
            # differing in their respective numbers of audio features; when
            # passed as a batch, we expect a list of 2D var len input features
            # so unsqueeze them.
            input_features = [
                feat.unsqueeze(dim=0) for feat in input_features if feat.ndim == 2
            ]

            # stack them into a 3D tensor of size [bsz, most_num_features, 160].
            input_features = self._pad_and_stack_input_features(
                input_features,
            ).to(self.encoder.input_linear.weight.dtype)

        return GraniteSpeechAudioInputs(
            input_features=input_features,
            input_features_mask=input_features_mask,
            audio_embed_sizes=audio_embed_sizes.flatten().tolist(),
        )

    def _build_input_features_mask(
        self,
        audio_embed_sizes: torch.Tensor,
    ) -> torch.Tensor:
        """Calculate the input features mask, which will generally be used
        to mask the padded features for all entries in the batch except
        for those with the most audio features.

        Args:
            audio_embed_sizes: torch.Tensor
                Tensor of num features in each seq in the batch.
        Returns:
            torch.Tensor: Mask of shape (bsz, num_features) to be applied to
            the audio features prior to splitting the audio embeddings.
        """
        most_audio_features = torch.max(audio_embed_sizes).item()
        mask_indices = torch.arange(
            most_audio_features,
            device=audio_embed_sizes.device,
        ).view(1, -1)
        input_features_mask = mask_indices < audio_embed_sizes.view(-1, 1)
        return input_features_mask

    def _pad_and_stack_input_features(
        self,
        input_features: list[torch.Tensor],
    ) -> torch.Tensor:
        """Given a list of input features of varying length, pad them to the
        same length and stack them into a torch.Tensor.

        NOTE: Usually, padding is done in the input processor/feature extractor
        and zero padded prior to the computation of the Mel features; the
        resulting values are only constant within a batch and generally nonzero
        (i.e., slightly negative nums); we should validate that this is okay
        since we don't use a feature attention mask, but the more important
        thing is that we apply the input_features_mask with variable len
        batches.

        Args:
            input_features: list[torch.Tensor]
                3D Input features to be coerced into a tensor.
        Returns:
            torch.Tensor: Tensor of shape [bsz, num_features, 160], where
            num_features is the max number of features of any entry in the
            batch.
        """
        feat_lens = [feats.shape[1] for feats in input_features]
        padding = [max(feat_lens) - length for length in feat_lens]
        # TODO (Alex) - Validate that it's okay to zero pad like this;
        # in transformers we zero pad prior to calculating the speech features,
        # so the value is not zero and is dependent on the batched features.
        padded = [
            torch.nn.functional.pad(feats, (0, 0, 0, pad, 0, 0))
            for feats, pad in zip(input_features, padding)
        ]
        stacked_features = torch.cat(padded, dim=0).to(input_features[0])
        return stacked_features

    def _process_audio_input(
        self,
        audio_input: GraniteSpeechAudioInputs,
    ) -> tuple[torch.Tensor]:
        """Compute the audio features to be merged into the LLM embeddings.

        Args:
            audio_input: GraniteSpeechAudioInputs
                Audio inputs object containing Mel features, an input features
                mask, and the (flattened) number of audio tokens per instance.
        Returns:
            tuple[torch.Tensor]: List of length bsz.
        """
        # TODO (Alex) - support embedding inputs
        encoder_embeds = self.encoder(audio_input["input_features"])
        # [bsz, <max feature size>, 4096]
        projected_embeds = self.projector(encoder_embeds)
        # Apply mask on variable length audio features
        masked_embeds = projected_embeds[audio_input["input_features_mask"]]
        # Split variable length features into a tuple
        return torch.split(masked_embeds, audio_input["audio_embed_sizes"])

    def embed_multimodal(
        self,
        **kwargs: object,
    ) -> MultiModalEmbeddings:
        """Compute the audio embeddings if audio inputs are present."""
        audio_input = self._parse_and_validate_audio_input(**kwargs)
        if audio_input is None:
            return []

        audio_features = self._process_audio_input(audio_input)
        return audio_features

    def embed_input_ids(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
        # Multi-modal token ID may exceed vocab size
        handle_oov_mm_token: bool = True,
    ) -> torch.Tensor:
        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
            return super().embed_input_ids(input_ids)

        return super().embed_input_ids(
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        if intermediate_tensors is not None:
            inputs_embeds = None

        model_output = self.language_model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
        return model_output

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(
        self,
        weights: Iterable[tuple[str, torch.Tensor]],
    ) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

    def get_mm_mapping(self) -> MultiModelKeys:
        """Get the module prefix in multimodal models."""
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="projector",
            tower_model="encoder",
        )

    ### Support for speech-to-text Transcription
    @classmethod
    def get_generation_prompt(
        cls,
        audio: np.ndarray,
        model_config: ModelConfig,
        stt_config: SpeechToTextConfig,
        language: str | None,
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
        to_language: str | None,
    ) -> PromptType:
        """Get the generation prompt to be used for transcription requests."""
        # Audio placeholders don't use an index, so value doesn't matter
        audio_tok = cls.get_placeholder_str("audio", 0)

        if task_type == "translate":
            full_lang_name_to = cls.supported_languages.get(to_language, to_language)
            user_prompt = f"{audio_tok}translate the speech to {full_lang_name_to}"  # noqa: E501
        elif task_type == "transcribe":
            user_prompt = (
                f"{audio_tok}can you transcribe the speech into a written format?"  # noqa: E501
            )
        else:
            raise ValueError(f"Unsupported task type {task_type}")

        tokenizer = cached_tokenizer_from_config(model_config)
        chat = [dict(role="user", content=user_prompt)]
        prompt = tokenizer.apply_chat_template(
            chat,
            tokenize=False,
            add_generation_prompt=True,
        )

        prompt_token_ids = tokenizer.encode(prompt)

        return TokensPrompt(
            prompt_token_ids=prompt_token_ids,
            multi_modal_data={"audio": audio},
        )

    # Adapted from https://github.com/huggingface/transformers/blob/v4.56.0/src/transformers/models/granite_speech/feature_extraction_granite_speech.py#L122 # noqa: E501
    @classmethod
    def get_num_audio_tokens(
        cls,
        audio_duration_s: float,
        stt_config: SpeechToTextConfig,
        model_config: ModelConfig,
    ) -> int | None:
        """Get the number of audio tokens for an audio duration in sec."""
        processor = cached_processor_from_config(model_config)
        hop_length = processor.audio_processor.melspec_kwargs["hop_length"]
        proj_win_size = processor.audio_processor.projector_window_size
        ds_rate = processor.audio_processor.projector_downsample_rate
        effective_window_size = proj_win_size // ds_rate

        raw_length = audio_duration_s * stt_config.sample_rate

        # mel sequence length computation
        mel_length = raw_length // hop_length + 1
        # encoder frame takes two mel features
        encoder_length = mel_length // 2
        nblocks = math.ceil(encoder_length / proj_win_size)
        # projector output length
        return nblocks * effective_window_size

    @classmethod
    def get_speech_to_text_config(
        cls, model_config: ModelConfig, task_type: str
    ) -> SpeechToTextConfig:
        """Get the stt config for this model."""
        # Default settings are reasonable for this model and we don't currently
        # expose this information in the model configs, but this may change in
        # the future
        return SpeechToTextConfig()

_build_input_features_mask

_build_input_features_mask(
    audio_embed_sizes: Tensor,
) -> Tensor

Calculate the input features mask, which will generally be used to mask the padded features for all entries in the batch except for those with the most audio features.

Parameters:

Name Type Description Default
audio_embed_sizes Tensor

torch.Tensor Tensor of num features in each seq in the batch.

required

Returns: torch.Tensor: Mask of shape (bsz, num_features) to be applied to the audio features prior to splitting the audio embeddings.

Source code in vllm/model_executor/models/granite_speech.py
def _build_input_features_mask(
    self,
    audio_embed_sizes: torch.Tensor,
) -> torch.Tensor:
    """Calculate the input features mask, which will generally be used
    to mask the padded features for all entries in the batch except
    for those with the most audio features.

    Args:
        audio_embed_sizes: torch.Tensor
            Tensor of num features in each seq in the batch.
    Returns:
        torch.Tensor: Mask of shape (bsz, num_features) to be applied to
        the audio features prior to splitting the audio embeddings.
    """
    most_audio_features = torch.max(audio_embed_sizes).item()
    mask_indices = torch.arange(
        most_audio_features,
        device=audio_embed_sizes.device,
    ).view(1, -1)
    input_features_mask = mask_indices < audio_embed_sizes.view(-1, 1)
    return input_features_mask

_pad_and_stack_input_features

_pad_and_stack_input_features(
    input_features: list[Tensor],
) -> Tensor

Given a list of input features of varying length, pad them to the same length and stack them into a torch.Tensor.

NOTE: Usually, padding is done in the input processor/feature extractor and zero padded prior to the computation of the Mel features; the resulting values are only constant within a batch and generally nonzero (i.e., slightly negative nums); we should validate that this is okay since we don't use a feature attention mask, but the more important thing is that we apply the input_features_mask with variable len batches.

Parameters:

Name Type Description Default
input_features list[Tensor]

list[torch.Tensor] 3D Input features to be coerced into a tensor.

required

Returns: torch.Tensor: Tensor of shape [bsz, num_features, 160], where num_features is the max number of features of any entry in the batch.

Source code in vllm/model_executor/models/granite_speech.py
def _pad_and_stack_input_features(
    self,
    input_features: list[torch.Tensor],
) -> torch.Tensor:
    """Given a list of input features of varying length, pad them to the
    same length and stack them into a torch.Tensor.

    NOTE: Usually, padding is done in the input processor/feature extractor
    and zero padded prior to the computation of the Mel features; the
    resulting values are only constant within a batch and generally nonzero
    (i.e., slightly negative nums); we should validate that this is okay
    since we don't use a feature attention mask, but the more important
    thing is that we apply the input_features_mask with variable len
    batches.

    Args:
        input_features: list[torch.Tensor]
            3D Input features to be coerced into a tensor.
    Returns:
        torch.Tensor: Tensor of shape [bsz, num_features, 160], where
        num_features is the max number of features of any entry in the
        batch.
    """
    feat_lens = [feats.shape[1] for feats in input_features]
    padding = [max(feat_lens) - length for length in feat_lens]
    # TODO (Alex) - Validate that it's okay to zero pad like this;
    # in transformers we zero pad prior to calculating the speech features,
    # so the value is not zero and is dependent on the batched features.
    padded = [
        torch.nn.functional.pad(feats, (0, 0, 0, pad, 0, 0))
        for feats, pad in zip(input_features, padding)
    ]
    stacked_features = torch.cat(padded, dim=0).to(input_features[0])
    return stacked_features

_process_audio_input

_process_audio_input(
    audio_input: GraniteSpeechAudioInputs,
) -> tuple[Tensor]

Compute the audio features to be merged into the LLM embeddings.

Parameters:

Name Type Description Default
audio_input GraniteSpeechAudioInputs

GraniteSpeechAudioInputs Audio inputs object containing Mel features, an input features mask, and the (flattened) number of audio tokens per instance.

required

Returns: tuple[torch.Tensor]: List of length bsz.

Source code in vllm/model_executor/models/granite_speech.py
def _process_audio_input(
    self,
    audio_input: GraniteSpeechAudioInputs,
) -> tuple[torch.Tensor]:
    """Compute the audio features to be merged into the LLM embeddings.

    Args:
        audio_input: GraniteSpeechAudioInputs
            Audio inputs object containing Mel features, an input features
            mask, and the (flattened) number of audio tokens per instance.
    Returns:
        tuple[torch.Tensor]: List of length bsz.
    """
    # TODO (Alex) - support embedding inputs
    encoder_embeds = self.encoder(audio_input["input_features"])
    # [bsz, <max feature size>, 4096]
    projected_embeds = self.projector(encoder_embeds)
    # Apply mask on variable length audio features
    masked_embeds = projected_embeds[audio_input["input_features_mask"]]
    # Split variable length features into a tuple
    return torch.split(masked_embeds, audio_input["audio_embed_sizes"])

embed_multimodal

embed_multimodal(**kwargs: object) -> MultiModalEmbeddings

Compute the audio embeddings if audio inputs are present.

Source code in vllm/model_executor/models/granite_speech.py
def embed_multimodal(
    self,
    **kwargs: object,
) -> MultiModalEmbeddings:
    """Compute the audio embeddings if audio inputs are present."""
    audio_input = self._parse_and_validate_audio_input(**kwargs)
    if audio_input is None:
        return []

    audio_features = self._process_audio_input(audio_input)
    return audio_features

get_generation_prompt classmethod

get_generation_prompt(
    audio: ndarray,
    model_config: ModelConfig,
    stt_config: SpeechToTextConfig,
    language: str | None,
    task_type: Literal["transcribe", "translate"],
    request_prompt: str,
    to_language: str | None,
) -> PromptType

Get the generation prompt to be used for transcription requests.

Source code in vllm/model_executor/models/granite_speech.py
@classmethod
def get_generation_prompt(
    cls,
    audio: np.ndarray,
    model_config: ModelConfig,
    stt_config: SpeechToTextConfig,
    language: str | None,
    task_type: Literal["transcribe", "translate"],
    request_prompt: str,
    to_language: str | None,
) -> PromptType:
    """Get the generation prompt to be used for transcription requests."""
    # Audio placeholders don't use an index, so value doesn't matter
    audio_tok = cls.get_placeholder_str("audio", 0)

    if task_type == "translate":
        full_lang_name_to = cls.supported_languages.get(to_language, to_language)
        user_prompt = f"{audio_tok}translate the speech to {full_lang_name_to}"  # noqa: E501
    elif task_type == "transcribe":
        user_prompt = (
            f"{audio_tok}can you transcribe the speech into a written format?"  # noqa: E501
        )
    else:
        raise ValueError(f"Unsupported task type {task_type}")

    tokenizer = cached_tokenizer_from_config(model_config)
    chat = [dict(role="user", content=user_prompt)]
    prompt = tokenizer.apply_chat_template(
        chat,
        tokenize=False,
        add_generation_prompt=True,
    )

    prompt_token_ids = tokenizer.encode(prompt)

    return TokensPrompt(
        prompt_token_ids=prompt_token_ids,
        multi_modal_data={"audio": audio},
    )

get_mm_mapping

get_mm_mapping() -> MultiModelKeys

Get the module prefix in multimodal models.

Source code in vllm/model_executor/models/granite_speech.py
def get_mm_mapping(self) -> MultiModelKeys:
    """Get the module prefix in multimodal models."""
    return MultiModelKeys.from_string_field(
        language_model="language_model",
        connector="projector",
        tower_model="encoder",
    )

get_num_audio_tokens classmethod

get_num_audio_tokens(
    audio_duration_s: float,
    stt_config: SpeechToTextConfig,
    model_config: ModelConfig,
) -> int | None

Get the number of audio tokens for an audio duration in sec.

Source code in vllm/model_executor/models/granite_speech.py
@classmethod
def get_num_audio_tokens(
    cls,
    audio_duration_s: float,
    stt_config: SpeechToTextConfig,
    model_config: ModelConfig,
) -> int | None:
    """Get the number of audio tokens for an audio duration in sec."""
    processor = cached_processor_from_config(model_config)
    hop_length = processor.audio_processor.melspec_kwargs["hop_length"]
    proj_win_size = processor.audio_processor.projector_window_size
    ds_rate = processor.audio_processor.projector_downsample_rate
    effective_window_size = proj_win_size // ds_rate

    raw_length = audio_duration_s * stt_config.sample_rate

    # mel sequence length computation
    mel_length = raw_length // hop_length + 1
    # encoder frame takes two mel features
    encoder_length = mel_length // 2
    nblocks = math.ceil(encoder_length / proj_win_size)
    # projector output length
    return nblocks * effective_window_size

get_speech_to_text_config classmethod

get_speech_to_text_config(
    model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig

Get the stt config for this model.

Source code in vllm/model_executor/models/granite_speech.py
@classmethod
def get_speech_to_text_config(
    cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
    """Get the stt config for this model."""
    # Default settings are reasonable for this model and we don't currently
    # expose this information in the model configs, but this may change in
    # the future
    return SpeechToTextConfig()