Skip to content

vllm.model_executor.models.qwen2_5_vl

Inference-only Qwen2.5-VL model compatible with HuggingFace weights.

Qwen2_5_VLForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsQuant, SupportsEagle3, SupportsMultiModalPruning, SupportsMRoPE

Source code in vllm/model_executor/models/qwen2_5_vl.py
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
@MULTIMODAL_REGISTRY.register_processor(
    Qwen2_5_VLMultiModalProcessor,
    info=Qwen2_5_VLProcessingInfo,
    dummy_inputs=Qwen2_5_VLDummyInputsBuilder,
)
class Qwen2_5_VLForConditionalGeneration(
    nn.Module,
    SupportsMultiModal,
    SupportsLoRA,
    SupportsPP,
    SupportsQuant,
    SupportsEagle3,
    SupportsMultiModalPruning,
    SupportsMRoPE,
):
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
        "qkv": ["qkv"],  # For vision tower's already-packed QKV
    }

    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        }
    )

    supports_encoder_tp_data = True

    def iter_mm_grid_thw(
        self, mm_features: list[MultiModalFeatureSpec]
    ) -> Iterator[tuple[int, int, int, int, float]]:
        """
        Iterate over multimodal features and yield grid information.

        Args:
            mm_features: List of multimodal feature specifications

        Yields:
            Tuple of (offset, grid_t, grid_h, grid_w, t_factor) for each frame/image
        """
        spatial_merge_size = self.config.vision_config.spatial_merge_size
        tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
        for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
            offset = mm_feature.mm_position.offset
            if mm_feature.modality == "image":
                t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
                assert t == 1, f"Image must have 1 frame, got {t}"
                yield offset, 1, h // spatial_merge_size, w // spatial_merge_size, 1.0
            elif mm_feature.modality == "video":
                t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
                second_per_grid_ts = 1.0
                if mm_feature.data.get("second_per_grid_ts", None):
                    second_per_grid_ts = mm_feature.data[
                        "second_per_grid_ts"
                    ].data.item()
                t_factor = second_per_grid_ts * tokens_per_second
                yield (
                    offset,
                    t,
                    h // spatial_merge_size,
                    w // spatial_merge_size,
                    t_factor,
                )
            else:
                raise ValueError(f"Unsupported modality: {mm_feature.modality}")

    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
        mm_features: list[MultiModalFeatureSpec],
    ) -> tuple[torch.Tensor, int]:
        llm_pos_ids_list: list = []
        st = 0

        for (
            offset,
            llm_grid_t,
            llm_grid_h,
            llm_grid_w,
            t_factor,
        ) in self.iter_mm_grid_thw(mm_features):
            text_len = offset - st
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )

            grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
            if t_factor != 1.0:
                grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
            llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
            st = offset + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )

        llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()

        return torch.from_numpy(llm_positions), mrope_position_delta

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
        self.config = config
        self.vllm_config = vllm_config
        self.multimodal_config = multimodal_config
        self.video_pruning_rate = multimodal_config.video_pruning_rate
        self.is_multimodal_pruning_enabled = (
            multimodal_config.is_multimodal_pruning_enabled()
        )

        with self._mark_tower_model(vllm_config, {"image", "video"}):
            self.visual = Qwen2_5_VisionTransformer(
                vision_config=config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
                quant_config=self.quant_config,
                prefix=maybe_prefix(prefix, "visual"),
            )

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["Qwen2ForCausalLM"],
            )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.language_model.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.language_model.model.layers)
        return (2, num_layers // 2, num_layers - 3)

    def _parse_and_validate_image_input(
        self, **kwargs: object
    ) -> Qwen2_5_VLImageInputs | None:
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return Qwen2_5_VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )

        if image_embeds is not None:
            return Qwen2_5_VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )

    def _parse_and_validate_video_input(
        self, **kwargs: object
    ) -> Qwen2_5_VLVideoInputs | None:
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_embeds = kwargs.pop("video_embeds", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)
        second_per_grid_ts = kwargs.pop("second_per_grid_ts", None)

        if pixel_values_videos is None and video_embeds is None:
            return None

        if pixel_values_videos is not None:
            return Qwen2_5_VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
            )

        if video_embeds is not None:
            return Qwen2_5_VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
            )

    def _process_image_input(
        self, image_input: Qwen2_5_VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
        grid_thw_list = grid_thw.tolist()

        if image_input["type"] == "image_embeds":
            image_embeds = image_input["image_embeds"].type(self.visual.dtype)
        else:
            pixel_values = image_input["pixel_values"]
            with set_forward_context(None, self.vllm_config):
                if self.use_data_parallel:
                    return run_dp_sharded_mrope_vision_model(
                        self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
                    )
                else:
                    image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
        return image_embeds.split(sizes)

    def _postprocess_image_embeds_evs(
        self,
        image_embeds_split: tuple[torch.Tensor, ...],
        image_input: Qwen2_5_VLImageInputs,
    ) -> tuple[torch.Tensor, ...]:
        """
        Append mrope positions for each for images.
        This is necessary to recover correct mrope
        positions after video pruning

        Args:
            image_embeds_split: Tuple of image embeddings for
                each image item.
            image_input: Image input data.

        Returns:
            Tuple of image embeddings for each image item.
            Resulting embeddings will have extra 4 channels for
            computed mrope positions.
        """
        merge_size = self.visual.spatial_merge_size
        grid_thw = image_input["image_grid_thw"]
        grid_thw_list = grid_thw.tolist()
        image_embeds_out = []
        for emb, size in zip(image_embeds_split, grid_thw_list):
            positions = compute_mrope_for_media(size, merge_size).to(emb.device)
            emb = torch.cat([emb, positions], dim=1)
            image_embeds_out.append(emb)
        image_embeds_split = image_embeds_out
        return tuple(image_embeds_split)

    def _process_video_input(
        self, video_input: Qwen2_5_VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
        grid_thw_list = grid_thw.tolist()

        if video_input["type"] == "video_embeds":
            video_embeds = video_input["video_embeds"].type(self.visual.dtype)
        else:
            pixel_values_videos = video_input["pixel_values_videos"]
            with set_forward_context(None, self.vllm_config):
                if self.use_data_parallel:
                    return run_dp_sharded_mrope_vision_model(
                        self.visual,
                        pixel_values_videos,
                        grid_thw_list,
                        rope_type="rope_3d",
                    )
                else:
                    video_embeds = self.visual(
                        pixel_values_videos, grid_thw=grid_thw_list
                    )

        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
        return video_embeds.split(sizes)

    def _postprocess_video_embeds_evs(
        self,
        video_embeds_split: tuple[torch.Tensor, ...],
        video_input: Qwen2_5_VLVideoInputs,
    ) -> tuple[torch.Tensor, ...]:
        """
        Prunes video embeddings via Efficient Video Sampling (EVS)
        and then appends mrope positions for each retained embeddings

        Args:
            video_embeds_split: Tuple of video embeddings for each video item.
            video_input: Video input data.

        Returns:
            Tuple of video embeddings for each video item.
            Resulting embeddings will have extra 4 channels for
            computed mrope positions.
        """
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
        grid_thw_list = grid_thw.tolist()
        merge_size = self.visual.spatial_merge_size

        # Cast to long to match the original code
        # https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa
        second_per_grid_ts = video_input.get("second_per_grid_ts")
        if second_per_grid_ts is None:
            raise ValueError(
                "second_per_grid_ts is required when video_pruning_rate > 0 "
                "is enabled for video inputs, including the video_embeds path."
            )
        second_per_grid_ts = second_per_grid_ts.long()
        tokens_per_second = self.config.vision_config.tokens_per_second

        video_embeds_out = []
        for emb, size, video_second_per_grid_t in zip(
            video_embeds_split, grid_thw_list, second_per_grid_ts
        ):
            # For each video, we compute retention mask using EVS
            retention_mask = compute_retention_mask(
                emb,
                size,
                spatial_merge_size=self.visual.spatial_merge_size,
                q=self.video_pruning_rate,
            )
            positions = compute_mrope_for_media(
                size,
                merge_size,
                tokens_per_second=tokens_per_second,
                video_second_per_grid=video_second_per_grid_t.item(),
            ).to(emb.device)

            emb = emb[retention_mask]
            positions = positions[retention_mask]
            emb = torch.cat([emb, positions], dim=1)
            video_embeds_out.append(emb)
        return tuple(video_embeds_out)

    def recompute_mrope_positions(
        self,
        input_ids: list[int],
        multimodal_embeddings: tuple[torch.Tensor, ...],
        mrope_positions: torch.LongTensor,
        num_computed_tokens: int,
    ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]:
        """
        Update part of input mrope positions (starting with
        num_computed_tokens index). Original mrope_positions are computed
        for unpruned sequence and becomes incorrect once pruning occurs,
        so once we prune media tokens we should reflect this in the
        mrope_positions before we feed it to LLM.

        Args:
            input_ids: (N,) All input tokens of the prompt (Containing
                entire sequence).
            multimodal_embeddings: Tuple of multimodal embeddings.
            mrope_positions: Existing mrope positions (3, N) for entire
                sequence
            num_computed_tokens: A number of computed tokens so far.

        Returns:
            Tuple of (multimodal_embeddings, mrope_positions,
                mrope_position_delta).
        """
        image_token_id = self.config.image_token_id
        video_token_id = self.config.video_token_id
        vision_start_token_id = self.config.vision_start_token_id

        # Device
        device = (
            multimodal_embeddings[0].device
            if len(multimodal_embeddings)
            else mrope_positions.device
        )

        # Tensors
        input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)

        mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings]
        mm_embeddings_pos = [
            mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings
        ]

        positions, mrope_positions_delta = recompute_mrope_positions(
            input_ids_t,
            mm_embeddings_pos,
            mrope_positions,
            num_computed_tokens,
            vision_start_token_id,
            image_token_id,
            video_token_id,
        )

        return tuple(mm_embeddings_out), positions, mrope_positions_delta

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        mm_input_by_modality = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if (
                input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality
            ):
                mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                    **kwargs
                )
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "video" not in mm_input_by_modality
            ):
                mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                    **kwargs
                )
        return mm_input_by_modality

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not mm_input_by_modality:
            return []

        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
                image_embeddings = self._process_image_input(multimodal_input)
                if self.is_multimodal_pruning_enabled:
                    image_embeddings = self._postprocess_image_embeds_evs(
                        image_embeddings, multimodal_input
                    )
                multimodal_embeddings += tuple(image_embeddings)
            if modality == "video":
                video_embeddings = self._process_video_input(multimodal_input)
                if self.is_multimodal_pruning_enabled:
                    video_embeddings = self._postprocess_video_embeds_evs(
                        video_embeddings, multimodal_input
                    )
                multimodal_embeddings += tuple(video_embeddings)
        return multimodal_embeddings

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        """Run forward pass for Qwen2.5-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch. **NOTE**: If mrope is enabled (default setting for
                Qwen2.5-VL opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
        """

        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="visual.merger.",
            tower_model="visual.",
        )

    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size

        return num_image_tokens * merge_size**2

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size
        return num_vision_tokens // merge_size**2

_postprocess_image_embeds_evs

_postprocess_image_embeds_evs(
    image_embeds_split: tuple[Tensor, ...],
    image_input: Qwen2_5_VLImageInputs,
) -> tuple[Tensor, ...]

Append mrope positions for each for images. This is necessary to recover correct mrope positions after video pruning

Parameters:

Name Type Description Default
image_embeds_split tuple[Tensor, ...]

Tuple of image embeddings for each image item.

required
image_input Qwen2_5_VLImageInputs

Image input data.

required

Returns:

Type Description
Tensor

Tuple of image embeddings for each image item.

...

Resulting embeddings will have extra 4 channels for

tuple[Tensor, ...]

computed mrope positions.

Source code in vllm/model_executor/models/qwen2_5_vl.py
def _postprocess_image_embeds_evs(
    self,
    image_embeds_split: tuple[torch.Tensor, ...],
    image_input: Qwen2_5_VLImageInputs,
) -> tuple[torch.Tensor, ...]:
    """
    Append mrope positions for each for images.
    This is necessary to recover correct mrope
    positions after video pruning

    Args:
        image_embeds_split: Tuple of image embeddings for
            each image item.
        image_input: Image input data.

    Returns:
        Tuple of image embeddings for each image item.
        Resulting embeddings will have extra 4 channels for
        computed mrope positions.
    """
    merge_size = self.visual.spatial_merge_size
    grid_thw = image_input["image_grid_thw"]
    grid_thw_list = grid_thw.tolist()
    image_embeds_out = []
    for emb, size in zip(image_embeds_split, grid_thw_list):
        positions = compute_mrope_for_media(size, merge_size).to(emb.device)
        emb = torch.cat([emb, positions], dim=1)
        image_embeds_out.append(emb)
    image_embeds_split = image_embeds_out
    return tuple(image_embeds_split)

_postprocess_video_embeds_evs

_postprocess_video_embeds_evs(
    video_embeds_split: tuple[Tensor, ...],
    video_input: Qwen2_5_VLVideoInputs,
) -> tuple[Tensor, ...]

Prunes video embeddings via Efficient Video Sampling (EVS) and then appends mrope positions for each retained embeddings

Parameters:

Name Type Description Default
video_embeds_split tuple[Tensor, ...]

Tuple of video embeddings for each video item.

required
video_input Qwen2_5_VLVideoInputs

Video input data.

required

Returns:

Type Description
Tensor

Tuple of video embeddings for each video item.

...

Resulting embeddings will have extra 4 channels for

tuple[Tensor, ...]

computed mrope positions.

Source code in vllm/model_executor/models/qwen2_5_vl.py
def _postprocess_video_embeds_evs(
    self,
    video_embeds_split: tuple[torch.Tensor, ...],
    video_input: Qwen2_5_VLVideoInputs,
) -> tuple[torch.Tensor, ...]:
    """
    Prunes video embeddings via Efficient Video Sampling (EVS)
    and then appends mrope positions for each retained embeddings

    Args:
        video_embeds_split: Tuple of video embeddings for each video item.
        video_input: Video input data.

    Returns:
        Tuple of video embeddings for each video item.
        Resulting embeddings will have extra 4 channels for
        computed mrope positions.
    """
    grid_thw = video_input["video_grid_thw"]
    assert grid_thw.ndim == 2
    grid_thw_list = grid_thw.tolist()
    merge_size = self.visual.spatial_merge_size

    # Cast to long to match the original code
    # https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa
    second_per_grid_ts = video_input.get("second_per_grid_ts")
    if second_per_grid_ts is None:
        raise ValueError(
            "second_per_grid_ts is required when video_pruning_rate > 0 "
            "is enabled for video inputs, including the video_embeds path."
        )
    second_per_grid_ts = second_per_grid_ts.long()
    tokens_per_second = self.config.vision_config.tokens_per_second

    video_embeds_out = []
    for emb, size, video_second_per_grid_t in zip(
        video_embeds_split, grid_thw_list, second_per_grid_ts
    ):
        # For each video, we compute retention mask using EVS
        retention_mask = compute_retention_mask(
            emb,
            size,
            spatial_merge_size=self.visual.spatial_merge_size,
            q=self.video_pruning_rate,
        )
        positions = compute_mrope_for_media(
            size,
            merge_size,
            tokens_per_second=tokens_per_second,
            video_second_per_grid=video_second_per_grid_t.item(),
        ).to(emb.device)

        emb = emb[retention_mask]
        positions = positions[retention_mask]
        emb = torch.cat([emb, positions], dim=1)
        video_embeds_out.append(emb)
    return tuple(video_embeds_out)

forward

forward(
    input_ids: Tensor | None,
    positions: Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: Tensor | None = None,
    **kwargs: object,
) -> Tensor | IntermediateTensors

Run forward pass for Qwen2.5-VL.

Parameters:

Name Type Description Default
input_ids Tensor | None

Flattened (concatenated) input_ids corresponding to a batch.

required
positions Tensor

Flattened (concatenated) position ids corresponding to a batch. NOTE: If mrope is enabled (default setting for Qwen2.5-VL opensource models), the shape will be (3, seq_len), otherwise it will be `(seq_len,).

required
Source code in vllm/model_executor/models/qwen2_5_vl.py
def forward(
    self,
    input_ids: torch.Tensor | None,
    positions: torch.Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: torch.Tensor | None = None,
    **kwargs: object,
) -> torch.Tensor | IntermediateTensors:
    """Run forward pass for Qwen2.5-VL.

    Args:
        input_ids: Flattened (concatenated) input_ids corresponding to a
            batch.
        positions: Flattened (concatenated) position ids corresponding to a
            batch. **NOTE**: If mrope is enabled (default setting for
            Qwen2.5-VL opensource models), the shape will be `(3, seq_len)`,
            otherwise it will be `(seq_len,).
    """

    if intermediate_tensors is not None:
        inputs_embeds = None

    hidden_states = self.language_model.model(
        input_ids=input_ids,
        positions=positions,
        intermediate_tensors=intermediate_tensors,
        inputs_embeds=inputs_embeds,
    )
    return hidden_states

get_mm_mapping

get_mm_mapping() -> MultiModelKeys

Get the module prefix in multimodal models

Source code in vllm/model_executor/models/qwen2_5_vl.py
def get_mm_mapping(self) -> MultiModelKeys:
    """
    Get the module prefix in multimodal models
    """
    return MultiModelKeys.from_string_field(
        language_model="language_model",
        connector="visual.merger.",
        tower_model="visual.",
    )

iter_mm_grid_thw

iter_mm_grid_thw(
    mm_features: list[MultiModalFeatureSpec],
) -> Iterator[tuple[int, int, int, int, float]]

Iterate over multimodal features and yield grid information.

Parameters:

Name Type Description Default
mm_features list[MultiModalFeatureSpec]

List of multimodal feature specifications

required

Yields:

Type Description
tuple[int, int, int, int, float]

Tuple of (offset, grid_t, grid_h, grid_w, t_factor) for each frame/image

Source code in vllm/model_executor/models/qwen2_5_vl.py
def iter_mm_grid_thw(
    self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int, int, float]]:
    """
    Iterate over multimodal features and yield grid information.

    Args:
        mm_features: List of multimodal feature specifications

    Yields:
        Tuple of (offset, grid_t, grid_h, grid_w, t_factor) for each frame/image
    """
    spatial_merge_size = self.config.vision_config.spatial_merge_size
    tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
    for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
        offset = mm_feature.mm_position.offset
        if mm_feature.modality == "image":
            t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
            assert t == 1, f"Image must have 1 frame, got {t}"
            yield offset, 1, h // spatial_merge_size, w // spatial_merge_size, 1.0
        elif mm_feature.modality == "video":
            t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
            second_per_grid_ts = 1.0
            if mm_feature.data.get("second_per_grid_ts", None):
                second_per_grid_ts = mm_feature.data[
                    "second_per_grid_ts"
                ].data.item()
            t_factor = second_per_grid_ts * tokens_per_second
            yield (
                offset,
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
                t_factor,
            )
        else:
            raise ValueError(f"Unsupported modality: {mm_feature.modality}")

recompute_mrope_positions

recompute_mrope_positions(
    input_ids: list[int],
    multimodal_embeddings: tuple[Tensor, ...],
    mrope_positions: LongTensor,
    num_computed_tokens: int,
) -> tuple[tuple[Tensor, ...], Tensor, int]

Update part of input mrope positions (starting with num_computed_tokens index). Original mrope_positions are computed for unpruned sequence and becomes incorrect once pruning occurs, so once we prune media tokens we should reflect this in the mrope_positions before we feed it to LLM.

Parameters:

Name Type Description Default
input_ids list[int]

(N,) All input tokens of the prompt (Containing entire sequence).

required
multimodal_embeddings tuple[Tensor, ...]

Tuple of multimodal embeddings.

required
mrope_positions LongTensor

Existing mrope positions (3, N) for entire sequence

required
num_computed_tokens int

A number of computed tokens so far.

required

Returns:

Type Description
tuple[tuple[Tensor, ...], Tensor, int]

Tuple of (multimodal_embeddings, mrope_positions, mrope_position_delta).

Source code in vllm/model_executor/models/qwen2_5_vl.py
def recompute_mrope_positions(
    self,
    input_ids: list[int],
    multimodal_embeddings: tuple[torch.Tensor, ...],
    mrope_positions: torch.LongTensor,
    num_computed_tokens: int,
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]:
    """
    Update part of input mrope positions (starting with
    num_computed_tokens index). Original mrope_positions are computed
    for unpruned sequence and becomes incorrect once pruning occurs,
    so once we prune media tokens we should reflect this in the
    mrope_positions before we feed it to LLM.

    Args:
        input_ids: (N,) All input tokens of the prompt (Containing
            entire sequence).
        multimodal_embeddings: Tuple of multimodal embeddings.
        mrope_positions: Existing mrope positions (3, N) for entire
            sequence
        num_computed_tokens: A number of computed tokens so far.

    Returns:
        Tuple of (multimodal_embeddings, mrope_positions,
            mrope_position_delta).
    """
    image_token_id = self.config.image_token_id
    video_token_id = self.config.video_token_id
    vision_start_token_id = self.config.vision_start_token_id

    # Device
    device = (
        multimodal_embeddings[0].device
        if len(multimodal_embeddings)
        else mrope_positions.device
    )

    # Tensors
    input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)

    mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings]
    mm_embeddings_pos = [
        mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings
    ]

    positions, mrope_positions_delta = recompute_mrope_positions(
        input_ids_t,
        mm_embeddings_pos,
        mrope_positions,
        num_computed_tokens,
        vision_start_token_id,
        image_token_id,
        video_token_id,
    )

    return tuple(mm_embeddings_out), positions, mrope_positions_delta

Qwen2_5_VLImageEmbeddingInputs

Bases: TensorSchema

Dimensions
  • nf: Number of image features
  • hs: Hidden size
  • ni: Number of images
Historical context
  • image_embeds shape: (num_image_features, hidden_size)
  • num_image_features varies based on the number and resolution of the images.
  • hidden_size must match the hidden size of language model backbone.
  • image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) format
Source code in vllm/model_executor/models/qwen2_5_vl.py
class Qwen2_5_VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images

    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
    """

    type: Literal["image_embeds"]

    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]

Qwen2_5_VLImagePixelInputs

Bases: TensorSchema

Dimensions
  • np: Number of patches
  • ni: Number of images
  • cps: Number of channels * patch_size * patch_size
Historical context
  • pixel_values shape: (num_patches, num_channels * patch_size * patch_size)
  • image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) format.
Source code in vllm/model_executor/models/qwen2_5_vl.py
class Qwen2_5_VLImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - np: Number of patches
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size

    Historical context:
        - pixel_values shape: (num_patches, num_channels * patch_size *
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format.
    """

    type: Literal["pixel_values"]

    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]

Qwen2_5_VLVideoEmbeddingInputs

Bases: TensorSchema

Dimensions
  • nf: Number of video features
  • hs: Hidden size
  • nv: Number of videos
Historical context
  • video_embeds shape: (num_video_features, hidden_size)
  • num_video_features varies based on the number and resolution of the videos.
  • hidden_size must match the hidden size of language model backbone.
  • video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) format
  • second_per_grid_ts: The video time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. Returned when videos is not None.
Source code in vllm/model_executor/models/qwen2_5_vl.py
class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos

    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
        - second_per_grid_ts: The video time interval (in seconds) for each
          grid along the temporal dimension in the 3D position IDs. Returned
          when `videos` is not `None`.
    """

    type: Literal["video_embeds"]

    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]

    second_per_grid_ts: Annotated[
        torch.Tensor | None,
        TensorShape("nv"),
    ] = None

Qwen2_5_VLVideoPixelInputs

Bases: TensorSchema

Dimensions
  • np: Number of patches
  • nv: Number of videos
  • ctps: Number of channels * temporal_patch_size * patch_size * patch_size
Historical context
  • pixel_values_videos shape: (num_patches, num_channels * temporal_patch_size * patch_size * patch_size)
  • video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) format
  • second_per_grid_ts: The video time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. Returned when videos is not None.
Source code in vllm/model_executor/models/qwen2_5_vl.py
class Qwen2_5_VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: Number of patches
        - nv: Number of videos
        - ctps: Number of channels * temporal_patch_size * patch_size *
          patch_size

    Historical context:
        - pixel_values_videos shape: (num_patches, num_channels *
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
        - second_per_grid_ts: The video time interval (in seconds) for each
          grid along the temporal dimension in the 3D position IDs. Returned
          when `videos` is not `None`.
    """

    type: Literal["pixel_values_videos"]

    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]

    second_per_grid_ts: Annotated[
        torch.Tensor | None,
        TensorShape("nv"),
    ]