Skip to content

vllm.model_executor.models.kanana_v

CustomQwen2VLVE

Bases: Qwen2VisionTransformer

Thin wrapper around the Qwen2-VL used as a vision encoder.

This mirrors the original HF-based vision encoder used in Kanana-V, but reuses vLLM's optimized Qwen2VisionTransformer building blocks.

Source code in vllm/model_executor/models/kanana_v.py
class CustomQwen2VLVE(Qwen2VisionTransformer):
    """Thin wrapper around the Qwen2-VL used as a vision encoder.

    This mirrors the original HF-based vision encoder used in Kanana-V, but
    reuses vLLM's optimized `Qwen2VisionTransformer` building blocks.
    """

    def __init__(self, config: Qwen2VLVisionConfig) -> None:
        super().__init__(
            vision_config=config,
            norm_eps=getattr(config, "rms_norm_eps", 1e-6),
            quant_config=None,
            prefix="",
        )

        # Kanana-V uses its own projector/abstractor instead of the Qwen2
        # built-in patch merger, so we drop the merger module to keep the
        # parameter set compatible with the original checkpoint.
        if hasattr(self, "merger"):
            del self.merger

    @classmethod
    def _from_config(cls, config: Qwen2VLVisionConfig) -> "CustomQwen2VLVE":
        """Drop-in replacement for the HF `_from_config` constructor."""
        return cls(config)

    def forward(
        self,
        pixel_values: torch.Tensor,
        grid_thw: torch.Tensor,
        output_hidden_states: bool | None = None,
        return_dict: bool | None = None,
    ) -> tuple | BaseModelOutput:
        """Run the vision transformer and optionally return intermediate states.

        Unlike the base `Qwen2VisionTransformer`, this wrapper exposes the
        pre-merger patch-level representations and a HF-style `BaseModelOutput`
        so that the existing projector / abstractor code can be reused.
        """
        assert return_dict, "Only return_dict=True is supported."

        # Patchify
        x = pixel_values.to(device=self.device, dtype=self.dtype)
        x = self.patch_embed(x)  # (num_patches, embed_dim)

        # Prepare grid and rotary embeddings – mirror base implementation.
        if isinstance(grid_thw, list):
            grid_thw_list = grid_thw
            grid_thw_np = np.array(grid_thw, dtype=np.int32)
        else:
            grid_thw_list = grid_thw.tolist()
            grid_thw_np = grid_thw.cpu().numpy()

        rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)

        # Compute cu_seqlens in numpy then move to device, same as base model.
        cu_seqlens = np.repeat(
            grid_thw_np[:, 1] * grid_thw_np[:, 2],
            grid_thw_np[:, 0],
        ).cumsum(axis=0, dtype=np.int32)
        cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
        cu_seqlens = torch.from_numpy(cu_seqlens).to(
            self.device,
            non_blocking=True,
        )

        # Shape to (S, B, D) with batch dimension 1 as expected by the blocks.
        x = x.unsqueeze(1)

        # Pre-compute seqlens for attention backend.
        max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)

        encoder_states = () if output_hidden_states else None

        for blk in self.blocks:
            if output_hidden_states:
                # Store patch-level states (S, D).
                encoder_states = encoder_states + (x.squeeze(1),)

            x = blk(
                x,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb_cos=rotary_pos_emb_cos,
                rotary_pos_emb_sin=rotary_pos_emb_sin,
                max_seqlen=max_seqlen,
            )

        # Final hidden state at patch level (S, D).
        hidden_states = x.squeeze(1)
        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, encoder_states] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states,
            hidden_states=encoder_states,
        )

    def get_num_tokens(self) -> int:
        # Not used in the current Kanana-V pipeline, kept for API compatibility.
        return -1

_from_config classmethod

_from_config(
    config: Qwen2VLVisionConfig,
) -> CustomQwen2VLVE

Drop-in replacement for the HF _from_config constructor.

Source code in vllm/model_executor/models/kanana_v.py
@classmethod
def _from_config(cls, config: Qwen2VLVisionConfig) -> "CustomQwen2VLVE":
    """Drop-in replacement for the HF `_from_config` constructor."""
    return cls(config)

forward

forward(
    pixel_values: Tensor,
    grid_thw: Tensor,
    output_hidden_states: bool | None = None,
    return_dict: bool | None = None,
) -> tuple | BaseModelOutput

Run the vision transformer and optionally return intermediate states.

Unlike the base Qwen2VisionTransformer, this wrapper exposes the pre-merger patch-level representations and a HF-style BaseModelOutput so that the existing projector / abstractor code can be reused.

Source code in vllm/model_executor/models/kanana_v.py
def forward(
    self,
    pixel_values: torch.Tensor,
    grid_thw: torch.Tensor,
    output_hidden_states: bool | None = None,
    return_dict: bool | None = None,
) -> tuple | BaseModelOutput:
    """Run the vision transformer and optionally return intermediate states.

    Unlike the base `Qwen2VisionTransformer`, this wrapper exposes the
    pre-merger patch-level representations and a HF-style `BaseModelOutput`
    so that the existing projector / abstractor code can be reused.
    """
    assert return_dict, "Only return_dict=True is supported."

    # Patchify
    x = pixel_values.to(device=self.device, dtype=self.dtype)
    x = self.patch_embed(x)  # (num_patches, embed_dim)

    # Prepare grid and rotary embeddings – mirror base implementation.
    if isinstance(grid_thw, list):
        grid_thw_list = grid_thw
        grid_thw_np = np.array(grid_thw, dtype=np.int32)
    else:
        grid_thw_list = grid_thw.tolist()
        grid_thw_np = grid_thw.cpu().numpy()

    rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list)

    # Compute cu_seqlens in numpy then move to device, same as base model.
    cu_seqlens = np.repeat(
        grid_thw_np[:, 1] * grid_thw_np[:, 2],
        grid_thw_np[:, 0],
    ).cumsum(axis=0, dtype=np.int32)
    cu_seqlens = np.concatenate([np.zeros(1, dtype=np.int32), cu_seqlens])
    cu_seqlens = torch.from_numpy(cu_seqlens).to(
        self.device,
        non_blocking=True,
    )

    # Shape to (S, B, D) with batch dimension 1 as expected by the blocks.
    x = x.unsqueeze(1)

    # Pre-compute seqlens for attention backend.
    max_seqlen = self.compute_attn_mask_seqlen(cu_seqlens)

    encoder_states = () if output_hidden_states else None

    for blk in self.blocks:
        if output_hidden_states:
            # Store patch-level states (S, D).
            encoder_states = encoder_states + (x.squeeze(1),)

        x = blk(
            x,
            cu_seqlens=cu_seqlens,
            rotary_pos_emb_cos=rotary_pos_emb_cos,
            rotary_pos_emb_sin=rotary_pos_emb_sin,
            max_seqlen=max_seqlen,
        )

    # Final hidden state at patch level (S, D).
    hidden_states = x.squeeze(1)
    if output_hidden_states:
        encoder_states = encoder_states + (hidden_states,)

    if not return_dict:
        return tuple(v for v in [hidden_states, encoder_states] if v is not None)
    return BaseModelOutput(
        last_hidden_state=hidden_states,
        hidden_states=encoder_states,
    )

DynamicCAbstractor

Bases: Module

Dynamic C-Abstractor based on RegNet blocks.

Source code in vllm/model_executor/models/kanana_v.py
class DynamicCAbstractor(nn.Module):
    """Dynamic C-Abstractor based on RegNet blocks."""

    def __init__(
        self,
        config: Qwen2VLVisionConfig,
        num_input_tokens: int,
    ) -> None:
        super().__init__()
        assert hasattr(config, "merge_size"), "merge_size must be provided."
        self.config = config
        self.merge_size = config.merge_size
        self.pos_emb_size = config.pos_emb_size
        if num_input_tokens == -1:
            num_input_tokens = config.pos_emb_size
        self.num_input_tokens = num_input_tokens
        self.pos_emb = build_pos_embeds(
            config, num_input_tokens, config.encoder_hidden_size
        )
        self.build_net()

    def _load_from_state_dict(self, state_dict, *args, **kwargs) -> None:
        if not state_dict:
            return

        if self.pos_emb is not None:
            key_re = re.compile(r"[\w,.]*abstractor[\w,.]*pos_emb")
            pos_emb_key = None
            for key in state_dict:
                if key_re.match(key):
                    pos_emb_key = key
                    break

            assert pos_emb_key is not None
            # update old ckpt compatible with current code
            pos_emb = state_dict[pos_emb_key]
            if pos_emb.size(1) == self.pos_emb.size(1) + 1:
                # remove obsolete first pos emb (for cls token originally)
                state_dict[pos_emb_key] = pos_emb[:, 1:]

        super()._load_from_state_dict(state_dict, *args, **kwargs)

    def build_net(self) -> None:
        encoder_hidden_size = self.config.encoder_hidden_size
        hidden_size = self.config.hidden_size
        output_hidden_size = self.config.output_hidden_size
        depth = self.config.depth
        mlp_depth = self.config.mlp_depth

        RegBlock = partial(
            RegStage,
            stride=1,
            dilation=1,
            act_layer=nn.SiLU,
            norm_layer=LayerNorm2d,
        )

        s1 = RegBlock(
            depth,
            encoder_hidden_size,
            hidden_size,
        )
        sampler = PatchMerge(merge_size=self.merge_size)
        s2 = RegBlock(
            depth,
            self.merge_size**2 * hidden_size,
            hidden_size,
        )

        if depth:
            self.net = nn.ModuleList([s1, sampler, s2])
            self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size)
        else:
            self.net = sampler
            self.readout = build_mlp(mlp_depth, encoder_hidden_size, output_hidden_size)

    def forward(
        self,
        flattened_visual_embeds: torch.Tensor,
        grid_thw: torch.Tensor,
        **unused_kwargs: object,
    ) -> BaseModelOutput:
        """Apply the dynamic abstractor over flattened visual embeddings."""
        n_token_loc = torch.prod(grid_thw, dim=1)
        split_visual_embeds = torch.split(flattened_visual_embeds, n_token_loc.tolist())

        flattened_visual_embeds = []
        for _visual_embeds, _grid_thw in zip(split_visual_embeds, grid_thw):
            T, H, W = _grid_thw
            assert T == 1, "T must be 1. Video is not supported yet."
            reshaped_visual_embeds = rearrange(
                _visual_embeds, "(t h w) d -> 1 t h w d", t=T, h=H, w=W
            )
            # remove temporal dim
            reshaped_visual_embeds = reshaped_visual_embeds[:, 0]

            if self.pos_emb is not None:
                # interpolate pos emb and add to visual embeds
                _local_pos_emb = resample_abs_pos_embed(
                    posemb=self.pos_emb,
                    old_size=tuple([int(self.pos_emb_size**0.5)] * 2),
                    new_size=(H, W),
                    num_prefix_tokens=0,
                )
                _local_pos_emb = rearrange(
                    _local_pos_emb,
                    "1 (h w) d -> 1 h w d",
                    h=H,
                    w=W,
                )
                reshaped_visual_embeds = reshaped_visual_embeds + _local_pos_emb

            reshaped_visual_embeds = self._forward(
                reshaped_visual_embeds,
                input_size=(H, W),
            )
            flattened_visual_embeds.append(reshaped_visual_embeds)
        reshaped_visual_embeds = torch.cat(flattened_visual_embeds, dim=0)
        return BaseModelOutput(last_hidden_state=reshaped_visual_embeds)

    def _forward(
        self,
        x: torch.Tensor,
        input_size: tuple[int, int],
    ) -> torch.Tensor:
        h, w = input_size
        x = rearrange(x, "1 h w d -> 1 d h w", h=h, w=w)
        if self.config.depth:
            x = self.net[0](x)
            x = self.net[1](x)
            x = self.net[2](x)
        else:
            # When depth=0, self.net is a single PatchMerge module
            x = self.net(x)
        x = rearrange(x, "1 d h w -> (h w) d")
        x = self.readout(x)
        return x

forward

forward(
    flattened_visual_embeds: Tensor,
    grid_thw: Tensor,
    **unused_kwargs: object,
) -> BaseModelOutput

Apply the dynamic abstractor over flattened visual embeddings.

Source code in vllm/model_executor/models/kanana_v.py
def forward(
    self,
    flattened_visual_embeds: torch.Tensor,
    grid_thw: torch.Tensor,
    **unused_kwargs: object,
) -> BaseModelOutput:
    """Apply the dynamic abstractor over flattened visual embeddings."""
    n_token_loc = torch.prod(grid_thw, dim=1)
    split_visual_embeds = torch.split(flattened_visual_embeds, n_token_loc.tolist())

    flattened_visual_embeds = []
    for _visual_embeds, _grid_thw in zip(split_visual_embeds, grid_thw):
        T, H, W = _grid_thw
        assert T == 1, "T must be 1. Video is not supported yet."
        reshaped_visual_embeds = rearrange(
            _visual_embeds, "(t h w) d -> 1 t h w d", t=T, h=H, w=W
        )
        # remove temporal dim
        reshaped_visual_embeds = reshaped_visual_embeds[:, 0]

        if self.pos_emb is not None:
            # interpolate pos emb and add to visual embeds
            _local_pos_emb = resample_abs_pos_embed(
                posemb=self.pos_emb,
                old_size=tuple([int(self.pos_emb_size**0.5)] * 2),
                new_size=(H, W),
                num_prefix_tokens=0,
            )
            _local_pos_emb = rearrange(
                _local_pos_emb,
                "1 (h w) d -> 1 h w d",
                h=H,
                w=W,
            )
            reshaped_visual_embeds = reshaped_visual_embeds + _local_pos_emb

        reshaped_visual_embeds = self._forward(
            reshaped_visual_embeds,
            input_size=(H, W),
        )
        flattened_visual_embeds.append(reshaped_visual_embeds)
    reshaped_visual_embeds = torch.cat(flattened_visual_embeds, dim=0)
    return BaseModelOutput(last_hidden_state=reshaped_visual_embeds)

KananaVImagePixelInputs

Bases: TensorSchema

Dimensions
  • np: The total number of patches over all images in the batch
  • cps: Number of channels * patch_size * patch_size
  • ni: Number of images
Source code in vllm/model_executor/models/kanana_v.py
class KananaVImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - np: The total number of patches over all images in the batch
        - cps: Number of channels * patch_size * patch_size
        - ni: Number of images
    """

    type: Literal["pixel_values"]

    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]

    vision_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]

KananaVMultiModalProcessor

Bases: BaseMultiModalProcessor[KananaVProcessingInfo]

vLLM multimodal processor for Kanana-V (text + image).

Source code in vllm/model_executor/models/kanana_v.py
class KananaVMultiModalProcessor(BaseMultiModalProcessor[KananaVProcessingInfo]):
    """vLLM multimodal processor for Kanana-V (text + image)."""

    @property
    def media_token_id(self) -> int:
        return self.info.get_hf_config().text_config.eos_token_id + 1

    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
        tok_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        """Run the underlying HF processor on text and image data."""
        # Text-only input is handled as a special case here.
        if not mm_data or not mm_data.get("images", []):
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        # Images
        image_inputs = mm_data.get("images", [])
        pixel_sizes = []
        if not isinstance(image_inputs[0], Image.Image):
            image_inputs = [Image.fromarray(image) for image in image_inputs]

        image_processor = self.info.get_hf_processor().image_processor
        processor_output = [image_processor(image) for image in image_inputs]
        pixel_values = [o["pixel_values"] for o in processor_output]
        image_meta = [o["image_meta"] for o in processor_output]
        # list of dict -> dict of list
        image_meta = {k: [d[k] for d in image_meta] for k in image_meta[0]}

        for pixel_value in pixel_values:
            pixel_sizes.append(pixel_value.shape[0])
        # flattened pixel_values for single example (already includes batch dim)
        pixel_values = torch.concat(pixel_values, dim=0)

        tokenizer = self.info.get_tokenizer()
        media_token = tokenizer.convert_ids_to_tokens([self.media_token_id])[0]
        prompt_replaced = prompt.replace("<image>", media_token)
        input_ids = tokenizer.encode(prompt_replaced)
        input_ids = torch.tensor(input_ids)

        # Ensure HF output is consistent with vLLM prompt-update expectations:
        # if the HF tokenizer emits exactly 1 placeholder token per image, expand
        # it to `T*H*W` placeholder tokens per image so placeholder detection works.
        num_images = len(image_inputs)
        image_token_thw = torch.tensor(image_meta["image_token_thw"])
        per_image_token_counts = image_token_thw.prod(dim=1).tolist()
        expected_total = int(sum(int(x) for x in per_image_token_counts))

        n_placeholders = int((input_ids == self.media_token_id).sum().item())
        if n_placeholders == num_images and expected_total != num_images:
            expanded: list[int] = []
            img_i = 0
            for tok in input_ids.tolist():
                if tok == self.media_token_id and img_i < num_images:
                    expanded.extend(
                        [self.media_token_id] * int(per_image_token_counts[img_i])
                    )
                    img_i += 1
                else:
                    expanded.append(tok)
            input_ids = input_ids.new_tensor(expanded)

        combined_outputs = dict(
            # Add batch dimension to input_ids.
            input_ids=input_ids.unsqueeze(0),
            pixel_values=pixel_values,
            vision_grid_thw=torch.tensor(image_meta["vision_grid_thw"]),
            image_token_thw=torch.tensor(image_meta["image_token_thw"]),
            pixel_sizes=torch.tensor(pixel_sizes),
        )
        return BatchFeature(combined_outputs, tensor_type="pt")

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        def get_replacement(idx: int) -> Sequence[int]:
            out_item = out_mm_kwargs["image"][idx]
            image_token_thw = out_item["image_token_thw"].data
            assert isinstance(image_token_thw, torch.Tensor)

            num_tokens = int(image_token_thw.prod().item())
            return [self.media_token_id] * num_tokens

        return [
            PromptReplacement(
                modality="image",
                target="<image>",
                replacement=get_replacement,
            ),
        ]

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        pixel_sizes = hf_inputs.get("pixel_sizes", torch.empty(0))

        mm_fields_config = dict(
            pixel_values=MultiModalFieldConfig.flat_from_sizes("image", pixel_sizes),
            vision_grid_thw=MultiModalFieldConfig.batched("image"),
            image_token_thw=MultiModalFieldConfig.batched("image"),
        )
        return mm_fields_config

_call_hf_processor

_call_hf_processor(
    prompt: str,
    mm_data: Mapping[str, object],
    mm_kwargs: Mapping[str, object],
    tok_kwargs: Mapping[str, object],
) -> BatchFeature

Run the underlying HF processor on text and image data.

Source code in vllm/model_executor/models/kanana_v.py
def _call_hf_processor(
    self,
    prompt: str,
    mm_data: Mapping[str, object],
    mm_kwargs: Mapping[str, object],
    tok_kwargs: Mapping[str, object],
) -> BatchFeature:
    """Run the underlying HF processor on text and image data."""
    # Text-only input is handled as a special case here.
    if not mm_data or not mm_data.get("images", []):
        prompt_ids = self.info.get_tokenizer().encode(prompt)
        return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

    # Images
    image_inputs = mm_data.get("images", [])
    pixel_sizes = []
    if not isinstance(image_inputs[0], Image.Image):
        image_inputs = [Image.fromarray(image) for image in image_inputs]

    image_processor = self.info.get_hf_processor().image_processor
    processor_output = [image_processor(image) for image in image_inputs]
    pixel_values = [o["pixel_values"] for o in processor_output]
    image_meta = [o["image_meta"] for o in processor_output]
    # list of dict -> dict of list
    image_meta = {k: [d[k] for d in image_meta] for k in image_meta[0]}

    for pixel_value in pixel_values:
        pixel_sizes.append(pixel_value.shape[0])
    # flattened pixel_values for single example (already includes batch dim)
    pixel_values = torch.concat(pixel_values, dim=0)

    tokenizer = self.info.get_tokenizer()
    media_token = tokenizer.convert_ids_to_tokens([self.media_token_id])[0]
    prompt_replaced = prompt.replace("<image>", media_token)
    input_ids = tokenizer.encode(prompt_replaced)
    input_ids = torch.tensor(input_ids)

    # Ensure HF output is consistent with vLLM prompt-update expectations:
    # if the HF tokenizer emits exactly 1 placeholder token per image, expand
    # it to `T*H*W` placeholder tokens per image so placeholder detection works.
    num_images = len(image_inputs)
    image_token_thw = torch.tensor(image_meta["image_token_thw"])
    per_image_token_counts = image_token_thw.prod(dim=1).tolist()
    expected_total = int(sum(int(x) for x in per_image_token_counts))

    n_placeholders = int((input_ids == self.media_token_id).sum().item())
    if n_placeholders == num_images and expected_total != num_images:
        expanded: list[int] = []
        img_i = 0
        for tok in input_ids.tolist():
            if tok == self.media_token_id and img_i < num_images:
                expanded.extend(
                    [self.media_token_id] * int(per_image_token_counts[img_i])
                )
                img_i += 1
            else:
                expanded.append(tok)
        input_ids = input_ids.new_tensor(expanded)

    combined_outputs = dict(
        # Add batch dimension to input_ids.
        input_ids=input_ids.unsqueeze(0),
        pixel_values=pixel_values,
        vision_grid_thw=torch.tensor(image_meta["vision_grid_thw"]),
        image_token_thw=torch.tensor(image_meta["image_token_thw"]),
        pixel_sizes=torch.tensor(pixel_sizes),
    )
    return BatchFeature(combined_outputs, tensor_type="pt")

PatchMerge

Bases: Module

Merge neighboring patches spatially to reduce resolution.

Source code in vllm/model_executor/models/kanana_v.py
class PatchMerge(nn.Module):
    """Merge neighboring patches spatially to reduce resolution."""

    def __init__(self, merge_size: int) -> None:
        super().__init__()
        self.merge_size = merge_size

    def forward(
        self,
        x: torch.Tensor,
        channel_last: bool = False,
    ) -> torch.Tensor:
        """Merge patches by `merge_size x merge_size`."""
        if channel_last:
            x = rearrange(x, "B H W D -> B D H W")
        _, _, H, W = x.shape
        merged_x = rearrange(
            x,
            "B D (H h2) (W w2) -> B (D h2 w2) H W",
            h2=self.merge_size,
            w2=self.merge_size,
        )
        return merged_x

forward

forward(x: Tensor, channel_last: bool = False) -> Tensor

Merge patches by merge_size x merge_size.

Source code in vllm/model_executor/models/kanana_v.py
def forward(
    self,
    x: torch.Tensor,
    channel_last: bool = False,
) -> torch.Tensor:
    """Merge patches by `merge_size x merge_size`."""
    if channel_last:
        x = rearrange(x, "B H W D -> B D H W")
    _, _, H, W = x.shape
    merged_x = rearrange(
        x,
        "B D (H h2) (W w2) -> B (D h2 w2) H W",
        h2=self.merge_size,
        w2=self.merge_size,
    )
    return merged_x

build_mlp

build_mlp(
    depth: int, hidden_size: int, output_hidden_size: int
) -> Sequential

Simple SiLU-activated MLP used as a projector readout.

Source code in vllm/model_executor/models/kanana_v.py
def build_mlp(
    depth: int,
    hidden_size: int,
    output_hidden_size: int,
) -> nn.Sequential:
    """Simple SiLU-activated MLP used as a projector readout."""
    layers = [nn.Linear(hidden_size, output_hidden_size)]
    for _ in range(1, depth):
        layers.append(nn.SiLU())
        layers.append(nn.Linear(output_hidden_size, output_hidden_size))
    return nn.Sequential(*layers)

build_pos_embeds

build_pos_embeds(
    config: Qwen2VLVisionConfig,
    num_input_tokens: int,
    vision_hidden_size: int,
) -> Parameter | None

Build positional embeddings for the visual encoder output.

Source code in vllm/model_executor/models/kanana_v.py
def build_pos_embeds(
    config: Qwen2VLVisionConfig,
    num_input_tokens: int,
    vision_hidden_size: int,
) -> nn.Parameter | None:
    """Build positional embeddings for the visual encoder output."""
    if config.pos_emb:
        pos_emb = nn.Parameter(torch.zeros(1, num_input_tokens, vision_hidden_size))
        nn.init.trunc_normal_(pos_emb, mean=0.0, std=0.02)
    else:
        pos_emb = None

    return pos_emb