Skip to content

vllm.model_executor.models.pixtral

PatchMerger

Bases: Module

Learned merging of spatial_merge_size ** 2 patches

Source code in vllm/model_executor/models/pixtral.py
class PatchMerger(nn.Module):
    """
    Learned merging of spatial_merge_size ** 2 patches
    """

    def __init__(
        self,
        vision_encoder_dim: int,
        spatial_merge_size: int,
        use_mlp_bias: bool = False,
    ) -> None:
        super().__init__()

        mlp_input_dim = vision_encoder_dim * (spatial_merge_size**2)

        self.spatial_merge_size = spatial_merge_size
        self.mlp_input_dim = mlp_input_dim

        self.merging_layer = nn.Linear(
            mlp_input_dim,
            vision_encoder_dim,
            bias=use_mlp_bias,
        )

    def forward(
        self, x: torch.Tensor, image_sizes: list[tuple[int, int]]
    ) -> torch.Tensor:
        # image_sizes specified in tokens
        assert sum([h * w for h, w in image_sizes]) == len(x)

        # x is (N, vision_encoder_dim)
        x = self.permute(x, image_sizes)

        # x is (N / spatial_merge_size ** 2,
        #       vision_encoder_dim * spatial_merge_size ** 2)
        x = self.merging_layer(x)

        # x is (N / spatial_merge_size ** 2, vision_encoder_dim)
        return x

    def permute(
        self,
        x: torch.Tensor,
        image_sizes: list[tuple[int, int]],
    ) -> torch.Tensor:
        """
        Args:
            x: (N, D) where N is flattened and concatenated patch tokens
                for all images
            image_sizes: list of tuple of (height, width) in tokens for
                each image
        Returns:
            image_features: reorders patch tokens so each grid of
                (spatial_merge_size, spatial_merge_size) is contiguous.
                now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2)
        """

        sub_grids = get_sub_grids(
            x=x, image_sizes=image_sizes, spatial_merge_size=self.spatial_merge_size
        )  # list of [d x sub_grid_size x sub_grid_size x n_patches]
        permuted_tensor: list[torch.Tensor] = []
        for grid in sub_grids:
            n_patches = grid.shape[-1]
            permuted_tensor.append(
                grid.view(-1, n_patches).t()
            )  # n_patches x d * sub_grid_size * sub_grid_size
        return torch.cat(
            permuted_tensor, dim=0
        )  # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2)

permute

permute(
    x: Tensor, image_sizes: list[tuple[int, int]]
) -> Tensor

Parameters:

Name Type Description Default
x Tensor

(N, D) where N is flattened and concatenated patch tokens for all images

required
image_sizes list[tuple[int, int]]

list of tuple of (height, width) in tokens for each image

required

Returns: image_features: reorders patch tokens so each grid of (spatial_merge_size, spatial_merge_size) is contiguous. now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2)

Source code in vllm/model_executor/models/pixtral.py
def permute(
    self,
    x: torch.Tensor,
    image_sizes: list[tuple[int, int]],
) -> torch.Tensor:
    """
    Args:
        x: (N, D) where N is flattened and concatenated patch tokens
            for all images
        image_sizes: list of tuple of (height, width) in tokens for
            each image
    Returns:
        image_features: reorders patch tokens so each grid of
            (spatial_merge_size, spatial_merge_size) is contiguous.
            now (N / spatial_merge_size ** 2, D * spatial_merge_size ** 2)
    """

    sub_grids = get_sub_grids(
        x=x, image_sizes=image_sizes, spatial_merge_size=self.spatial_merge_size
    )  # list of [d x sub_grid_size x sub_grid_size x n_patches]
    permuted_tensor: list[torch.Tensor] = []
    for grid in sub_grids:
        n_patches = grid.shape[-1]
        permuted_tensor.append(
            grid.view(-1, n_patches).t()
        )  # n_patches x d * sub_grid_size * sub_grid_size
    return torch.cat(
        permuted_tensor, dim=0
    )  # (N / spatial_merge_size ** 2, d * spatial_merge_size ** 2)

PixtralForConditionalGeneration

Bases: Module, SupportsLoRA, SupportsMultiModal, SupportsPP

Source code in vllm/model_executor/models/pixtral.py
@MULTIMODAL_REGISTRY.register_processor(
    PixtralMultiModalProcessor,
    info=PixtralProcessingInfo,
    dummy_inputs=PixtralDummyInputsBuilder,
)
class PixtralForConditionalGeneration(
    nn.Module, SupportsLoRA, SupportsMultiModal, SupportsPP
):
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.multimodal_config = multimodal_config

        dataclass_fields = {field.name for field in fields(VisionEncoderArgs)}
        vision_args = {
            key: value
            for key, value in self.config.vision_config.to_dict().items()
            if key in dataclass_fields
        }

        self.vision_args = VisionEncoderArgs(**vision_args)

        # init MistralForCausalLM
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
            )

        with self._mark_tower_model(vllm_config, "image"):
            self.vision_encoder = VisionTransformer(self.vision_args)
            self.pre_mm_projector_norm = (
                RMSNorm(self.vision_args.hidden_size, eps=1e-5)
                if self.vision_args.add_pre_mm_projector_layer_norm
                else None
            )
            self.patch_merger = (
                PatchMerger(
                    vision_encoder_dim=self.vision_args.hidden_size,
                    spatial_merge_size=self.vision_args.spatial_merge_size,
                    use_mlp_bias=False,
                )
                if self.vision_args.mm_projector_id == PATCH_MERGE
                else None
            )
            self.vision_language_adapter = VisionLanguageAdapter(
                self.vision_args, dim=config.text_config.hidden_size
            )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    def _parse_and_validate_image_input(
        self, **kwargs: object
    ) -> PixtralImagePixelInputs | None:
        images = kwargs.pop("images", None)
        if images is None:
            return None

        return PixtralImagePixelInputs(
            type="pixel_values",
            images=images,
        )

    def _process_image_input(
        self,
        image_input: PixtralImagePixelInputs,
    ) -> tuple[torch.Tensor, ...]:
        images = image_input["images"]
        image_features = self.vision_encoder(images)
        feature_sizes = [image_feature.shape[0] for image_feature in image_features]
        image_features = torch.cat(image_features)
        if self.pre_mm_projector_norm is not None:
            image_features = self.pre_mm_projector_norm(image_features)
        if self.patch_merger is not None:
            patch_size = self.vision_args.patch_size
            spatial_merge_size_square = self.vision_args.spatial_merge_size**2
            img_patch_dims = [
                (img.shape[1] // patch_size, img.shape[2] // patch_size)
                for img in images
            ]
            feature_sizes = [
                feature_size // spatial_merge_size_square
                for feature_size in feature_sizes
            ]
            image_features = self.patch_merger(
                image_features, image_sizes=img_patch_dims
            )
        image_embeds = self.vision_language_adapter(image_features)
        image_embeds = torch.split(image_embeds, feature_sizes)
        return image_embeds

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        return self._process_image_input(image_input)

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        """Run forward pass for pixtral."""
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
            return weight[0].startswith(("vision_encoder", "vision_tower"))

        def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]):
            return weight[0].startswith(
                ("vision_language_adapter", "multi_modal_projector")
            )

        def is_patch_merger(weight: tuple[str, torch.Tensor]):
            return weight[0].startswith("patch_merger")

        def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]):
            return weight[0].startswith("pre_mm_projector_norm")

        # Get references to parameters for direct loading
        vision_encoder_dict = (
            dict(self.vision_encoder.named_parameters())
            if self.vision_encoder is not None
            else {}
        )
        patch_merger_dict = (
            dict(self.patch_merger.named_parameters())
            if self.patch_merger is not None
            else {}
        )
        pre_mm_projector_norm_dict = (
            dict(self.pre_mm_projector_norm.named_parameters())
            if self.pre_mm_projector_norm is not None
            else {}
        )
        vision_lang_adapter_dict = (
            dict(self.vision_language_adapter.named_parameters())
            if self.vision_language_adapter is not None
            else {}
        )

        def llm_weights_generator():
            # Single pass over weights
            for name, w in weights:
                if is_vision_encoder_weights((name, w)):
                    if _is_layer_none_or_staged(self.vision_encoder):
                        continue
                    # Load vision encoder weights directly
                    trimmed_name = ".".join(name.split(".")[1:])
                    param = vision_encoder_dict.get(trimmed_name)
                    if param is not None:
                        with torch.no_grad():
                            default_weight_loader(param, w)
                elif is_patch_merger((name, w)):
                    if _is_layer_none_or_staged(self.patch_merger):
                        continue
                    # Load vision patch merger weights directly
                    trimmed_name = ".".join(name.split(".")[1:])
                    param = patch_merger_dict[trimmed_name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
                elif is_pre_mm_projector_norm((name, w)):
                    if _is_layer_none_or_staged(self.pre_mm_projector_norm):
                        continue
                    # Load vision pre_mm_projector_norm weights directly
                    trimmed_name = ".".join(name.split(".")[1:])
                    param = pre_mm_projector_norm_dict[trimmed_name]
                    with torch.no_grad():
                        default_weight_loader(param, w)
                elif is_vision_lang_adapter_weights((name, w)):
                    if _is_layer_none_or_staged(self.vision_language_adapter):
                        continue
                    # Load vision-language adapter weights directly
                    trimmed_name = ".".join(name.split(".")[1:])
                    param = vision_lang_adapter_dict.get(trimmed_name)
                    if param is not None:
                        with torch.no_grad():
                            default_weight_loader(param, w)
                else:
                    # LLM weights: yield them to be loaded
                    # by language_model.load_weights
                    # Strip "language_model." prefix if present (HF sharded format)
                    name = name.removeprefix("language_model.")
                    yield (name, w)

        # Now we call the language model load with the generator
        self.language_model.load_weights(llm_weights_generator())

    def get_mm_mapping(self) -> MultiModelKeys:
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="vision_language_adapter",
            tower_model="vision_encoder",
        )

    def get_num_mm_encoder_tokens(self, num_image_tokens: int) -> int:
        if getattr(self, "patch_merger", None) is None:
            return num_image_tokens
        merge_size = self.vision_args.spatial_merge_size
        return num_image_tokens * (merge_size**2)

    def get_num_mm_connector_tokens(self, num_vision_tokens: int) -> int:
        if getattr(self, "patch_merger", None) is None:
            return num_vision_tokens
        merge_size = self.vision_args.spatial_merge_size
        return num_vision_tokens // (merge_size**2)

forward

forward(
    input_ids: Tensor | None,
    positions: Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: Tensor | None = None,
    **kwargs: object,
) -> Tensor | IntermediateTensors

Run forward pass for pixtral.

Source code in vllm/model_executor/models/pixtral.py
def forward(
    self,
    input_ids: torch.Tensor | None,
    positions: torch.Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: torch.Tensor | None = None,
    **kwargs: object,
) -> torch.Tensor | IntermediateTensors:
    """Run forward pass for pixtral."""
    if intermediate_tensors is not None:
        inputs_embeds = None

    hidden_states = self.language_model.model(
        input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
    )

    return hidden_states

PixtralHFVisionModel

Bases: Module

Source code in vllm/model_executor/models/pixtral.py
class PixtralHFVisionModel(nn.Module):
    def __init__(
        self,
        config: PixtralVisionConfig,
        quant_config: QuantizationConfig | None = None,
        *,
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.config = config

        self.patch_conv = Conv2dLayer(
            in_channels=config.num_channels,
            out_channels=config.hidden_size,
            kernel_size=config.patch_size,
            stride=config.patch_size,
            bias=False,
        )
        self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
        self.transformer = PixtralHFTransformer(
            config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_override,
            prefix=f"{prefix}.transformer",
        )

        num_hidden_layers = config.num_hidden_layers
        if len(self.transformer.layers) > config.num_hidden_layers:
            raise ValueError(
                f"The original encoder only has {num_hidden_layers} "
                f"layers, but you requested {len(self.transformer.layers)} "
                "layers."
            )

        if require_post_norm is True:
            msg = "PixtralHFVisionModel does not have post-layernorm"
            raise ValueError(msg)

        self.dtype = next(self.parameters()).dtype
        self.device = next(self.parameters()).device
        self.patch_positional_embedding = PixtralRotaryEmbedding(config, self.device)

    def forward(
        self,
        pixel_values: list[torch.Tensor],
        *,
        select_layers: list[int] | None = None,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
    ) -> tuple[torch.Tensor, ...]:
        """
        Args:
            pixel_values: Each image to be processed will be a separate tensor
                in pixel_values. This means it will be a list of tensors
                because multiple requests batched can have multiple images,
                each with their own shape potentially
            select_layers: Layer indices whose features should be
                concatenated and used as the visual encoder output. If none
                are provided, the last layer is used.

        Returns:
            image_features: tensor of token features for
                all tokens of all images of shape (N_toks, D)
        """
        # pass images through initial convolution independently
        patch_embeds_list = [
            self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values
        ]

        patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list]
        embed_sizes = [p.shape[1] for p in patch_embeds]

        # flatten to a single sequence
        patch_embeds = torch.cat(patch_embeds, dim=1)
        patch_embeds = self.ln_pre(patch_embeds)

        # positional embeddings
        position_ids = position_ids_in_meshgrid(
            patch_embeds_list,
            max_width=self.config.image_size // self.config.patch_size,
        ).to(self.device)
        position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)

        if USE_XFORMERS_OPS:
            attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
            )
        else:
            from transformers.models.pixtral.modeling_pixtral import (
                generate_block_attention_mask,
            )

            attention_mask = generate_block_attention_mask(
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
            )

        out = self.transformer(
            patch_embeds,
            attention_mask,
            position_embedding,
            return_all_hidden_states=select_layers is not None,
        )

        out = resolve_visual_encoder_outputs(
            out,
            None,
            select_layers=select_layers,
            max_possible_layers=self.config.num_hidden_layers,
            feature_select_strategy=feature_select_strategy,
        )

        # squeeze dim 0 and split into separate tensors for each image
        return torch.split(out.squeeze(0), embed_sizes)

    # (TODO) Add prefix argument for filtering out weights to be loaded
    #        ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        layer_count = len(self.transformer.layers)

        for name, loaded_weight in weights:
            # omit layers when num_hidden_layers_override is set
            if name.startswith("transformer.layers"):
                layer_idx = int(name.split(".")[2])
                if layer_idx >= layer_count:
                    continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

forward

forward(
    pixel_values: list[Tensor],
    *,
    select_layers: list[int] | None = None,
    feature_select_strategy: VisionFeatureSelectStrategy
    | None = None,
) -> tuple[Tensor, ...]

Parameters:

Name Type Description Default
pixel_values list[Tensor]

Each image to be processed will be a separate tensor in pixel_values. This means it will be a list of tensors because multiple requests batched can have multiple images, each with their own shape potentially

required
select_layers list[int] | None

Layer indices whose features should be concatenated and used as the visual encoder output. If none are provided, the last layer is used.

None

Returns:

Name Type Description
image_features tuple[Tensor, ...]

tensor of token features for all tokens of all images of shape (N_toks, D)

Source code in vllm/model_executor/models/pixtral.py
def forward(
    self,
    pixel_values: list[torch.Tensor],
    *,
    select_layers: list[int] | None = None,
    feature_select_strategy: VisionFeatureSelectStrategy | None = None,
) -> tuple[torch.Tensor, ...]:
    """
    Args:
        pixel_values: Each image to be processed will be a separate tensor
            in pixel_values. This means it will be a list of tensors
            because multiple requests batched can have multiple images,
            each with their own shape potentially
        select_layers: Layer indices whose features should be
            concatenated and used as the visual encoder output. If none
            are provided, the last layer is used.

    Returns:
        image_features: tensor of token features for
            all tokens of all images of shape (N_toks, D)
    """
    # pass images through initial convolution independently
    patch_embeds_list = [
        self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in pixel_values
    ]

    patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list]
    embed_sizes = [p.shape[1] for p in patch_embeds]

    # flatten to a single sequence
    patch_embeds = torch.cat(patch_embeds, dim=1)
    patch_embeds = self.ln_pre(patch_embeds)

    # positional embeddings
    position_ids = position_ids_in_meshgrid(
        patch_embeds_list,
        max_width=self.config.image_size // self.config.patch_size,
    ).to(self.device)
    position_embedding = self.patch_positional_embedding(patch_embeds, position_ids)

    if USE_XFORMERS_OPS:
        attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
            [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
        )
    else:
        from transformers.models.pixtral.modeling_pixtral import (
            generate_block_attention_mask,
        )

        attention_mask = generate_block_attention_mask(
            [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
        )

    out = self.transformer(
        patch_embeds,
        attention_mask,
        position_embedding,
        return_all_hidden_states=select_layers is not None,
    )

    out = resolve_visual_encoder_outputs(
        out,
        None,
        select_layers=select_layers,
        max_possible_layers=self.config.num_hidden_layers,
        feature_select_strategy=feature_select_strategy,
    )

    # squeeze dim 0 and split into separate tensors for each image
    return torch.split(out.squeeze(0), embed_sizes)

PixtralImagePixelInputs

Bases: TensorSchema

Dimensions
  • bn: Batch size * number of images
  • c: Number of channels (3)
  • h: Height of each image
  • w: Width of each image

The result of stacking ImageEncoding.tokens from each prompt.

Source code in vllm/model_executor/models/pixtral.py
class PixtralImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image

    The result of stacking `ImageEncoding.tokens` from each prompt.
    """

    type: Literal["pixel_values"] = "pixel_values"

    images: Annotated[
        torch.Tensor | list[torch.Tensor],
        TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}),
    ]

PixtralProcessorAdapter

Provide a HF-compatible interface for mistral_common.tokens.tokenizers.multimodal.ImageEncoder.

Source code in vllm/model_executor/models/pixtral.py
class PixtralProcessorAdapter:
    """
    Provide a HF-compatible interface for
    `mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
    """

    def __init__(self, tokenizer: MistralTokenizer) -> None:
        super().__init__()

        self.tokenizer = tokenizer

    @property
    def image_processor(self) -> ImageEncoder:
        image_encoder = self.tokenizer.instruct.mm_encoder
        assert isinstance(image_encoder, ImageEncoder)
        return image_encoder

    @cached_property
    def image_break_id(self) -> int:
        return self.image_processor.special_ids.img_break

    @cached_property
    def image_token_id(self) -> int:
        return self.image_processor.special_ids.img

    @cached_property
    def image_end_id(self) -> int:
        return self.image_processor.special_ids.img_end

    @cached_property
    def image_size(self) -> int:
        return self.image_processor.mm_config.max_image_size

    @cached_property
    def patch_size(self) -> int:
        return self.image_processor.mm_config.image_patch_size

    def __call__(
        self,
        text: TextInput | list[TextInput] | None = None,
        images: ImageInput | list[ImageInput] | None = None,
        return_tensors: str | TensorType | None = None,
        **kwargs,
    ) -> Mapping[str, NestedTensors]:
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        if not images:
            input_ids = self.tokenizer(text).input_ids

            return {"input_ids": torch.tensor(input_ids)}

        # Allow dummy text, which is used for profiling as well as token inputs
        if any(len(t) > 0 for t in text):
            raise ValueError(
                "You've passed text inputs instead of token inputs. "
                "Make sure to process your input via `mistral_common`'s "
                "tokenizer or pass a chat completion request. "
                "For more info, see: "
                "https://github.com/vllm-project/vllm/issues/8411."
            )

        images_processed = list[torch.Tensor]()
        images_tokens = list[torch.Tensor]()

        for image in images:
            image_inputs = self.image_processor(ImageChunk(image=image))
            image_processed = torch.tensor(image_inputs.image)
            image_tokens = torch.tensor(image_inputs.tokens)

            images_processed.append(image_processed)
            images_tokens.append(image_tokens)

        return BatchFeature(
            {
                "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
                "images": images_processed,
            }
        )

VisionTransformer

Bases: Module

Source code in vllm/model_executor/models/pixtral.py
class VisionTransformer(nn.Module):
    def __init__(self, args: VisionEncoderArgs):
        super().__init__()
        self.args = args
        self.patch_conv = Conv2dLayer(
            in_channels=args.num_channels,
            out_channels=args.hidden_size,
            kernel_size=args.patch_size,
            stride=args.patch_size,
            bias=False,
        )
        self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5)
        self.transformer = Transformer(args)

        head_dim = self.args.hidden_size // self.args.num_attention_heads
        assert head_dim % 2 == 0, "ROPE requires even head_dim"
        self._freqs_cis: torch.Tensor | None = None

    @property
    def max_patches_per_side(self) -> int:
        return self.args.image_size // self.args.patch_size

    @property
    def device(self) -> torch.types.Device:
        return next(self.parameters()).device

    @property
    def dtype(self) -> torch.dtype:
        return next(self.parameters()).dtype

    @property
    def freqs_cis(self) -> torch.Tensor:
        if self._freqs_cis is None:
            self._freqs_cis = precompute_freqs_cis_2d(
                dim=self.args.hidden_size // self.args.num_attention_heads,
                height=self.max_patches_per_side,
                width=self.max_patches_per_side,
                theta=self.args.rope_theta,
            )

        if self._freqs_cis.device != self.device:
            self._freqs_cis = self._freqs_cis.to(device=self.device)

        return self._freqs_cis

    def forward(
        self,
        images: list[torch.Tensor],
    ) -> torch.Tensor:
        """
        Args:
            images: list of N_img images of variable sizes,
                each of shape (C, H, W)
        Returns:
            image_features: tensor of token features for
                all tokens of all images of shape (N_toks, D)
        """
        # pass images through initial convolution independently
        patch_embeds_list = [
            self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
        ]

        patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list]
        embed_sizes = [p.shape[1] for p in patch_embeds]

        # flatten to a single sequence
        patch_embeds = torch.cat(patch_embeds, dim=1)
        patch_embeds = self.ln_pre(patch_embeds)

        # positional embeddings
        positions = position_meshgrid(patch_embeds_list).to(self.device)
        freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]

        # pass through Transformer with a block diagonal mask delimiting images
        if USE_XFORMERS_OPS:
            mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
            )
        else:
            from transformers.models.pixtral.modeling_pixtral import (
                generate_block_attention_mask,
            )

            mask = generate_block_attention_mask(
                [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
            )
        out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)

        # squeeze dim 0 and split into separate tensors for each image
        return torch.split(out.squeeze(0), embed_sizes)

forward

forward(images: list[Tensor]) -> Tensor

Parameters:

Name Type Description Default
images list[Tensor]

list of N_img images of variable sizes, each of shape (C, H, W)

required

Returns: image_features: tensor of token features for all tokens of all images of shape (N_toks, D)

Source code in vllm/model_executor/models/pixtral.py
def forward(
    self,
    images: list[torch.Tensor],
) -> torch.Tensor:
    """
    Args:
        images: list of N_img images of variable sizes,
            each of shape (C, H, W)
    Returns:
        image_features: tensor of token features for
            all tokens of all images of shape (N_toks, D)
    """
    # pass images through initial convolution independently
    patch_embeds_list = [
        self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images
    ]

    patch_embeds = [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list]
    embed_sizes = [p.shape[1] for p in patch_embeds]

    # flatten to a single sequence
    patch_embeds = torch.cat(patch_embeds, dim=1)
    patch_embeds = self.ln_pre(patch_embeds)

    # positional embeddings
    positions = position_meshgrid(patch_embeds_list).to(self.device)
    freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]

    # pass through Transformer with a block diagonal mask delimiting images
    if USE_XFORMERS_OPS:
        mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
            [p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
        )
    else:
        from transformers.models.pixtral.modeling_pixtral import (
            generate_block_attention_mask,
        )

        mask = generate_block_attention_mask(
            [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds
        )
    out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)

    # squeeze dim 0 and split into separate tensors for each image
    return torch.split(out.squeeze(0), embed_sizes)

_reshape_for_broadcast

_reshape_for_broadcast(
    freqs_cis: Tensor, x: Tensor
) -> Tensor

freqs_cis: complex - (seq_len, head_dim / 2) x: complex - (bsz, seq_len, head_dim / 2)

Source code in vllm/model_executor/models/pixtral.py
def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    """
    freqs_cis: complex - (seq_len, head_dim / 2)
    x: complex - (bsz, seq_len, head_dim / 2)
    """
    ndim = x.ndim
    assert ndim > 1
    assert freqs_cis.shape == (x.shape[1], x.shape[-1]), (
        freqs_cis.shape,
        (x.shape[1], x.shape[-1]),
    )
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

precompute_freqs_cis_2d

precompute_freqs_cis_2d(
    dim: int, height: int, width: int, theta: float
) -> Tensor
2D complex tensor of shape (height, width, dim // 2)

to be indexed by (height, width) position tuples

Source code in vllm/model_executor/models/pixtral.py
def precompute_freqs_cis_2d(
    dim: int,
    height: int,
    width: int,
    theta: float,
) -> torch.Tensor:
    """
    freqs_cis: 2D complex tensor of shape (height, width, dim // 2)
        to be indexed by (height, width) position tuples
    """
    # (dim / 2) frequency bases
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))

    h = torch.arange(height, device=freqs.device)
    w = torch.arange(width, device=freqs.device)

    freqs_h = torch.outer(h, freqs[::2]).float()
    freqs_w = torch.outer(w, freqs[1::2]).float()
    freqs_2d = torch.cat(
        [
            freqs_h[:, None, :].repeat(1, width, 1),
            freqs_w[None, :, :].repeat(height, 1, 1),
        ],
        dim=-1,
    )
    return torch.polar(torch.ones_like(freqs_2d), freqs_2d)