Skip to content

vllm.multimodal

Modules:

Name Description
audio
cache
encoder_budget
evs
hasher
image
inputs
media
parse
processing
registry
utils
video

BatchedTensorInputs module-attribute

BatchedTensorInputs: TypeAlias = dict[str, NestedTensors]

A dictionary containing nested tensors which have been batched via MultiModalKwargsItems.get_data.

MULTIMODAL_REGISTRY module-attribute

MULTIMODAL_REGISTRY = MultiModalRegistry()

The global MultiModalRegistry is used by model runners to dispatch data processing according to the target model.

Info

mm_processing

ModalityData module-attribute

ModalityData: TypeAlias = _T | list[_T | None] | None

Either a single data item, or a list of data items. Can only be None if UUID is provided.

The number of data items allowed per modality is restricted by --limit-mm-per-prompt.

MultiModalDataDict module-attribute

MultiModalDataDict: TypeAlias = Mapping[
    str, ModalityData[Any]
]

A dictionary containing an entry for each modality type to input.

The built-in modalities are defined by MultiModalDataBuiltins.

MultiModalPlaceholderDict module-attribute

MultiModalPlaceholderDict: TypeAlias = Mapping[
    str, Sequence[PlaceholderRange]
]

A dictionary containing per-item placeholder ranges for each modality.

MultiModalUUIDDict module-attribute

MultiModalUUIDDict: TypeAlias = Mapping[
    str, list[str | None] | str
]

A dictionary containing user-provided UUIDs for items in each modality. If a UUID for an item is not provided, its entry will be None and MultiModalHasher will compute a hash for the item.

The UUID will be used to identify the item for all caching purposes (input processing caching, embedding caching, prefix caching, etc).

NestedTensors module-attribute

NestedTensors: TypeAlias = Union[
    list["NestedTensors"],
    list["torch.Tensor"],
    "torch.Tensor",
    tuple["torch.Tensor", ...],
]

Uses a list instead of a tensor if the dimensions of each element do not match.

MultiModalDataBuiltins

Bases: TypedDict

Type annotations for modality types predefined by vLLM.

Source code in vllm/multimodal/inputs.py
@final
class MultiModalDataBuiltins(TypedDict, total=False):
    """Type annotations for modality types predefined by vLLM."""

    image: ModalityData[ImageItem]
    """The input image(s)."""

    video: ModalityData[VideoItem]
    """The input video(s)."""

    audio: ModalityData[AudioItem]
    """The input audio(s)."""

    vision_chunk: ModalityData[VisionChunk]
    """The input visual atom(s) - unified modality for images and video chunks."""

audio instance-attribute

The input audio(s).

image instance-attribute

The input image(s).

video instance-attribute

The input video(s).

vision_chunk instance-attribute

vision_chunk: ModalityData[VisionChunk]

The input visual atom(s) - unified modality for images and video chunks.

MultiModalKwargsItems

Bases: UserDict[str, Sequence[_I]]

A dictionary of processed multi-modal inputs by modality.

For example, given a processor that processes images into pixel_values and image_grid_thw, and audios into input_audio_features, a prompt with 2 images and 1 audio will be processed into a MultiModalKwargsItems with the following structure:

MultiModalKwargsItems(
    {
        "image": [
            # For the first image
            MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
            # For the second imgae
            MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
        ],
        "audio": [
            # For the first audio
            MultiModalKwargsItem({"input_audio_features": ...}),
        ],
    }
)

Unlike HF processing which returns all items in a single dictionary with batched keyword arguments, we split up the items because some of them may already be cached. Also, items from multiple requests may be batched together to improve throughput, using the logic defined by the BaseMultiModalField for each keyword argument.

Source code in vllm/multimodal/inputs.py
class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
    """
    A dictionary of processed multi-modal inputs by modality.

    For example, given a processor that processes
    images into `pixel_values` and `image_grid_thw`,
    and audios into `input_audio_features`,
    a prompt with 2 images and 1 audio will be processed
    into a `MultiModalKwargsItems` with the following structure:

    ```python
    MultiModalKwargsItems(
        {
            "image": [
                # For the first image
                MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
                # For the second imgae
                MultiModalKwargsItem({"pixel_values": ..., "image_grid_thw": ...}),
            ],
            "audio": [
                # For the first audio
                MultiModalKwargsItem({"input_audio_features": ...}),
            ],
        }
    )
    ```

    Unlike HF processing which returns all items
    in a single dictionary with batched keyword arguments,
    we split up the items because some of them may already be cached.
    Also, items from multiple requests may be batched together to improve throughput,
    using the logic defined by the
    [`BaseMultiModalField`][vllm.multimodal.inputs.BaseMultiModalField]
    for each keyword argument.
    """

    @staticmethod
    def from_hf_inputs(
        hf_inputs: "BatchFeature",
        config_by_key: Mapping[str, MultiModalFieldConfig],
    ):
        # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
        # We assume that those fields are not used in vLLM
        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
        keys_by_modality = defaultdict[str, set[str]](set)
        for key, config in config_by_key.items():
            batch = hf_inputs.get(key)
            if batch is not None:
                elems = config.build_elems(key, batch)
                if len(elems) > 0:
                    elems_by_key[key] = elems
                    keys_by_modality[config.modality].add(key)

        items_by_modality = dict[str, list[MultiModalKwargsItem]]()
        for modality, keys in keys_by_modality.items():
            elems_in_modality = {k: elems_by_key[k] for k in keys}
            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}

            if len(set(batch_sizes.values())) > 1:
                raise ValueError(
                    f"Cannot merge different batch sizes for {modality=}! "
                    f"Found: {batch_sizes=}"
                )

            batch_size = next(iter(batch_sizes.values()))
            items_by_modality[modality] = [
                MultiModalKwargsItem({k: v[i] for k, v in elems_in_modality.items()})
                for i in range(batch_size)
            ]

        return MultiModalKwargsItems(items_by_modality)

    def __getitem__(self, modality: str) -> Sequence[_I]:
        if modality not in self:
            raise KeyError(
                f"Modality {modality!r} not found. "
                f"Available modalities: {set(self.keys())}"
            )

        return super().__getitem__(modality)  # type: ignore[return-value]

    def require_data(self) -> "MultiModalKwargsItems[MultiModalKwargsItem]":
        for modality, items in self.items():
            for i, item in enumerate(items):
                if item is None:
                    raise RuntimeError(f"Found empty mm_items[{modality}][{i}]")

        return self  # type: ignore[return-value]

    def get_data(
        self,
        *,
        device: torch.types.Device = None,
        pin_memory: bool = False,
    ) -> BatchedTensorInputs:
        """Construct a dictionary of keyword arguments to pass to the model."""
        from .utils import group_and_batch_mm_items

        items_by_modality = self.require_data()
        batches_by_modality = {
            modality: [
                data
                for _, data in group_and_batch_mm_items(
                    items,
                    device=device,
                    pin_memory=pin_memory,
                )
            ]
            for modality, items in items_by_modality.items()
            if len(items) > 0
        }

        out_data: BatchedTensorInputs = {}
        for _, batches in batches_by_modality.items():
            if len(batches) != 1:
                num_batches_by_modality = {
                    modality: len(batches)
                    for modality, batches in batches_by_modality.items()
                }

                raise RuntimeError(
                    f"Some modalities cannot be merged into a single batch "
                    f"({num_batches_by_modality=})"
                )

            out_data.update(batches[0])

        return out_data

get_data

get_data(
    *, device: Device = None, pin_memory: bool = False
) -> BatchedTensorInputs

Construct a dictionary of keyword arguments to pass to the model.

Source code in vllm/multimodal/inputs.py
def get_data(
    self,
    *,
    device: torch.types.Device = None,
    pin_memory: bool = False,
) -> BatchedTensorInputs:
    """Construct a dictionary of keyword arguments to pass to the model."""
    from .utils import group_and_batch_mm_items

    items_by_modality = self.require_data()
    batches_by_modality = {
        modality: [
            data
            for _, data in group_and_batch_mm_items(
                items,
                device=device,
                pin_memory=pin_memory,
            )
        ]
        for modality, items in items_by_modality.items()
        if len(items) > 0
    }

    out_data: BatchedTensorInputs = {}
    for _, batches in batches_by_modality.items():
        if len(batches) != 1:
            num_batches_by_modality = {
                modality: len(batches)
                for modality, batches in batches_by_modality.items()
            }

            raise RuntimeError(
                f"Some modalities cannot be merged into a single batch "
                f"({num_batches_by_modality=})"
            )

        out_data.update(batches[0])

    return out_data

MultiModalRegistry

A registry that dispatches data processing according to the model.

Source code in vllm/multimodal/registry.py
class MultiModalRegistry:
    """
    A registry that dispatches data processing according to the model.
    """

    def _extract_mm_options(
        self,
        model_config: "ModelConfig",
    ) -> Mapping[str, BaseDummyOptions] | None:
        """
        Extract multimodal dummy options from model config.

        Returns None if no configurable options are found, otherwise returns
        a mapping of modality names to their dummy options.
        """
        if not model_config.multimodal_config:
            return None

        mm_options = {
            m: opt
            for m in model_config.multimodal_config.limit_per_prompt
            if (opt := model_config.multimodal_config.get_dummy_options(m)) is not None
        }

        return mm_options if len(mm_options) > 0 else None

    def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
        """
        Checks if the model supports multimodal inputs.
        Returns True if the model is multimodal with any non-zero supported
        modalities, otherwise returns False, effectively running in
        text-only mode.
        """
        if not model_config.is_multimodal_model:
            return False

        mm_config = model_config.get_multimodal_config()
        info = self._create_processing_info(model_config, tokenizer=None)

        # Check if all supported modalities have limit == 0
        if all(
            mm_config.get_limit_per_prompt(modality) == 0
            for modality in info.supported_mm_limits
        ):
            # If enable_mm_embeds is True, we still need MM infrastructure
            # to process pre-computed embeddings even though encoder won't run
            if mm_config.enable_mm_embeds:
                return True

            logger.info_once(
                "All limits of multimodal modalities supported by the model "
                "are set to 0, running in text-only mode."
            )
            return False

        return True

    def register_processor(
        self,
        processor: MultiModalProcessorFactory[_I],
        *,
        info: ProcessingInfoFactory[_I],
        dummy_inputs: DummyInputsBuilderFactory[_I],
    ):
        """
        Register a multi-modal processor to a model class. The processor
        is constructed lazily, hence a factory method should be passed.

        When the model receives multi-modal data, the provided function is
        invoked to transform the data into a dictionary of model inputs.
        """

        def wrapper(model_cls: N) -> N:
            if "_processor_factory" in model_cls.__dict__:
                logger.warning(
                    "Model class %s already has a multi-modal processor "
                    "registered to %s. It is overwritten by the new one.",
                    model_cls,
                    self,
                )

            model_cls._processor_factory = _ProcessorFactories(
                info=info,
                dummy_inputs=dummy_inputs,
                processor=processor,
            )

            return model_cls

        return wrapper

    def _get_model_cls(self, model_config: "ModelConfig") -> "SupportsMultiModal":
        # Avoid circular import
        from vllm.model_executor.model_loader import get_model_architecture

        model_cls, _ = get_model_architecture(model_config)
        assert hasattr(model_cls, "_processor_factory")
        return cast("SupportsMultiModal", model_cls)

    def _create_processing_ctx(
        self,
        model_config: "ModelConfig",
        observability_config: "ObservabilityConfig | None" = None,
        tokenizer: TokenizerLike | None = None,
    ) -> InputProcessingContext:
        if tokenizer is None:
            tokenizer = cached_tokenizer_from_config(model_config)

        return InputProcessingContext(
            model_config, tokenizer, observability_config=observability_config
        )

    def _create_processing_info(
        self,
        model_config: "ModelConfig",
        observability_config: "ObservabilityConfig | None" = None,
        *,
        tokenizer: TokenizerLike | None = None,
    ) -> BaseProcessingInfo:
        model_cls = self._get_model_cls(model_config)
        factories = model_cls._processor_factory
        ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)
        return factories.info(ctx)

    def create_processor(
        self,
        model_config: "ModelConfig",
        observability_config: "ObservabilityConfig | None" = None,
        *,
        tokenizer: TokenizerLike | None = None,
        cache: BaseMultiModalProcessorCache | None = None,
    ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
        """
        Create a multi-modal processor for a specific model and tokenizer.
        """
        if not model_config.is_multimodal_model:
            raise ValueError(f"{model_config.model} is not a multimodal model")

        model_cls = self._get_model_cls(model_config)
        factories = model_cls._processor_factory

        ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)

        return factories.build_processor(ctx, cache=cache)

    def get_dummy_mm_inputs(
        self,
        model_config: "ModelConfig",
        mm_counts: Mapping[str, int],
        *,
        cache: BaseMultiModalProcessorCache | None = None,
        processor: BaseMultiModalProcessor | None = None,
    ) -> MultiModalInputs:
        """
        Create dummy data for profiling the memory usage of a model.

        The model is identified by `model_config`.
        """
        seq_len = model_config.max_model_len

        if processor is None:
            processor = self.create_processor(model_config, cache=cache)

        mm_config = model_config.get_multimodal_config()
        processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
            seq_len=seq_len,
            mm_counts=mm_counts,
            mm_options=self._extract_mm_options(model_config),
            mm_processor_kwargs=mm_config.mm_processor_kwargs,
        )
        mm_inputs = processor.apply(
            prompt=processor_inputs.prompt,
            mm_items=processor_inputs.mm_items,
            hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
            tokenization_kwargs=processor_inputs.tokenization_kwargs,
        )

        prompt_token_ids = mm_inputs["prompt_token_ids"]
        total_len = len(prompt_token_ids)
        if total_len < seq_len:
            prompt_token_ids.extend([0] * (seq_len - total_len))

        return mm_inputs

    def _get_cache_type(
        self,
        vllm_config: "VllmConfig",
    ) -> Literal[None, "processor_only", "lru", "shm"]:
        model_config = vllm_config.model_config
        if not self.supports_multimodal_inputs(model_config):
            return None

        # Check if the cache is disabled.
        mm_config = model_config.get_multimodal_config()
        if mm_config.mm_processor_cache_gb <= 0:
            return None

        # Check if IPC caching is supported.
        parallel_config = vllm_config.parallel_config
        is_ipc_supported = parallel_config._api_process_count == 1 and (
            parallel_config.data_parallel_size == 1
            or parallel_config.data_parallel_external_lb
        )

        if not is_ipc_supported:
            return "processor_only"

        mm_config = model_config.get_multimodal_config()
        return mm_config.mm_processor_cache_type

    def processor_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> BaseMultiModalProcessorCache | None:
        """Return a `BaseMultiModalProcessorCache`, if enabled."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type is None:
            return None
        elif cache_type == "processor_only":
            return MultiModalProcessorOnlyCache(vllm_config.model_config)
        elif cache_type == "lru":
            return MultiModalProcessorSenderCache(vllm_config.model_config)
        elif cache_type == "shm":
            return ShmObjectStoreSenderCache(vllm_config)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")

    def processor_only_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> MultiModalProcessorOnlyCache | None:
        """Return a `MultiModalProcessorOnlyCache`, if enabled."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type is None:
            return None

        return MultiModalProcessorOnlyCache(vllm_config.model_config)

    def engine_receiver_cache_from_config(
        self,
        vllm_config: "VllmConfig",
    ) -> BaseMultiModalReceiverCache | None:
        """Return a `BaseMultiModalReceiverCache` for the engine process."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type in (None, "processor_only", "shm"):
            return None
        elif cache_type == "lru":
            return MultiModalReceiverCache(vllm_config.model_config)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")

    def worker_receiver_cache_from_config(
        self,
        vllm_config: "VllmConfig",
        shared_worker_lock: LockType,
    ) -> BaseMultiModalReceiverCache | None:
        """Return a `BaseMultiModalReceiverCache` for the worker process."""
        cache_type = self._get_cache_type(vllm_config)
        if cache_type in (None, "processor_only", "lru"):
            return None
        elif cache_type == "shm":
            return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
        else:
            raise ValueError(f"Unknown cache type: {cache_type!r}")

_extract_mm_options

_extract_mm_options(
    model_config: ModelConfig,
) -> Mapping[str, BaseDummyOptions] | None

Extract multimodal dummy options from model config.

Returns None if no configurable options are found, otherwise returns a mapping of modality names to their dummy options.

Source code in vllm/multimodal/registry.py
def _extract_mm_options(
    self,
    model_config: "ModelConfig",
) -> Mapping[str, BaseDummyOptions] | None:
    """
    Extract multimodal dummy options from model config.

    Returns None if no configurable options are found, otherwise returns
    a mapping of modality names to their dummy options.
    """
    if not model_config.multimodal_config:
        return None

    mm_options = {
        m: opt
        for m in model_config.multimodal_config.limit_per_prompt
        if (opt := model_config.multimodal_config.get_dummy_options(m)) is not None
    }

    return mm_options if len(mm_options) > 0 else None

create_processor

create_processor(
    model_config: ModelConfig,
    observability_config: ObservabilityConfig | None = None,
    *,
    tokenizer: TokenizerLike | None = None,
    cache: BaseMultiModalProcessorCache | None = None,
) -> BaseMultiModalProcessor[BaseProcessingInfo]

Create a multi-modal processor for a specific model and tokenizer.

Source code in vllm/multimodal/registry.py
def create_processor(
    self,
    model_config: "ModelConfig",
    observability_config: "ObservabilityConfig | None" = None,
    *,
    tokenizer: TokenizerLike | None = None,
    cache: BaseMultiModalProcessorCache | None = None,
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
    """
    Create a multi-modal processor for a specific model and tokenizer.
    """
    if not model_config.is_multimodal_model:
        raise ValueError(f"{model_config.model} is not a multimodal model")

    model_cls = self._get_model_cls(model_config)
    factories = model_cls._processor_factory

    ctx = self._create_processing_ctx(model_config, observability_config, tokenizer)

    return factories.build_processor(ctx, cache=cache)

engine_receiver_cache_from_config

engine_receiver_cache_from_config(
    vllm_config: VllmConfig,
) -> BaseMultiModalReceiverCache | None

Return a BaseMultiModalReceiverCache for the engine process.

Source code in vllm/multimodal/registry.py
def engine_receiver_cache_from_config(
    self,
    vllm_config: "VllmConfig",
) -> BaseMultiModalReceiverCache | None:
    """Return a `BaseMultiModalReceiverCache` for the engine process."""
    cache_type = self._get_cache_type(vllm_config)
    if cache_type in (None, "processor_only", "shm"):
        return None
    elif cache_type == "lru":
        return MultiModalReceiverCache(vllm_config.model_config)
    else:
        raise ValueError(f"Unknown cache type: {cache_type!r}")

get_dummy_mm_inputs

get_dummy_mm_inputs(
    model_config: ModelConfig,
    mm_counts: Mapping[str, int],
    *,
    cache: BaseMultiModalProcessorCache | None = None,
    processor: BaseMultiModalProcessor | None = None,
) -> MultiModalInputs

Create dummy data for profiling the memory usage of a model.

The model is identified by model_config.

Source code in vllm/multimodal/registry.py
def get_dummy_mm_inputs(
    self,
    model_config: "ModelConfig",
    mm_counts: Mapping[str, int],
    *,
    cache: BaseMultiModalProcessorCache | None = None,
    processor: BaseMultiModalProcessor | None = None,
) -> MultiModalInputs:
    """
    Create dummy data for profiling the memory usage of a model.

    The model is identified by `model_config`.
    """
    seq_len = model_config.max_model_len

    if processor is None:
        processor = self.create_processor(model_config, cache=cache)

    mm_config = model_config.get_multimodal_config()
    processor_inputs = processor.dummy_inputs.get_dummy_processor_inputs(
        seq_len=seq_len,
        mm_counts=mm_counts,
        mm_options=self._extract_mm_options(model_config),
        mm_processor_kwargs=mm_config.mm_processor_kwargs,
    )
    mm_inputs = processor.apply(
        prompt=processor_inputs.prompt,
        mm_items=processor_inputs.mm_items,
        hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
        tokenization_kwargs=processor_inputs.tokenization_kwargs,
    )

    prompt_token_ids = mm_inputs["prompt_token_ids"]
    total_len = len(prompt_token_ids)
    if total_len < seq_len:
        prompt_token_ids.extend([0] * (seq_len - total_len))

    return mm_inputs

processor_cache_from_config

processor_cache_from_config(
    vllm_config: VllmConfig,
) -> BaseMultiModalProcessorCache | None

Return a BaseMultiModalProcessorCache, if enabled.

Source code in vllm/multimodal/registry.py
def processor_cache_from_config(
    self,
    vllm_config: "VllmConfig",
) -> BaseMultiModalProcessorCache | None:
    """Return a `BaseMultiModalProcessorCache`, if enabled."""
    cache_type = self._get_cache_type(vllm_config)
    if cache_type is None:
        return None
    elif cache_type == "processor_only":
        return MultiModalProcessorOnlyCache(vllm_config.model_config)
    elif cache_type == "lru":
        return MultiModalProcessorSenderCache(vllm_config.model_config)
    elif cache_type == "shm":
        return ShmObjectStoreSenderCache(vllm_config)
    else:
        raise ValueError(f"Unknown cache type: {cache_type!r}")

processor_only_cache_from_config

processor_only_cache_from_config(
    vllm_config: VllmConfig,
) -> MultiModalProcessorOnlyCache | None

Return a MultiModalProcessorOnlyCache, if enabled.

Source code in vllm/multimodal/registry.py
def processor_only_cache_from_config(
    self,
    vllm_config: "VllmConfig",
) -> MultiModalProcessorOnlyCache | None:
    """Return a `MultiModalProcessorOnlyCache`, if enabled."""
    cache_type = self._get_cache_type(vllm_config)
    if cache_type is None:
        return None

    return MultiModalProcessorOnlyCache(vllm_config.model_config)

register_processor

register_processor(
    processor: MultiModalProcessorFactory[_I],
    *,
    info: ProcessingInfoFactory[_I],
    dummy_inputs: DummyInputsBuilderFactory[_I],
)

Register a multi-modal processor to a model class. The processor is constructed lazily, hence a factory method should be passed.

When the model receives multi-modal data, the provided function is invoked to transform the data into a dictionary of model inputs.

Source code in vllm/multimodal/registry.py
def register_processor(
    self,
    processor: MultiModalProcessorFactory[_I],
    *,
    info: ProcessingInfoFactory[_I],
    dummy_inputs: DummyInputsBuilderFactory[_I],
):
    """
    Register a multi-modal processor to a model class. The processor
    is constructed lazily, hence a factory method should be passed.

    When the model receives multi-modal data, the provided function is
    invoked to transform the data into a dictionary of model inputs.
    """

    def wrapper(model_cls: N) -> N:
        if "_processor_factory" in model_cls.__dict__:
            logger.warning(
                "Model class %s already has a multi-modal processor "
                "registered to %s. It is overwritten by the new one.",
                model_cls,
                self,
            )

        model_cls._processor_factory = _ProcessorFactories(
            info=info,
            dummy_inputs=dummy_inputs,
            processor=processor,
        )

        return model_cls

    return wrapper

supports_multimodal_inputs

supports_multimodal_inputs(
    model_config: ModelConfig,
) -> bool

Checks if the model supports multimodal inputs. Returns True if the model is multimodal with any non-zero supported modalities, otherwise returns False, effectively running in text-only mode.

Source code in vllm/multimodal/registry.py
def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
    """
    Checks if the model supports multimodal inputs.
    Returns True if the model is multimodal with any non-zero supported
    modalities, otherwise returns False, effectively running in
    text-only mode.
    """
    if not model_config.is_multimodal_model:
        return False

    mm_config = model_config.get_multimodal_config()
    info = self._create_processing_info(model_config, tokenizer=None)

    # Check if all supported modalities have limit == 0
    if all(
        mm_config.get_limit_per_prompt(modality) == 0
        for modality in info.supported_mm_limits
    ):
        # If enable_mm_embeds is True, we still need MM infrastructure
        # to process pre-computed embeddings even though encoder won't run
        if mm_config.enable_mm_embeds:
            return True

        logger.info_once(
            "All limits of multimodal modalities supported by the model "
            "are set to 0, running in text-only mode."
        )
        return False

    return True

worker_receiver_cache_from_config

worker_receiver_cache_from_config(
    vllm_config: VllmConfig, shared_worker_lock: Lock
) -> BaseMultiModalReceiverCache | None

Return a BaseMultiModalReceiverCache for the worker process.

Source code in vllm/multimodal/registry.py
def worker_receiver_cache_from_config(
    self,
    vllm_config: "VllmConfig",
    shared_worker_lock: LockType,
) -> BaseMultiModalReceiverCache | None:
    """Return a `BaseMultiModalReceiverCache` for the worker process."""
    cache_type = self._get_cache_type(vllm_config)
    if cache_type in (None, "processor_only", "lru"):
        return None
    elif cache_type == "shm":
        return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
    else:
        raise ValueError(f"Unknown cache type: {cache_type!r}")