Skip to content

vllm.model_executor.models.isaac

IsaacForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE

Source code in vllm/model_executor/models/isaac.py
@MULTIMODAL_REGISTRY.register_processor(
    IsaacMultiModalProcessor,
    info=IsaacProcessingInfo,
    dummy_inputs=IsaacDummyInputsBuilder,
)
class IsaacForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    supports_encoder_tp_data = True

    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "lm_head.": "language_model.lm_head.",
            "model.text_model.lm_head.": "language_model.lm_head.",
            "model.text_model.": "language_model.model.",
            "model.vision_embedding.0": "vision_embedding.transformer",
            "model.vision_embedding.1": "vision_embedding.linear_fc1",
            "model.vision_embedding.2": "vision_embedding.act",
            "model.vision_embedding.3": "vision_embedding.linear_fc2",
            "model.vision_embedding.": "vision_embedding.",
            "model.lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        }
    )

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
        super().__init__()
        config: IsaacConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config

        head_dim = config.head_dim
        calculated_mrope_section = [
            head_dim // 4,  # 2x more for temporal dim
            head_dim // 8,
            head_dim // 8,
        ]

        self.vision_token_id = _resolve_vision_token_id(
            vllm_config.model_config, config.vision_token
        )
        config.image_token_id = self.vision_token_id

        text_cfg = getattr(config, "text_config", None)
        target_cfg = (
            text_cfg
            if text_cfg is not None and not isinstance(text_cfg, dict)
            else config
        )

        rope_scaling = getattr(target_cfg, "rope_scaling", None)
        if rope_scaling is None and target_cfg is config:
            rope_scaling = getattr(config, "_rope_scaling", None)

        patch_rope_parameters(target_cfg)
        rope_parameters = target_cfg.rope_parameters
        rope_parameters["mrope_section"] = calculated_mrope_section
        if rope_scaling is not None and "mrope_interleaved" in rope_scaling:
            rope_parameters.setdefault(
                "mrope_interleaved", rope_scaling["mrope_interleaved"]
            )
        target_cfg.rope_parameters = rope_parameters

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                architectures=["Qwen3ForCausalLM"],
                prefix=maybe_prefix(prefix, "language_model"),
            )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

        vision_cfg = config.vision_config
        if vision_cfg is None:
            raise ValueError("IsaacConfig should always have vision_config")
        attn_impl = (
            config.vision_attn_implementation
            if config.vision_attn_implementation is not None
            else getattr(config, "_attn_implementation", None)
        )
        if attn_impl is not None:
            vision_cfg._attn_implementation = attn_impl

        hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2)

        with self._mark_tower_model(vllm_config, "image"):
            self.vision_embedding = IsaacVisionEmbedding(
                vision_cfg=vision_cfg,
                hidden_dim=hidden_dim,
                output_dim=config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "vision_embedding"),
            )

    def iter_mm_grid_hw(
        self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec]
    ) -> Iterator[tuple[int, int, int]]:
        spatial_merge_size = self.config.vision_config.pixel_shuffle_scale_factor
        for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
            offset = mm_feature.mm_position.offset
            if mm_feature.modality == "image":
                t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
                assert t == 1, f"Image must have 1 frame, got {t}"
                yield offset, h // spatial_merge_size, w // spatial_merge_size
            else:
                raise ValueError(f"Unsupported modality: {mm_feature.modality}")

    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
        mm_features: list[MultiModalFeatureSpec],
    ) -> tuple[torch.Tensor, int]:
        llm_pos_ids_list = []
        st = 0
        for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw(
            input_tokens, mm_features
        ):
            text_len = offset - st
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )

            grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)
            grid_indices[0, :] = grid_indices[0, :] + text_len + st_idx
            llm_pos_ids_list.append(grid_indices)
            st = offset + llm_grid_h * llm_grid_w

        if st < len(input_tokens):
            st_idx = llm_pos_ids_list[-1][0, -1] + 1 if len(llm_pos_ids_list) > 0 else 0
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )

        llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()

        return torch.from_numpy(llm_positions), mrope_position_delta

    def _parse_and_validate_image_input(
        self, **kwargs: object
    ) -> IsaacImagePixelInputs | None:
        pixel_values = kwargs.get("pixel_values")
        image_grid_thw = kwargs.get("image_grid_thw")
        if pixel_values is None or image_grid_thw is None:
            return None

        # TensorSchema will automatically validate shapes on initialization
        return IsaacImagePixelInputs(
            pixel_values=pixel_values,
            image_grid_thw=image_grid_thw,
        )

    def _process_image_input(
        self,
        image_input: IsaacImagePixelInputs,
    ) -> tuple[torch.Tensor, ...]:
        pixel_values = image_input["pixel_values"]
        image_grid_thw = image_input["image_grid_thw"]
        if pixel_values.numel() == 0:
            return ()

        device = next(self.language_model.parameters()).device
        dtype = self.vision_embedding.linear_fc1.weight.dtype
        pixel_values = pixel_values.to(device=device, dtype=dtype)
        spatial_grids = image_grid_thw[:, 1:3].to(device, dtype=torch.int32)

        vision_embeddings = self.vision_embedding((pixel_values, spatial_grids))
        merge_size = self.config.vision_config.pixel_shuffle_scale_factor
        sizes = spatial_grids.prod(-1) // (merge_size * merge_size)
        return tuple(vision_embeddings.split(sizes.tolist()))

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return ()
        return self._process_image_input(image_input)

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        return self.language_model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            **kwargs,
        )

    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="vision_embedding.linear_fc2",  # The final linear layer
            tower_model="vision_embedding",
        )

get_mm_mapping

get_mm_mapping() -> MultiModelKeys

Get the module prefix in multimodal models

Source code in vllm/model_executor/models/isaac.py
def get_mm_mapping(self) -> MultiModelKeys:
    """
    Get the module prefix in multimodal models
    """
    return MultiModelKeys.from_string_field(
        language_model="language_model",
        connector="vision_embedding.linear_fc2",  # The final linear layer
        tower_model="vision_embedding",
    )

IsaacImagePixelInputs

Bases: TensorSchema

Schema for validating Isaac image inputs.

Dimensions
  • np: Number of patches
  • d: Patch dimension
  • ni: Number of images
The schema enforces
  • pixel_values must be 2D: (num_patches, patch_dim)
  • image_grid_thw must be 2D: (num_images, 3) where 3 represents [T, H, W]
Source code in vllm/model_executor/models/isaac.py
class IsaacImagePixelInputs(TensorSchema):
    """
    Schema for validating Isaac image inputs.

    Dimensions:
        - np: Number of patches
        - d: Patch dimension
        - ni: Number of images

    The schema enforces:
        - pixel_values must be 2D: (num_patches, patch_dim)
        - image_grid_thw must be 2D: (num_images, 3)
          where 3 represents [T, H, W]
    """

    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "d"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]

IsaacImageProcessor

Source code in vllm/model_executor/models/isaac.py
class IsaacImageProcessor:
    patch_size = 16
    max_num_patches = 6144
    min_num_patches = 256
    pixel_shuffle_scale = 2

    valid_kwargs = IsaacImageProcessorKwargs
    model_input_names = ["pixel_values", "image_grid_thw"]

    def __init__(self, kwargs):
        self.patch_size = kwargs.pop("patch_size", self.patch_size)
        self.vision_max_num_patches = kwargs.pop(
            "vision_max_num_patches", self.max_num_patches
        )
        self.vision_min_num_patches = kwargs.pop(
            "vision_min_num_patches", self.min_num_patches
        )
        self.pixel_shuffle_scale = kwargs.pop("pixel_shuffle_scale", 2)

    def preprocess(
        self,
        images: list[torch.Tensor],
        return_tensors: str | TensorType | None,
        **kwargs: Unpack[IsaacImageProcessorKwargs],
    ) -> BatchFeature:
        """Preprocess images into format compatibile with vLLM input processing."""

        all_pixel_values: list[torch.Tensor] = []
        all_image_grids: list[torch.Tensor] = []

        for image in images:
            image_tensor = extract_image_pil(image)

            patches, dims_virtual = process_vision_for_patches(
                image_tensor,
                patch_size=self.patch_size,
                max_num_patches=self.vision_max_num_patches,
                min_num_patches=self.vision_min_num_patches,
                pixel_shuffle_scale=self.pixel_shuffle_scale,
            )

            # Isaac packs a dummy temporal dim for images
            patches = patches.unsqueeze(1)  # [N, T=1, Hp, Wp, D]

            hp, wp, dim = patches.shape[-3], patches.shape[-2], patches.shape[-1]
            current_num_patches = hp * wp
            pixel_values = patches.reshape(current_num_patches, dim)  # [N_tokens, D]

            # Use real patch dimensions for image_grid_thw, not virtual dimensions
            # This ensures the vision model receives correct grid info for pixel shuffle
            dims_real = [1, hp, wp]  # Real patch dimensions
            image_grid_thw = torch.tensor(dims_real).unsqueeze(0)

            all_pixel_values.append(pixel_values)
            all_image_grids.append(image_grid_thw)

        if all_pixel_values:
            final_pixel_values = torch.cat(all_pixel_values, dim=0)
            final_image_grids = torch.cat(all_image_grids, dim=0)
        else:
            final_pixel_values = torch.empty(0, 0)
            final_image_grids = torch.empty(0, 3)

        return BatchFeature(
            data={
                "pixel_values": final_pixel_values,
                "image_grid_thw": final_image_grids,
            },
            tensor_type=return_tensors,
        )

preprocess

preprocess(
    images: list[Tensor],
    return_tensors: str | TensorType | None,
    **kwargs: Unpack[IsaacImageProcessorKwargs],
) -> BatchFeature

Preprocess images into format compatibile with vLLM input processing.

Source code in vllm/model_executor/models/isaac.py
def preprocess(
    self,
    images: list[torch.Tensor],
    return_tensors: str | TensorType | None,
    **kwargs: Unpack[IsaacImageProcessorKwargs],
) -> BatchFeature:
    """Preprocess images into format compatibile with vLLM input processing."""

    all_pixel_values: list[torch.Tensor] = []
    all_image_grids: list[torch.Tensor] = []

    for image in images:
        image_tensor = extract_image_pil(image)

        patches, dims_virtual = process_vision_for_patches(
            image_tensor,
            patch_size=self.patch_size,
            max_num_patches=self.vision_max_num_patches,
            min_num_patches=self.vision_min_num_patches,
            pixel_shuffle_scale=self.pixel_shuffle_scale,
        )

        # Isaac packs a dummy temporal dim for images
        patches = patches.unsqueeze(1)  # [N, T=1, Hp, Wp, D]

        hp, wp, dim = patches.shape[-3], patches.shape[-2], patches.shape[-1]
        current_num_patches = hp * wp
        pixel_values = patches.reshape(current_num_patches, dim)  # [N_tokens, D]

        # Use real patch dimensions for image_grid_thw, not virtual dimensions
        # This ensures the vision model receives correct grid info for pixel shuffle
        dims_real = [1, hp, wp]  # Real patch dimensions
        image_grid_thw = torch.tensor(dims_real).unsqueeze(0)

        all_pixel_values.append(pixel_values)
        all_image_grids.append(image_grid_thw)

    if all_pixel_values:
        final_pixel_values = torch.cat(all_pixel_values, dim=0)
        final_image_grids = torch.cat(all_image_grids, dim=0)
    else:
        final_pixel_values = torch.empty(0, 0)
        final_image_grids = torch.empty(0, 3)

    return BatchFeature(
        data={
            "pixel_values": final_pixel_values,
            "image_grid_thw": final_image_grids,
        },
        tensor_type=return_tensors,
    )

IsaacProcessor

Processor wrapper (tokenizer + IsaacImageProcessor).

Source code in vllm/model_executor/models/isaac.py
class IsaacProcessor:
    """Processor wrapper (tokenizer + IsaacImageProcessor)."""

    def __init__(self, image_processor=None, tokenizer=None, **kwargs):
        self.image_token = kwargs.pop("image_token", "<image>")
        self.image_processor = image_processor or IsaacImageProcessor(kwargs)
        self.tokenizer = tokenizer

    def __call__(self, text=None, images=None, **kwargs) -> BatchFeature:
        result = {}

        if images is not None:
            image_inputs = self.image_processor.preprocess(images, **kwargs)
            image_grid_thw = image_inputs["image_grid_thw"]
            result.update(image_inputs)

            if text is not None:
                if not isinstance(text, list):
                    text = [text]

                text = text.copy()  # below lines change text in-place
                merge_length = self.image_processor.pixel_shuffle_scale**2
                index = 0
                for i in range(len(text)):
                    while self.image_token in text[i]:
                        num_image_tokens = image_grid_thw[index].prod() // merge_length
                        text[i] = text[i].replace(
                            self.image_token, "<|placeholder|>" * num_image_tokens, 1
                        )
                        index += 1
                    text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")

        if text is not None:
            result.update(self.tokenizer(text, **kwargs))

        return BatchFeature(result)

    def apply_chat_template(
        self,
        messages: list[dict[str, Any]],
        tokenize: bool = False,
        add_generation_prompt: bool = False,
        **kwargs,
    ) -> Any:
        # Convert mixed content messages to simple text format
        processed_messages = []

        for message in messages:
            if "content" in message and isinstance(message["content"], list):
                # Handle mixed content (text + image)
                text_parts = []
                for content_item in message["content"]:
                    if content_item.get("type") == "text":
                        text_parts.append(content_item.get("text", ""))
                    elif content_item.get("type") == "image":
                        # Replace image with vision token
                        text_parts.append(self.image_token)

                processed_message = {
                    "role": message.get("role", "user"),
                    "content": "".join(text_parts),
                }
                processed_messages.append(processed_message)
            else:
                # Regular text message
                processed_messages.append(message)

        kwargs["return_dict"] = False
        return self.tokenizer.apply_chat_template(
            processed_messages,
            tokenize=tokenize,
            add_generation_prompt=add_generation_prompt,
            **kwargs,
        )

Siglip2VisionTransformer

Bases: Module

Source code in vllm/model_executor/models/isaac.py
class Siglip2VisionTransformer(nn.Module):
    def __init__(
        self,
        config: PixelShuffleSiglip2VisionConfig,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
        self.quant_config = quant_config
        embed_dim = config.hidden_size

        self.embeddings = Siglip2VariableSequenceEmbeddings(config)
        self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor
        self.encoder = Siglip2Encoder(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.encoder",
        )
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

    def forward(
        self,
        packed_seq_patches: tuple[torch.Tensor, torch.Tensor],
    ) -> torch.Tensor:
        r"""
        spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
            Tensor containing the spatial dimensions (height, width)
            of the input images.
        """

        seq_patches, token_grids = packed_seq_patches
        seq_sizes = torch.prod(token_grids, dim=-1)

        # Get embeddings from packed sequence
        hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids))

        # Add a pseudo batch dimension for the encoder
        hidden_states = hidden_states.unsqueeze(0)

        cu_seqlens, max_seqlen = create_cumulative_seq_lengths(
            seq_sizes, hidden_states.device
        )

        hidden_states = self.encoder(
            inputs_embeds=hidden_states,
            cu_seqlens=cu_seqlens,
            max_seqlen=max_seqlen,
        )
        hidden_states = self.post_layernorm(hidden_states)

        if self.pixel_shuffle_scale_factor > 1:
            hidden_states = pixel_shuffle_varlen(
                x=hidden_states,
                token_grids=token_grids,
                scale_factor=self.pixel_shuffle_scale_factor,
            )
        # Remove the pseudo batch dimension we added earlier
        hidden_states = hidden_states.squeeze(0)

        # return last_hidden_state
        return hidden_states

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

forward

forward(
    packed_seq_patches: tuple[Tensor, Tensor],
) -> Tensor

spatial_shapes (torch.LongTensor of shape (batch_size, 2)): Tensor containing the spatial dimensions (height, width) of the input images.

Source code in vllm/model_executor/models/isaac.py
def forward(
    self,
    packed_seq_patches: tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
    r"""
    spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
        Tensor containing the spatial dimensions (height, width)
        of the input images.
    """

    seq_patches, token_grids = packed_seq_patches
    seq_sizes = torch.prod(token_grids, dim=-1)

    # Get embeddings from packed sequence
    hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids))

    # Add a pseudo batch dimension for the encoder
    hidden_states = hidden_states.unsqueeze(0)

    cu_seqlens, max_seqlen = create_cumulative_seq_lengths(
        seq_sizes, hidden_states.device
    )

    hidden_states = self.encoder(
        inputs_embeds=hidden_states,
        cu_seqlens=cu_seqlens,
        max_seqlen=max_seqlen,
    )
    hidden_states = self.post_layernorm(hidden_states)

    if self.pixel_shuffle_scale_factor > 1:
        hidden_states = pixel_shuffle_varlen(
            x=hidden_states,
            token_grids=token_grids,
            scale_factor=self.pixel_shuffle_scale_factor,
        )
    # Remove the pseudo batch dimension we added earlier
    hidden_states = hidden_states.squeeze(0)

    # return last_hidden_state
    return hidden_states

_make_writeable

_make_writeable(arr: ndarray) -> ndarray

Return arr itself if it is already writeable, otherwise try to flip the write flag in-place and finally fall back to arr.copy(). This guarantees the buffer handed to torch.from_numpy() is always writeable, silencing the PyTorch warning about undefined behaviour.

Source code in vllm/model_executor/models/isaac.py
def _make_writeable(arr: np.ndarray) -> np.ndarray:
    """Return *arr* itself if it is already writeable, otherwise try to flip the
    write flag in-place and finally fall back to `arr.copy()`.
    This guarantees the buffer handed to `torch.from_numpy()` is always
    writeable, silencing the PyTorch warning about undefined behaviour.
    """
    if arr.flags.writeable:
        return arr

    # First, try the cheap path — in-place flag toggle (works for mmap'd arrays
    # and some shared memory buffers):
    try:
        arr.setflags(write=True)
        return arr  # success: no data copy
    except ValueError:
        # Buffer is inherently read-only (e.g. backed by PyAV / PIL): make copy
        return arr.copy()

create_cumulative_seq_lengths

create_cumulative_seq_lengths(
    seq_sizes: Tensor, device: device
) -> tuple[Tensor, Tensor]

Create cumulative sequence lengths for variable-length attention.

Source code in vllm/model_executor/models/isaac.py
def create_cumulative_seq_lengths(
    seq_sizes: torch.Tensor, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor]:
    """Create cumulative sequence lengths for variable-length attention."""
    cu_seqlens = torch.zeros(len(seq_sizes) + 1, dtype=torch.int32, device=device)
    cu_seqlens[1:] = seq_sizes.cumsum(0)
    max_seqlen = (
        seq_sizes.max()
        if len(seq_sizes) > 0
        else torch.tensor(0, dtype=torch.int32, device=device)
    )
    return cu_seqlens, max_seqlen

create_pixel_shuffle_index_map

create_pixel_shuffle_index_map(
    seq_sizes: Tensor,
    token_grids: Tensor,
    scale_factor: int = 1,
    device: device | None = None,
) -> Tensor

Build a gather-index map that tells us, for every output token after pixel-shuffle, which scale_factor**2 input tokens are being merged.

Args

seq_sizes : (num_images,) - #patches in each image (row-major order) token_grids : (num_images,2) - (height, width) for every image scale_factor : spatial down-scale factor (≥2) device : (optional) overrides seq_sizes.device

Returns

gather_idx : (new_total_seq_len, scale_factor2) int64 tensor. gather_idx[i, j] is the flat index into the original packed sequence for the j-th sub-patch that forms the i-th output token.

Source code in vllm/model_executor/models/isaac.py
def create_pixel_shuffle_index_map(
    seq_sizes: torch.Tensor,
    token_grids: torch.Tensor,
    scale_factor: int = 1,
    device: torch.device | None = None,
) -> torch.Tensor:
    """
    Build a gather-index map that tells us, for every *output* token after
    pixel-shuffle, which `scale_factor**2` *input* tokens are being merged.

    Args
    ----
    seq_sizes     : (num_images,)  - #patches in each image (row-major order)
    token_grids   : (num_images,2) - (height, width) for every image
    scale_factor  : spatial down-scale factor (≥2)
    device        : (optional) overrides `seq_sizes.device`

    Returns
    -------
    gather_idx : (new_total_seq_len, scale_factor**2) int64 tensor.
                 gather_idx[i, j] is the *flat* index into the *original*
                 packed sequence for the j-th sub-patch that forms the
                 i-th output token.
    """
    if device is None:
        device = seq_sizes.device

    r = int(scale_factor)
    if r < 2:
        raise ValueError("`scale_factor` must be ≥ 2")

    # Safety: all spatial dims must be divisible by r
    # Cannot run under torch compile fullgraph mode hence
    if not torch.compiler.is_compiling() and not (
        (token_grids[:, 0] % r == 0).all() and (token_grids[:, 1] % r == 0).all()
    ):
        raise AssertionError(
            "Every (H,W) in `token_grids` must be divisible by "
            f"scale_factor={r}, got {token_grids.tolist()}"
        )

    gather_chunks: list[torch.Tensor] = []
    tok_offset = 0

    for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist(), strict=False):
        # Build the (H, W) grid of flat indices for this image
        grid = torch.arange(seq_len, device=device, dtype=torch.int64) + tok_offset
        grid = grid.view(h, w)  # (H, W)

        # -------- identical ordering to your fixed-res routine --------
        # Step 1: split width into blocks of r
        grid = grid.view(h, w // r, r)  # (H, W/r, r)
        # Step 2: now split height into blocks of r
        grid = grid.view(h // r, r, w // r, r)  # (H/r, r, W/r, r)
        # Step 3: final permutation to (H/r, W/r, r, r)
        grid = grid.permute(0, 2, 1, 3).contiguous()  # (H/r, W/r, r, r)
        # Step 4: each (r, r) block forms one output token
        gather_chunks.append(grid.reshape(-1, r * r))  # (H*W / r², r²)

        tok_offset += seq_len

    # Concatenate over all images in the packed batch
    gather_idx = torch.cat(gather_chunks, dim=0)  # (Σ_i HᵢWᵢ/r², r²)
    return gather_idx

get_image_size_for_max_num_patches

get_image_size_for_max_num_patches(
    image_height: int,
    image_width: int,
    patch_size: int,
    max_num_patches: int,
    min_num_patches: int | None = None,
    eps: float = 1e-05,
    pixel_shuffle_scale: int = 1,
) -> tuple[int, int]

Compute a target resolution whose patch grid satisfies patching parametrization.

Parameters:

Name Type Description Default
image_height `int`

Height in pixels of the source image prior to any resizing.

required
image_width `int`

Width in pixels of the source image prior to any resizing.

required
patch_size `int`

Size of the square patch used by the vision encoder.

required
max_num_patches `int`

Upper bound on (height / patch_size) * (width / patch_size) after resizing.

required
min_num_patches `int`, *optional*

Lower bound on the number of patches. When provided the image will be scaled up if necessary.

None
eps `float`, *optional*, defaults to 1e-5

Convergence tolerance for the internal binary search to determine the target dimensions.

1e-05
pixel_shuffle_scale `int`, *optional*, defaults to 1

Additional stride multiplier applied when pixel shuffle later reduces spatial resolution.

1

Returns:

Type Description
int

tuple[int, int]: Height and width (in pixels) that are multiples of

int

patch_size * pixel_shuffle_scale and respect both the maximum and

tuple[int, int]

optional minimum patch-count constraints.

Source code in vllm/model_executor/models/isaac.py
def get_image_size_for_max_num_patches(
    image_height: int,
    image_width: int,
    patch_size: int,
    max_num_patches: int,
    min_num_patches: int | None = None,
    eps: float = 1e-5,
    pixel_shuffle_scale: int = 1,
) -> tuple[int, int]:
    r"""Compute a target resolution whose patch grid satisfies patching parametrization.

    Args:
        image_height (`int`):
            Height in pixels of the source image prior to any resizing.
        image_width (`int`):
            Width in pixels of the source image prior to any resizing.
        patch_size (`int`):
            Size of the square patch used by the vision encoder.
        max_num_patches (`int`):
            Upper bound on `(height / patch_size) * (width / patch_size)` after
            resizing.
        min_num_patches (`int`, *optional*):
            Lower bound on the number of patches. When provided the image will
            be scaled up if necessary.
        eps (`float`, *optional*, defaults to 1e-5):
            Convergence tolerance for the internal binary search to determine
            the target dimensions.
        pixel_shuffle_scale (`int`, *optional*, defaults to 1):
            Additional stride multiplier applied when pixel shuffle later
            reduces spatial resolution.

    Returns:
        `tuple[int, int]`: Height and width (in pixels) that are multiples of
        `patch_size * pixel_shuffle_scale` and respect both the maximum and
        optional minimum patch-count constraints.
    """

    def get_scaled_image_size(scale, original_size, patch_size, pixel_shuffle_scale):
        scaled_size = scale * original_size
        divisor = patch_size * pixel_shuffle_scale
        scaled_size = math.ceil(scaled_size / divisor) * divisor
        scaled_size = max(divisor, scaled_size)
        return int(scaled_size)

    # Ensure divisibility
    divisor = patch_size * pixel_shuffle_scale
    adjusted_height = math.ceil(image_height / divisor) * divisor
    adjusted_height = max(divisor, adjusted_height)
    adjusted_width = math.ceil(image_width / divisor) * divisor
    adjusted_width = max(divisor, adjusted_width)

    num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size)

    if min_num_patches is not None and num_patches < min_num_patches:
        # Scale up
        scale_min, scale_max = 1.0, 100.0
        while (scale_max - scale_min) >= eps:
            scale = (scale_min + scale_max) / 2
            target_height = get_scaled_image_size(
                scale, image_height, patch_size, pixel_shuffle_scale
            )
            target_width = get_scaled_image_size(
                scale, image_width, patch_size, pixel_shuffle_scale
            )
            num_patches = (target_height / patch_size) * (target_width / patch_size)
            if num_patches >= min_num_patches:
                scale_max = scale
            else:
                scale_min = scale
        scale = scale_max
        target_height = get_scaled_image_size(
            scale, image_height, patch_size, pixel_shuffle_scale
        )
        target_width = get_scaled_image_size(
            scale, image_width, patch_size, pixel_shuffle_scale
        )
        return target_height, target_width
    elif num_patches <= max_num_patches:
        return adjusted_height, adjusted_width
    else:
        # Scale down
        scale_min, scale_max = eps / 10, 1.0
        while (scale_max - scale_min) >= eps:
            scale = (scale_min + scale_max) / 2
            target_height = get_scaled_image_size(
                scale, image_height, patch_size, pixel_shuffle_scale
            )
            target_width = get_scaled_image_size(
                scale, image_width, patch_size, pixel_shuffle_scale
            )
            num_patches = (target_height / patch_size) * (target_width / patch_size)
            if num_patches <= max_num_patches:
                scale_min = scale
            else:
                scale_max = scale
        scale = scale_min
        target_height = get_scaled_image_size(
            scale, image_height, patch_size, pixel_shuffle_scale
        )
        target_width = get_scaled_image_size(
            scale, image_width, patch_size, pixel_shuffle_scale
        )
        return target_height, target_width

patchify_vision

patchify_vision(image: Tensor, patch_size: int) -> Tensor

Convert normalized images into flattened ViT-style patches.

Parameters:

Name Type Description Default
image `torch.Tensor`

Tensor of shape (num_images, height, width, channels).

required
patch_size `int`

Edge length of the square patches

required

Returns:

Type Description
Tensor

torch.Tensor: Patch tensor where each position stores the flattened pixels belonging to that patch.

Raises:

Type Description
ValueError

If height or width is not divisible by patch_size.

Source code in vllm/model_executor/models/isaac.py
def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor:
    r"""Convert normalized images into flattened ViT-style patches.

    Args:
        image (`torch.Tensor`):
            Tensor of shape `(num_images, height, width, channels)`.
        patch_size (`int`):
            Edge length of the square patches

    Returns:
        `torch.Tensor`:
            Patch tensor where each position stores the flattened pixels
            belonging to that patch.

    Raises:
        ValueError: If `height` or `width` is not divisible by `patch_size`.
    """
    num_images, height, width, channels = image.shape
    if height % patch_size or width % patch_size:
        raise ValueError(
            "Dimensions of images "
            f"{image.shape} are not divisible by patch_size={patch_size}."
        )
    patches = image.reshape(
        num_images,
        height // patch_size,
        patch_size,
        width // patch_size,
        patch_size,
        channels,
    )
    patches = patches.permute(0, 1, 3, 2, 4, 5)
    patches = patches.reshape(
        num_images,
        height // patch_size,
        width // patch_size,
        channels * patch_size * patch_size,
    )
    return patches

pixel_shuffle_varlen

pixel_shuffle_varlen(
    x: Tensor, token_grids: Tensor, scale_factor: int = 1
) -> Tensor

Apply pixel shuffle to a packed vision sequence without unpacking per image.

Parameters:

Name Type Description Default
x `torch.Tensor`

Concatenated vision embeddings. Accepts (seq_len, hidden_size) or (1, seq_len, hidden_size) shapes produced by stacking image patches.

required
token_grids `torch.Tensor`

Integer tensor of shape (num_images, 2) whose rows give the (height, width) patch grid sizes corresponding to each image segment inside x.

required
scale_factor `int`, *optional*, defaults to 1

Spatial down-sampling factor specific to pixel shuffle. Values greater than one merge scale_factor**2 neighboring patches into a single embedding channel-group.

1

Returns:

Name Type Description
Tensor

torch.Tensor: Pixel-shuffled embeddings with shape matching the input

convention Tensor

(seq_len, hidden_size * scale_factor**2) when the input

Tensor

was 2D, or (1, seq_len, hidden_size * scale_factor**2) if the

Tensor

singleton batch dimension was present.

Raises:

Type Description
ValueError

If more than one batch item is provided.

Source code in vllm/model_executor/models/isaac.py
def pixel_shuffle_varlen(
    x: torch.Tensor,
    token_grids: torch.Tensor,
    scale_factor: int = 1,
) -> torch.Tensor:
    r"""Apply pixel shuffle to a packed vision sequence without unpacking per image.

    Args:
        x (`torch.Tensor`):
            Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or
            `(1, seq_len, hidden_size)` shapes produced by stacking image
            patches.
        token_grids (`torch.Tensor`):
            Integer tensor of shape `(num_images, 2)` whose rows give the
            `(height, width)` patch grid sizes corresponding to each image
            segment inside `x`.
        scale_factor (`int`, *optional*, defaults to 1):
            Spatial down-sampling factor specific to pixel shuffle. Values
            greater than one merge `scale_factor**2` neighboring patches into a
            single embedding channel-group.

    Returns:
        `torch.Tensor`: Pixel-shuffled embeddings with shape matching the input
        convention: `(seq_len, hidden_size * scale_factor**2)` when the input
        was 2D, or `(1, seq_len, hidden_size * scale_factor**2)` if the
        singleton batch dimension was present.

    Raises:
        ValueError: If more than one batch item is provided.
    """
    keep_batch_dim = x.dim() == 3
    if keep_batch_dim:
        if x.size(0) != 1:
            raise AssertionError("Packed sequence is expected to have batch_size == 1")
        x_ = x.squeeze(0)  # (seq, embed)
    else:
        x_ = x  # (seq, embed)

    embed_dim = x_.size(-1)
    r = int(scale_factor)

    # Calculate seq_sizes from token_grids
    seq_sizes = torch.prod(token_grids, dim=-1)

    # Build index map and gather in one go
    gather_idx = create_pixel_shuffle_index_map(
        seq_sizes=seq_sizes,
        token_grids=token_grids,
        scale_factor=r,
        device=x_.device,
    )  # (new_seq, r²)

    # Gather → (new_seq, r², embed_dim)
    gathered = x_[gather_idx]  # fancy indexing keeps gradient

    # Merge the r² group dimension into channels to finish the shuffle
    out = gathered.reshape(gathered.size(0), embed_dim * r * r)

    # Restore batch dimension if needed
    if keep_batch_dim:
        out = out.unsqueeze(0)
    return out

prepare_image_tensor

prepare_image_tensor(
    image: Tensor, scale: float = VISION_SCALE
) -> Tensor

Standardize RGB images prior to patch extraction via rescaling and whitening.

Parameters:

Name Type Description Default
image `torch.Tensor`

Tensor with shape (..., height, width, 3) containing RGB values. The tensor is converted to floating point if needed.

required
scale `float`, *optional*, defaults to `VISION_SCALE`

Scalar multiplier applied before normalization.

VISION_SCALE

Returns: torch.Tensor: Normalized tensor with the same shape as the input and dtype torch.float32.

Source code in vllm/model_executor/models/isaac.py
def prepare_image_tensor(
    image: torch.Tensor,
    scale: float = VISION_SCALE,
) -> torch.Tensor:
    r"""Standardize RGB images prior to patch extraction via rescaling and whitening.

    Args:
        image (`torch.Tensor`):
            Tensor with shape `(..., height, width, 3)` containing RGB values.
            The tensor is converted to floating point if needed.
        scale (`float`, *optional*, defaults to `VISION_SCALE`):
            Scalar multiplier applied before normalization.
    Returns:
        `torch.Tensor`: Normalized tensor with the same shape as the input and
        dtype `torch.float32`.
    """
    if not torch.is_floating_point(image):
        image = image.float()
    rescaled = image * scale

    # Use precomputed tensors and move to the correct device if needed
    mean_tensor = _MEAN_TENSOR.to(image.device)
    std_tensor = _STD_TENSOR.to(image.device)

    normalized = (rescaled - mean_tensor) / std_tensor
    return normalized

process_vision_for_patches

process_vision_for_patches(
    images: Tensor,
    patch_size: int,
    max_num_patches: int,
    min_num_patches: int | None = None,
    pixel_shuffle_scale: int = 1,
) -> tuple[Tensor, list[int]]

Resize, normalize, and patchify RGB images for the vision encoder.

Parameters:

Name Type Description Default
images `torch.Tensor`

Either (height, width, channels) for a single image or (num_images, height, width, channels) for a batch. Channels are expected to be RGB.

required
patch_size `int`

Edge length of square patches; implictly controls resize grid granularity.

required
max_num_patches `int`

Maximum number of patches allowed after resizing.

required
min_num_patches `int`, *optional*

Minimum number of patches. If provided, the routine upsamples images as needed to satisfy the lower bound.

None
pixel_shuffle_scale `int`, *optional*, defaults to 1

Pixel shuffle scale factor; influences the target grid that the function produces.

1

Returns:

Type Description
Tensor

tuple[torch.Tensor, list[int]]: A pair (patches, dims_virtual)

list[int]

where patches has shape `(num_images, target_h / patch_size, target_w

tuple[Tensor, list[int]]

/ patch_size, channels * patch_size**2)anddims_virtual` encodes

tuple[Tensor, list[int]]

effective (images, height, width) dimensions after optional pixel

tuple[Tensor, list[int]]

shuffling.

Source code in vllm/model_executor/models/isaac.py
def process_vision_for_patches(
    images: torch.Tensor,
    patch_size: int,
    max_num_patches: int,
    min_num_patches: int | None = None,
    pixel_shuffle_scale: int = 1,
) -> tuple[torch.Tensor, list[int]]:
    r"""Resize, normalize, and patchify RGB images for the vision encoder.

    Args:
        images (`torch.Tensor`):
            Either `(height, width, channels)` for a single image or
            `(num_images, height, width, channels)` for a batch. Channels are
            expected to be RGB.
        patch_size (`int`):
            Edge length of square patches; implictly controls resize grid granularity.
        max_num_patches (`int`):
            Maximum number of patches allowed after resizing.
        min_num_patches (`int`, *optional*):
            Minimum number of patches. If provided, the routine upsamples images
            as needed to satisfy the lower bound.
        pixel_shuffle_scale (`int`, *optional*, defaults to 1):
            Pixel shuffle scale factor; influences the target grid that the
            function produces.

    Returns:
        `tuple[torch.Tensor, list[int]]`: A pair `(patches, dims_virtual)`
        where `patches` has shape `(num_images, target_h / patch_size, target_w
        / patch_size, channels * patch_size**2)` and `dims_virtual` encodes
        effective `(images, height, width)` dimensions after optional pixel
        shuffling.
    """
    # Add batch dim if single image
    if images.dim() == 3:
        images = images.unsqueeze(0)

    # Permute to channel first for resize
    images = images.permute(0, 3, 1, 2)

    # Get target dimensions
    _, _, orig_height, orig_width = images.shape
    target_height, target_width = get_image_size_for_max_num_patches(
        orig_height,
        orig_width,
        patch_size,
        max_num_patches,
        min_num_patches=min_num_patches,
        pixel_shuffle_scale=pixel_shuffle_scale,
    )

    # Resize
    images = F.interpolate(
        images,
        size=(target_height, target_width),
        mode="bilinear",
        align_corners=False,
    )

    # Back to channel last
    images = images.permute(0, 2, 3, 1)

    # Normalize
    images = prepare_image_tensor(images)

    # Patchify
    patches = patchify_vision(images, patch_size=patch_size)

    # Calculate dimensions for the patches
    n_images, h_patches, w_patches, _ = patches.shape
    dims_virtual = (
        [1, h_patches, w_patches]
        if pixel_shuffle_scale == 1
        else [1, h_patches // pixel_shuffle_scale, w_patches // pixel_shuffle_scale]
    )

    return patches, dims_virtual