Skip to content

vllm.model_executor.models.glm4v

Inference-only CogAgent model compatible with THUDM weights.

EVA2CLIPGLU

Bases: Module

Source code in vllm/model_executor/models/glm4v.py
class EVA2CLIPGLU(nn.Module):
    def __init__(
        self,
        config,
        in_features,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        """
        The original implementation is the same as:
        ```python
        self.dense_h_to_4h = ColumnParallelLinear(
            config.hidden_size,
            config.ffn_hidden_size,
            bias=False,
            quant_config=quant_config,
        )

        self.gate_proj = ColumnParallelLinear(
            config.hidden_size,
            config.ffn_hidden_size,
            bias=False,
            quant_config=quant_config,
        )
        ```
        ```
        gate_proj_output, _ = self.gate_proj(x)
        dense_h_to_4h_output, _ = self.dense_h_to_4h(x)
        x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1)
        ```

        We merge two ColumnParallelLinear into one MergedColumnParallelLinear:
        ```
        self.merged_proj = MergedColumnParallelLinear(
            config.hidden_size,
            [config.ffn_hidden_size] * 2,
            bias=False,
            quant_config=quant_config,
        )
        ```
        ```
        x, _ = self.merged_proj(x)
        ```
        """
        super().__init__()
        self.linear_proj = ReplicatedLinear(
            in_features,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_proj",
        )
        self.norm1 = nn.LayerNorm(config.hidden_size)
        self.act1 = nn.GELU()
        self.act2 = SiluAndMul()

        self.merged_proj = MergedColumnParallelLinear(
            config.hidden_size,
            [config.ffn_hidden_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.merged_proj",
        )

        self.dense_4h_to_h = RowParallelLinear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.dense_4h_to_h",
        )

    def forward(self, x):
        x, _ = self.linear_proj(x)
        x = self.act1(self.norm1(x))
        x, _ = self.merged_proj(x)
        x = self.act2(x)
        x, _ = self.dense_4h_to_h(x)
        return x

__init__

__init__(
    config,
    in_features,
    quant_config: QuantizationConfig | None = None,
    prefix: str = "",
)

The original implementation is the same as:

self.dense_h_to_4h = ColumnParallelLinear(
    config.hidden_size,
    config.ffn_hidden_size,
    bias=False,
    quant_config=quant_config,
)

self.gate_proj = ColumnParallelLinear(
    config.hidden_size,
    config.ffn_hidden_size,
    bias=False,
    quant_config=quant_config,
)
gate_proj_output, _ = self.gate_proj(x)
dense_h_to_4h_output, _ = self.dense_h_to_4h(x)
x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1)

We merge two ColumnParallelLinear into one MergedColumnParallelLinear:

self.merged_proj = MergedColumnParallelLinear(
    config.hidden_size,
    [config.ffn_hidden_size] * 2,
    bias=False,
    quant_config=quant_config,
)
x, _ = self.merged_proj(x)

Source code in vllm/model_executor/models/glm4v.py
def __init__(
    self,
    config,
    in_features,
    quant_config: QuantizationConfig | None = None,
    prefix: str = "",
):
    """
    The original implementation is the same as:
    ```python
    self.dense_h_to_4h = ColumnParallelLinear(
        config.hidden_size,
        config.ffn_hidden_size,
        bias=False,
        quant_config=quant_config,
    )

    self.gate_proj = ColumnParallelLinear(
        config.hidden_size,
        config.ffn_hidden_size,
        bias=False,
        quant_config=quant_config,
    )
    ```
    ```
    gate_proj_output, _ = self.gate_proj(x)
    dense_h_to_4h_output, _ = self.dense_h_to_4h(x)
    x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1)
    ```

    We merge two ColumnParallelLinear into one MergedColumnParallelLinear:
    ```
    self.merged_proj = MergedColumnParallelLinear(
        config.hidden_size,
        [config.ffn_hidden_size] * 2,
        bias=False,
        quant_config=quant_config,
    )
    ```
    ```
    x, _ = self.merged_proj(x)
    ```
    """
    super().__init__()
    self.linear_proj = ReplicatedLinear(
        in_features,
        config.hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.linear_proj",
    )
    self.norm1 = nn.LayerNorm(config.hidden_size)
    self.act1 = nn.GELU()
    self.act2 = SiluAndMul()

    self.merged_proj = MergedColumnParallelLinear(
        config.hidden_size,
        [config.ffn_hidden_size] * 2,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.merged_proj",
    )

    self.dense_4h_to_h = RowParallelLinear(
        config.ffn_hidden_size,
        config.hidden_size,
        bias=False,
        quant_config=quant_config,
        prefix=f"{prefix}.dense_4h_to_h",
    )

EVA2CLIPModel

Bases: Module

Source code in vllm/model_executor/models/glm4v.py
class EVA2CLIPModel(nn.Module):
    def __init__(
        self,
        config,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        vision_config = Namespace(**config.vision_config)
        self.patch_embedding = EVA2CLIPPatchEmbedding(vision_config)
        self.transformer = EVA2CLIPTransformer(
            vision_config, quant_config=quant_config, prefix=f"{prefix}.transformer"
        )
        self.linear_proj = EVA2CLIPGLU(
            config,
            in_features=config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.linear_proj",
        )
        self.conv = Conv2dLayer(
            in_channels=vision_config.hidden_size,
            out_channels=config.hidden_size,
            kernel_size=2,
            stride=2,
        )
        self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
        self.scaling_factor = vision_config.scaling_factor

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """
        Parameters:
        images : torch.Tensor
            Input image tensor with shape (B, C, H, W)

        Returns:
        torch.Tensor
            Transformed tensor with shape (B, L, D)
        """
        x = self.patch_embedding(images)
        x = self.transformer(x)
        x = x[:, 1:]

        b, s, h = x.shape
        grid_size = int(s**0.5)
        x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
        x = self.conv(x)

        x = x.flatten(2).transpose(1, 2)
        x = self.linear_proj(x)
        boi = self.boi.expand(x.shape[0], -1, -1)
        eoi = self.eoi.expand(x.shape[0], -1, -1)
        x = torch.cat((boi, x, eoi), dim=1)
        x = x / self.scaling_factor
        return x

forward

forward(images: Tensor) -> Tensor

images : torch.Tensor Input image tensor with shape (B, C, H, W)

torch.Tensor Transformed tensor with shape (B, L, D)

Source code in vllm/model_executor/models/glm4v.py
def forward(self, images: torch.Tensor) -> torch.Tensor:
    """
    Parameters:
    images : torch.Tensor
        Input image tensor with shape (B, C, H, W)

    Returns:
    torch.Tensor
        Transformed tensor with shape (B, L, D)
    """
    x = self.patch_embedding(images)
    x = self.transformer(x)
    x = x[:, 1:]

    b, s, h = x.shape
    grid_size = int(s**0.5)
    x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2)
    x = self.conv(x)

    x = x.flatten(2).transpose(1, 2)
    x = self.linear_proj(x)
    boi = self.boi.expand(x.shape[0], -1, -1)
    eoi = self.eoi.expand(x.shape[0], -1, -1)
    x = torch.cat((boi, x, eoi), dim=1)
    x = x / self.scaling_factor
    return x

EVA2CLIPPatchEmbedding

Bases: Module

Source code in vllm/model_executor/models/glm4v.py
class EVA2CLIPPatchEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.proj = Conv2dLayer(
            config.in_channels,
            config.hidden_size,
            kernel_size=config.patch_size,
            stride=config.patch_size,
        )
        self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size))
        self.position_embedding = nn.Embedding(config.num_positions, config.hidden_size)

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """
        Parameters:
        images : torch.Tensor
            Input image tensor with shape (B, C, H, W)

        Returns:
        torch.Tensor
            Transformed tensor with shape (B, L, D)
        """
        images = images.to(device=self.proj.weight.device, dtype=self.proj.weight.dtype)
        x = self.proj(images)
        x = x.flatten(2).transpose(1, 2)
        cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x += self.position_embedding.weight.unsqueeze(0)
        return x

forward

forward(images: Tensor) -> Tensor

images : torch.Tensor Input image tensor with shape (B, C, H, W)

torch.Tensor Transformed tensor with shape (B, L, D)

Source code in vllm/model_executor/models/glm4v.py
def forward(self, images: torch.Tensor) -> torch.Tensor:
    """
    Parameters:
    images : torch.Tensor
        Input image tensor with shape (B, C, H, W)

    Returns:
    torch.Tensor
        Transformed tensor with shape (B, L, D)
    """
    images = images.to(device=self.proj.weight.device, dtype=self.proj.weight.dtype)
    x = self.proj(images)
    x = x.flatten(2).transpose(1, 2)
    cls_token = self.cls_embedding.expand(x.shape[0], -1, -1)
    x = torch.cat((cls_token, x), dim=1)
    x += self.position_embedding.weight.unsqueeze(0)
    return x

GLM4VForCausalLM

Bases: ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE

Source code in vllm/model_executor/models/glm4v.py
@MULTIMODAL_REGISTRY.register_processor(
    GLM4VMultiModalProcessor,
    info=GLM4VProcessingInfo,
    dummy_inputs=GLM4VDummyInputsBuilder,
)
class GLM4VForCausalLM(
    ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
    packed_modules_mapping = {
        "query_key_value": ["query_key_value"],
        "dense_h_to_4h": ["dense_h_to_4h"],
        "merged_proj": ["gate_proj", "dense_h_to_4h"],
    }

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="transformer.encoder",
            connector="transformer.vision.linear_proj",
            tower_model="transformer.vision.transformer",
        )

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return "<|begin_of_image|><|endoftext|><|end_of_image|>"

        raise ValueError("Only image modality is supported")

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        transformer_type: type[GLM4VModel] = GLM4VModel,
    ) -> None:
        with self._mark_composite_model(
            vllm_config,
            language_targets=GLMTransformer,
            tower_targets={"image": EVA2CLIPModel},
        ):
            super().__init__(
                vllm_config=vllm_config,
                prefix=prefix,
                transformer_type=transformer_type,
            )

        self.transformer: GLM4VModel

    def _parse_and_validate_image_input(
        self, **kwargs: object
    ) -> GLMVImagePixelInputs | None:
        pixel_values = kwargs.pop("pixel_values", None)

        if pixel_values is not None:
            expected_h = expected_w = self.config.vision_config["image_size"]
            return GLMVImagePixelInputs(
                type="pixel_values",
                data=pixel_values,
                resolve_bindings={"h": expected_h, "w": expected_w},
            )

        return None

    def _process_image_input(self, image_input: GLMVImagePixelInputs) -> torch.Tensor:
        pixel_values = image_input["data"].to(dtype=self.config.dtype)

        return self.transformer.vision(pixel_values)

    def iter_mm_grid_thw(
        self, mm_features: list[MultiModalFeatureSpec]
    ) -> Iterator[tuple[int, int, int, int]]:
        hf_config = self.config
        spatial_merge_size = hf_config.vision_config.spatial_merge_size
        for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
            offset = mm_feature.mm_position.offset
            if mm_feature.modality == "image":
                t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
                assert t == 1, f"Image must have 1 frame, got {t}"
                yield offset, t, h // spatial_merge_size, w // spatial_merge_size
            else:
                # glm4v only supports image modality
                raise ValueError(f"Unsupported modality: {mm_feature.modality}")

    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
        mm_features: list[MultiModalFeatureSpec],
    ) -> tuple[torch.Tensor, int]:
        llm_pos_ids_list: list = []
        st = 0
        for (
            offset,
            llm_grid_t,
            llm_grid_h,
            llm_grid_w,
        ) in self.iter_mm_grid_thw(mm_features):
            text_len = offset - st
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )
            grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w)).reshape(
                3, -1
            )
            llm_pos_ids_list.append(grid_indices + text_len + st_idx)
            # EVA2CLIPModel has embeddings for boi and eoi tokens as well
            st = offset + 1 + llm_grid_t * llm_grid_h * llm_grid_w + 1

        if st < len(input_tokens):
            text_len = len(input_tokens) - st
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )

        llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
        return torch.from_numpy(llm_positions), mrope_position_delta

    embed_input_ids = SupportsMultiModal.embed_input_ids

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )

        return hidden_states

get_mm_mapping

get_mm_mapping() -> MultiModelKeys

Get the module prefix in multimodal models

Source code in vllm/model_executor/models/glm4v.py
def get_mm_mapping(self) -> MultiModelKeys:
    """
    Get the module prefix in multimodal models
    """
    return MultiModelKeys.from_string_field(
        language_model="transformer.encoder",
        connector="transformer.vision.linear_proj",
        tower_model="transformer.vision.transformer",
    )

GLM4VProcessor

This model doesn't define its own HF processor, so we implement our own one here.

Source code in vllm/model_executor/models/glm4v.py
class GLM4VProcessor:
    """
    This model doesn't define its own HF processor,
    so we implement our own one here.
    """

    def __init__(
        self,
        config: ChatGLMConfig,
        tokenizer: PreTrainedTokenizer,
    ) -> None:
        super().__init__()

        self.config = config
        self.tokenizer = tokenizer

        vision_config = config.vision_config
        image_size = vision_config["image_size"]

        self.image_transform = transforms.Compose(
            [
                transforms.Resize(
                    (image_size, image_size),
                    interpolation=InterpolationMode.BICUBIC,
                ),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=(0.48145466, 0.4578275, 0.40821073),
                    std=(0.26862954, 0.26130258, 0.27577711),
                ),
            ]
        )

    def __call__(
        self,
        text: TextInput | list[TextInput] | None = None,
        images: ImageInput | list[ImageInput] | None = None,
        return_tensors: str | TensorType | None = None,
    ) -> BatchFeature:
        if text is None:
            text = []
        if not isinstance(text, list):
            text = [text]
        if images is None:
            images = []
        if not isinstance(images, list):
            images = [images]

        text_inputs = self.tokenizer(text)

        if len(images) == 0:
            image_inputs = {}
        else:
            pixel_values = [self.image_transform(image) for image in images]
            image_inputs = {"pixel_values": torch.stack(pixel_values)}

        return BatchFeature(
            {
                **text_inputs,
                **image_inputs,
            },
            tensor_type=return_tensors,
        )

GLMVImagePixelInputs

Bases: TensorSchema

Dimensions
  • b: Batch size
  • c: Number of channels (3)
  • h: Height of image
  • w: Width of image
Source code in vllm/model_executor/models/glm4v.py
class GLMVImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - b: Batch size
        - c: Number of channels (3)
        - h: Height of image
        - w: Width of image
    """

    type: Literal["pixel_values"] = "pixel_values"
    data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")]