Skip to content

vllm.model_executor.models.phi3v

Phi3HDImageEmbedding

Bases: Module

Phi3 Image embedding with HD transform.

Source code in vllm/model_executor/models/phi3v.py
class Phi3HDImageEmbedding(nn.Module):
    """Phi3 Image embedding with HD transform."""

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: QuantizationConfig | None,
        prefix: str = "",
    ) -> None:
        super().__init__()

        # n_embed or hidden_size
        hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size

        self.img_processor = _init_img_processor(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.img_processor",
        )

        image_dim_out = config.img_processor["image_dim_out"]
        self.num_img_tokens = config.img_processor["num_img_tokens"]

        self.image_dim_out = image_dim_out

        # global_gn and sub_gn for hd transform, serves as line separator
        self.use_hd_transform = config.embd_layer.get("use_hd_transform", False)
        self.with_learnable_separator = config.embd_layer.get(
            "with_learnable_separator", False
        )
        self.hd_transform_order = config.embd_layer.get("hd_transform_order", "glb_sub")
        # with_hd_transform and with_learnable_separator should have same value
        assert self.use_hd_transform and self.with_learnable_separator

        # 1024 * 4, merge spatial to channel dimension
        self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
        self.sub_GN = nn.Parameter(torch.empty([1, 1, 1, self.image_dim_out * 4]))

        dim_projection = hidden_size
        depth = 2
        layers: list[nn.Module] = [nn.Linear(image_dim_out * 4, dim_projection)]
        for _ in range(1, depth):
            layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
        self.img_projection = nn.Sequential(*layers)

        self.type_feature = config.img_processor.get("type_feature", "patch")

    def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
        type_feature = self.type_feature

        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the img_processor
        img_feature = self.img_processor(img_embeds)

        if type_feature == "patch":
            patch_feature = img_feature[:, 1:]
            return patch_feature

        if type_feature == "cls_patch":
            return img_feature

        raise NotImplementedError(type_feature)

    def forward(
        self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor
    ) -> torch.FloatTensor:
        """
        process image and return vision embeddings.

        pixel_values: (num_images, num_crops, c, h, w)
        output: (num_images, num_img_tokens, hidden_size)
        """
        num_images, num_crops, c, h, w = pixel_values.shape
        pixel_values = pixel_values.flatten(0, 1)
        img_features = self.get_img_features(pixel_values)
        img_features = img_features.reshape(
            num_images, num_crops, -1, self.image_dim_out
        )
        image_features_proj = self.hd_feature_transform(img_features, image_sizes)
        return image_features_proj

    def hd_feature_transform(self, image_features, image_sizes):
        """
        image_features: (num_images, num_crops+1, 24*24, 1024)
        """
        assert self.hd_transform_order == "sub_glb", (
            f"hd_transform_order `{self.hd_transform_order}` not implemented"
        )
        if isinstance(self.img_projection, nn.Sequential):
            target_device = self.img_projection[0].bias.device
            target_dtype = self.img_projection[0].bias.dtype
        else:  # It's a single nn.Linear layer
            target_device = self.img_projection.bias.device
            target_dtype = self.img_projection.bias.dtype

        global_image_features = image_features[:, 0]  # (num_images, 24*24, 1024)
        # global feature can be viewed as a special HD case with num_crops 1x1
        global_image_features_hd = self.reshape_hd_patches_2x2merge(
            global_image_features, 1, 1
        )
        global_image_features_hd_newline = self.add_image_newline(
            global_image_features_hd
        )

        batch_image_features_proj = []
        # need a for loop to process each image because of different image sizes
        # (patch arrangement is different for each image)
        for i, img_size in enumerate(image_sizes):
            h, w = img_size
            h_crop = h // 336
            w_crop = w // 336
            num_crops = h_crop * w_crop

            # NOTE: real num_crops is padded
            # (num_crops, 24*24, 1024)
            sub_image_features = image_features[i, 1 : 1 + num_crops]
            sub_image_features_hd = self.reshape_hd_patches_2x2merge(
                sub_image_features, h_crop, w_crop
            )
            sub_image_features_hd_newline = self.add_image_newline(
                sub_image_features_hd
            )

            # [sub features, separator, global features]
            image_embeddings = torch.cat(
                [
                    sub_image_features_hd_newline.squeeze(
                        0
                    ),  # (h_crop*12*(w_crop*12+1), 4096)
                    self.glb_GN.squeeze(0),
                    global_image_features_hd_newline[i],
                ]
            )
            img_proj = self.img_projection(
                image_embeddings.to(target_device, target_dtype)
            )
            batch_image_features_proj.append(img_proj)

        return batch_image_features_proj

    def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
        """
        image_features: (num_images*num_crops, 24*24, 1024)
        output: (num_images, h_crop*12, w_crop*12, 4096)
        where h_crop*w_crop == num_crops
        """
        N, L, C = image_features.shape
        assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
        num_images = N // (h_crop * w_crop)
        H = int(L**0.5)
        image_features_hd = (
            image_features.reshape(N, H, H, C)  # N, 24, 24, 1024
            .reshape(N, H // 2, 2, H // 2, 2, C)  # N, 12, 2, 12, 2, 1024
            .permute(0, 1, 3, 2, 4, 5)  # N, 12, 12, 2, 2, 1024
            .reshape(N, -1, 4 * C)  # N, 144, 4096
            .reshape(
                num_images, h_crop, w_crop, H // 2, H // 2, -1
            )  # n_img, h_crop, w_crop, 12, 12, 4096
            .permute(0, 1, 3, 2, 4, 5)  # n_img, h_crop, 12, w_crop, 12, 4096
            .reshape(
                num_images, h_crop * H // 2, w_crop * H // 2, 4 * C
            )  # n_img, h_crop*12, w_crop*12, 4096
        )
        return image_features_hd

    def add_image_newline(self, image_features_hd):
        """
        image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
        output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
        """
        num_images, h, w, hid_dim = image_features_hd.shape
        # add the newline token to the HD image feature patches
        newline_embeddings = self.sub_GN.expand(
            num_images, h, -1, -1
        )  # (n_img, h, 1, hid_dim)
        image_features_hd_newline = torch.cat(
            [image_features_hd, newline_embeddings], dim=2
        ).reshape(num_images, -1, hid_dim)
        return image_features_hd_newline

add_image_newline

add_image_newline(image_features_hd)

image_features_hd: (num_images, h_crop12, w_crop12, 4096) output: (num_images, (h_crop12) * (w_crop12+1), 4096)

Source code in vllm/model_executor/models/phi3v.py
def add_image_newline(self, image_features_hd):
    """
    image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
    output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
    """
    num_images, h, w, hid_dim = image_features_hd.shape
    # add the newline token to the HD image feature patches
    newline_embeddings = self.sub_GN.expand(
        num_images, h, -1, -1
    )  # (n_img, h, 1, hid_dim)
    image_features_hd_newline = torch.cat(
        [image_features_hd, newline_embeddings], dim=2
    ).reshape(num_images, -1, hid_dim)
    return image_features_hd_newline

forward

forward(
    pixel_values: FloatTensor, image_sizes: Tensor
) -> FloatTensor

process image and return vision embeddings.

pixel_values: (num_images, num_crops, c, h, w) output: (num_images, num_img_tokens, hidden_size)

Source code in vllm/model_executor/models/phi3v.py
def forward(
    self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor
) -> torch.FloatTensor:
    """
    process image and return vision embeddings.

    pixel_values: (num_images, num_crops, c, h, w)
    output: (num_images, num_img_tokens, hidden_size)
    """
    num_images, num_crops, c, h, w = pixel_values.shape
    pixel_values = pixel_values.flatten(0, 1)
    img_features = self.get_img_features(pixel_values)
    img_features = img_features.reshape(
        num_images, num_crops, -1, self.image_dim_out
    )
    image_features_proj = self.hd_feature_transform(img_features, image_sizes)
    return image_features_proj

hd_feature_transform

hd_feature_transform(image_features, image_sizes)

image_features: (num_images, num_crops+1, 24*24, 1024)

Source code in vllm/model_executor/models/phi3v.py
def hd_feature_transform(self, image_features, image_sizes):
    """
    image_features: (num_images, num_crops+1, 24*24, 1024)
    """
    assert self.hd_transform_order == "sub_glb", (
        f"hd_transform_order `{self.hd_transform_order}` not implemented"
    )
    if isinstance(self.img_projection, nn.Sequential):
        target_device = self.img_projection[0].bias.device
        target_dtype = self.img_projection[0].bias.dtype
    else:  # It's a single nn.Linear layer
        target_device = self.img_projection.bias.device
        target_dtype = self.img_projection.bias.dtype

    global_image_features = image_features[:, 0]  # (num_images, 24*24, 1024)
    # global feature can be viewed as a special HD case with num_crops 1x1
    global_image_features_hd = self.reshape_hd_patches_2x2merge(
        global_image_features, 1, 1
    )
    global_image_features_hd_newline = self.add_image_newline(
        global_image_features_hd
    )

    batch_image_features_proj = []
    # need a for loop to process each image because of different image sizes
    # (patch arrangement is different for each image)
    for i, img_size in enumerate(image_sizes):
        h, w = img_size
        h_crop = h // 336
        w_crop = w // 336
        num_crops = h_crop * w_crop

        # NOTE: real num_crops is padded
        # (num_crops, 24*24, 1024)
        sub_image_features = image_features[i, 1 : 1 + num_crops]
        sub_image_features_hd = self.reshape_hd_patches_2x2merge(
            sub_image_features, h_crop, w_crop
        )
        sub_image_features_hd_newline = self.add_image_newline(
            sub_image_features_hd
        )

        # [sub features, separator, global features]
        image_embeddings = torch.cat(
            [
                sub_image_features_hd_newline.squeeze(
                    0
                ),  # (h_crop*12*(w_crop*12+1), 4096)
                self.glb_GN.squeeze(0),
                global_image_features_hd_newline[i],
            ]
        )
        img_proj = self.img_projection(
            image_embeddings.to(target_device, target_dtype)
        )
        batch_image_features_proj.append(img_proj)

    return batch_image_features_proj

reshape_hd_patches_2x2merge

reshape_hd_patches_2x2merge(image_features, h_crop, w_crop)

image_features: (num_imagesnum_crops, 2424, 1024) output: (num_images, h_crop12, w_crop12, 4096) where h_crop*w_crop == num_crops

Source code in vllm/model_executor/models/phi3v.py
def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
    """
    image_features: (num_images*num_crops, 24*24, 1024)
    output: (num_images, h_crop*12, w_crop*12, 4096)
    where h_crop*w_crop == num_crops
    """
    N, L, C = image_features.shape
    assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
    num_images = N // (h_crop * w_crop)
    H = int(L**0.5)
    image_features_hd = (
        image_features.reshape(N, H, H, C)  # N, 24, 24, 1024
        .reshape(N, H // 2, 2, H // 2, 2, C)  # N, 12, 2, 12, 2, 1024
        .permute(0, 1, 3, 2, 4, 5)  # N, 12, 12, 2, 2, 1024
        .reshape(N, -1, 4 * C)  # N, 144, 4096
        .reshape(
            num_images, h_crop, w_crop, H // 2, H // 2, -1
        )  # n_img, h_crop, w_crop, 12, 12, 4096
        .permute(0, 1, 3, 2, 4, 5)  # n_img, h_crop, 12, w_crop, 12, 4096
        .reshape(
            num_images, h_crop * H // 2, w_crop * H // 2, 4 * C
        )  # n_img, h_crop*12, w_crop*12, 4096
    )
    return image_features_hd

Phi3VImageEmbeddingInputs

Bases: TensorSchema

Dimensions
  • b: Batch size
  • n: Number of images
  • f: Image feature size (e.g., number of tokens per image)
  • h: Hidden size (must match language model backbone)
Source code in vllm/model_executor/models/phi3v.py
class Phi3VImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - n: Number of images
        - f: Image feature size (e.g., number of tokens per image)
        - h: Hidden size (must match language model backbone)
    """

    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[
        torch.Tensor | list[torch.Tensor],
        TensorShape("bn", "f", "h"),
    ]

Phi3VImagePixelInputs

Bases: TensorSchema

Dimensions
  • b: Batch size
  • n: Number of images
  • p: Number of patches
  • h: Height of each patch
  • w: Width of each patch
Source code in vllm/model_executor/models/phi3v.py
class Phi3VImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - n: Number of images
        - p: Number of patches
        - h: Height of each patch
        - w: Width of each patch
    """

    type: Literal["pixel_values", "image_embeds"] = "pixel_values"

    # Supports either a stacked tensor or a list of (p, 3, h, w) tensors
    pixel_values: Annotated[
        torch.Tensor | list[torch.Tensor],
        TensorShape(
            "bn", "p", 3, "h", "w", dynamic_dims={"p"}
        ),  # 'p' may vary across items
    ]

    # Stacked tensor with height and width for each image
    image_sizes: Annotated[torch.Tensor | None, TensorShape("bn", 2)]