Skip to content

vllm.model_executor.model_loader.weight_utils

Utilities for downloading and initializing model weights.

_natural_sort_key

_natural_sort_key(filepath: str) -> list

Natural sort key for filenames with numeric components, such as model-00001-of-00005.safetensors -> ['model-', 1, '-of-', 5, '.safetensors']

Source code in vllm/model_executor/model_loader/weight_utils.py
def _natural_sort_key(filepath: str) -> list:
    """Natural sort key for filenames with numeric components, such as
    model-00001-of-00005.safetensors -> ['model-', 1, '-of-', 5, '.safetensors']"""
    return [
        int(s) if s.isdigit() else s
        for s in re.split(r"(\d+)", os.path.basename(filepath))
    ]

atomic_writer

atomic_writer(
    filepath: str | Path,
    mode: str = "w",
    encoding: str | None = None,
) -> Generator[IO]

Context manager that provides an atomic file writing routine.

The context manager writes to a temporary file and, if successful, atomically replaces the original file.

Parameters:

Name Type Description Default
filepath str or Path

The path to the file to write.

required
mode str

The file mode for the temporary file (e.g., 'w', 'wb').

'w'
encoding str

The encoding for text mode.

None

Yields:

Type Description
Generator[IO]

file object: A handle to the temporary file.

Source code in vllm/model_executor/model_loader/weight_utils.py
@contextmanager
def atomic_writer(
    filepath: str | Path, mode: str = "w", encoding: str | None = None
) -> Generator[IO]:
    """
    Context manager that provides an atomic file writing routine.

    The context manager writes to a temporary file and, if successful,
    atomically replaces the original file.

    Args:
        filepath (str or Path): The path to the file to write.
        mode (str): The file mode for the temporary file (e.g., 'w', 'wb').
        encoding (str): The encoding for text mode.

    Yields:
        file object: A handle to the temporary file.
    """
    # Create a temporary file in the same directory as the target file
    # to ensure it's on the same filesystem for an atomic replace.
    temp_dir = os.path.dirname(filepath)
    temp_fd, temp_path = tempfile.mkstemp(dir=temp_dir)

    try:
        # Open the temporary file for writing
        with os.fdopen(temp_fd, mode=mode, encoding=encoding) as temp_file:
            yield temp_file

        # If the 'with' block completes successfully,
        # perform the atomic replace.
        os.replace(temp_path, filepath)

    except Exception:
        logger.exception(
            "Error during atomic write. Original file '%s' not modified", filepath
        )
        raise
    finally:
        # Clean up the temporary file if it still exists.
        if os.path.exists(temp_path):
            os.remove(temp_path)

composed_weight_loader

composed_weight_loader(
    loader: LoaderFunction, fn: Callable[[Tensor], Tensor]
) -> LoaderFunction

Create a weight loader that post-processes the weights after loading

Source code in vllm/model_executor/model_loader/weight_utils.py
def composed_weight_loader(
    loader: LoaderFunction, fn: Callable[[torch.Tensor], torch.Tensor]
) -> LoaderFunction:
    """Create a weight loader that post-processes the weights after loading"""

    def composed_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
        loader(param, loaded_weight)
        param.data.copy_(fn(param))
        return

    return composed_loader

convert_pyslice_to_tensor

convert_pyslice_to_tensor(x: Any) -> Tensor

convert PySafeSlice object from safetensors to torch.Tensor

PySafeSlice object supports indexing, which is done before loading the actual tensor and can reduce the amount of memory being read into the memory. However, it does not support more advanced functionalities like .view() or .t(). Therefore, if we need to modify the loaded tensor with these more complicated operators, we need to convert to tensor first.

Source code in vllm/model_executor/model_loader/weight_utils.py
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
    """convert PySafeSlice object from safetensors to torch.Tensor

    PySafeSlice object supports indexing, which is done before loading the
    actual tensor and can reduce the amount of memory being read into the
    memory. However, it does not support more advanced functionalities
    like `.view()` or `.t()`. Therefore, if we need to modify the loaded
    tensor with these more complicated operators, we need to convert to
    tensor first.
    """
    if not isinstance(x, torch.Tensor):
        x = x[:]
    return x

default_weight_loader

default_weight_loader(
    param: Tensor, loaded_weight: Tensor
) -> None

Default weight loader.

Source code in vllm/model_executor/model_loader/weight_utils.py
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
    """Default weight loader."""
    try:
        if param.numel() == 1 and loaded_weight.numel() == 1:
            # Sometimes scalar values aren't considered tensors with shapes
            # so if both param and loaded_weight are a scalar,
            # "broadcast" instead of copy
            param.data.fill_(loaded_weight.item())
        else:
            assert param.size() == loaded_weight.size(), (
                f"Attempted to load weight ({loaded_weight.size()}) "
                f"into parameter ({param.size()})"
            )

            param.data.copy_(loaded_weight)
    except Exception:
        # NOTE: This exception is added for the purpose of setting breakpoint to
        # debug weight loading issues.
        raise

download_safetensors_index_file_from_hf

download_safetensors_index_file_from_hf(
    model_name_or_path: str,
    index_file: str,
    cache_dir: str | None,
    revision: str | None = None,
) -> None

Download hf safetensors index file from Hugging Face Hub.

Parameters:

Name Type Description Default
model_name_or_path str

The model name or path.

required
index_file str

The safetensors index file name

required
cache_dir Optional[str]

The cache directory to store the model weights. If None, will use HF defaults.

required
revision Optional[str]

The revision of the model.

None
Source code in vllm/model_executor/model_loader/weight_utils.py
def download_safetensors_index_file_from_hf(
    model_name_or_path: str,
    index_file: str,
    cache_dir: str | None,
    revision: str | None = None,
) -> None:
    """Download hf safetensors index file from Hugging Face Hub.

    Args:
        model_name_or_path (str): The model name or path.
        index_file (str): The safetensors index file name
        cache_dir (Optional[str]): The cache directory to store the model
            weights. If None, will use HF defaults.
        revision (Optional[str]): The revision of the model.
    """
    # Use file lock to prevent multiple processes from
    # downloading the same model weights at the same time.
    with get_lock(model_name_or_path, cache_dir):
        try:
            # Download the safetensors index file.
            hf_hub_download(
                repo_id=model_name_or_path,
                filename=index_file,
                cache_dir=cache_dir,
                revision=revision,
                local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
            )
        # If file not found on remote or locally, we should not fail since
        # only some models will have index_file.
        except huggingface_hub.utils.LocalEntryNotFoundError:
            logger.info("No %s found in local cache.", index_file)
        except huggingface_hub.utils.EntryNotFoundError:
            logger.info("No %s found in remote.", index_file)

download_weights_from_hf

download_weights_from_hf(
    model_name_or_path: str,
    cache_dir: str | None,
    allow_patterns: list[str],
    revision: str | None = None,
    ignore_patterns: str | list[str] | None = None,
) -> str

Download model weights from Hugging Face Hub.

Parameters:

Name Type Description Default
model_name_or_path str

The model name or path.

required
cache_dir Optional[str]

The cache directory to store the model weights. If None, will use HF defaults.

required
allow_patterns list[str]

The allowed patterns for the weight files. Files matched by any of the patterns will be downloaded.

required
revision Optional[str]

The revision of the model.

None
ignore_patterns Optional[Union[str, list[str]]]

The patterns to filter out the weight files. Files matched by any of the patterns will be ignored.

None

Returns:

Name Type Description
str str

The path to the downloaded model weights.

Source code in vllm/model_executor/model_loader/weight_utils.py
@instrument(span_name="Download weights - HF")
def download_weights_from_hf(
    model_name_or_path: str,
    cache_dir: str | None,
    allow_patterns: list[str],
    revision: str | None = None,
    ignore_patterns: str | list[str] | None = None,
) -> str:
    """Download model weights from Hugging Face Hub.

    Args:
        model_name_or_path (str): The model name or path.
        cache_dir (Optional[str]): The cache directory to store the model
            weights. If None, will use HF defaults.
        allow_patterns (list[str]): The allowed patterns for the
            weight files. Files matched by any of the patterns will be
            downloaded.
        revision (Optional[str]): The revision of the model.
        ignore_patterns (Optional[Union[str, list[str]]]): The patterns to
            filter out the weight files. Files matched by any of the patterns
            will be ignored.

    Returns:
        str: The path to the downloaded model weights.
    """
    assert len(allow_patterns) > 0
    local_only = huggingface_hub.constants.HF_HUB_OFFLINE
    if not local_only:
        # Attempt to reduce allow_patterns to a single pattern
        # so we only have to call snapshot_download once.
        try:
            fs = HfFileSystem()
            file_list = fs.ls(model_name_or_path, detail=False, revision=revision)

            # If downloading safetensors and an index file exists, use the
            # specific file names from the index to avoid downloading
            # unnecessary files (e.g., from subdirectories like "original/").
            index_file = f"{model_name_or_path}/{SAFE_WEIGHTS_INDEX_NAME}"
            if "*.safetensors" in allow_patterns and index_file in file_list:
                index_path = hf_hub_download(
                    repo_id=model_name_or_path,
                    filename=SAFE_WEIGHTS_INDEX_NAME,
                    cache_dir=cache_dir,
                    revision=revision,
                )
                with open(index_path) as f:
                    weight_map = json.load(f)["weight_map"]
                if weight_map:
                    # Extra [] so that weight_map files are treated as a
                    # single allow_pattern in the loop below
                    allow_patterns = [list(set(weight_map.values()))]  # type: ignore[list-item]
                else:
                    allow_patterns = ["*.safetensors"]
            else:
                # Use the first pattern found in the HF repo's files.
                for pattern in allow_patterns:
                    if fnmatch.filter(file_list, pattern):
                        allow_patterns = [pattern]
                        break
        except Exception as e:
            logger.warning(
                "Failed to get file list for '%s'. Trying each pattern in "
                "allow_patterns individually until weights have been "
                "downloaded. Error: %s",
                model_name_or_path,
                e,
            )

    logger.debug("Using model weights format %s", allow_patterns)
    # Use file lock to prevent multiple processes from
    # downloading the same model weights at the same time.
    with get_lock(model_name_or_path, cache_dir):
        start_time = time.perf_counter()
        for allow_pattern in allow_patterns:
            hf_folder = snapshot_download(
                model_name_or_path,
                allow_patterns=allow_pattern,
                ignore_patterns=ignore_patterns,
                cache_dir=cache_dir,
                tqdm_class=DisabledTqdm,
                revision=revision,
                local_files_only=local_only,
            )
            # If we have downloaded weights for this allow_pattern,
            # we don't need to check the rest.
            # allow_pattern can be a list (from weight_map) or str (glob)
            if isinstance(allow_pattern, list):
                break
            if any(Path(hf_folder).glob(allow_pattern)):
                break
        time_taken = time.perf_counter() - start_time
        if time_taken > 0.5:
            logger.info(
                "Time spent downloading weights for %s: %.6f seconds",
                model_name_or_path,
                time_taken,
            )
    return hf_folder

enable_hf_transfer

enable_hf_transfer()

automatically activates hf_transfer

Source code in vllm/model_executor/model_loader/weight_utils.py
def enable_hf_transfer():
    """automatically activates hf_transfer"""
    if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
        try:
            # enable hf hub transfer if available
            import hf_transfer  # type: ignore # noqa

            huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
        except ImportError:
            pass

fastsafetensors_weights_iterator

fastsafetensors_weights_iterator(
    hf_weights_files: list[str], use_tqdm_on_load: bool
) -> Generator[tuple[str, Tensor], None, None]

Iterate over the weights in the model safetensor files using fastsafetensor library.

Source code in vllm/model_executor/model_loader/weight_utils.py
def fastsafetensors_weights_iterator(
    hf_weights_files: list[str],
    use_tqdm_on_load: bool,
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model safetensor files
    using fastsafetensor library."""
    if torch.distributed.is_initialized():
        pg = torch.distributed.group.WORLD
    else:
        pg = SingleGroup()

    device = torch.device(f"cuda:{current_platform.current_device()}")
    hf_weights_files = sorted(hf_weights_files, key=_natural_sort_key)
    weight_files_sub_lists = [
        hf_weights_files[i : i + pg.size()]
        for i in range(0, len(hf_weights_files), pg.size())
    ]

    # Use nogds=True for TP > 1 to avoid cuFileDriverOpen() which
    # initializes the GDS DMA subsystem for all visible GPUs, creating
    # unwanted CUDA contexts on every device.
    nogds = pg.size() > 1

    for f_list in tqdm(
        weight_files_sub_lists,
        desc="Loading safetensors using Fastsafetensor loader",
        disable=not enable_tqdm(use_tqdm_on_load),
        bar_format=_BAR_FORMAT,
    ):
        loader = _init_fastsafetensors_loader(pg, device, f_list, nogds=nogds)
        try:
            try:
                fb = loader.copy_files_to_device()
            except RuntimeError as e:
                if "gds" not in str(e):
                    raise

                loader.close()
                nogds = True
                logger.warning_once(
                    "GDS not enabled, setting `nogds=True`.\n"
                    "For more information, see: https://github.com/foundation-model-stack/fastsafetensors?tab=readme-ov-file#basic-api-usages"
                )
                loader = _init_fastsafetensors_loader(pg, device, f_list, nogds=nogds)
                fb = loader.copy_files_to_device()

            try:
                keys = list(fb.key_to_rank_lidx.keys())
                for k in keys:
                    t = fb.get_tensor(k)
                    yield k, t
            finally:
                fb.close()
        finally:
            loader.close()

filter_files_not_needed_for_inference

filter_files_not_needed_for_inference(
    hf_weights_files: list[str],
) -> list[str]

Exclude files that are not needed for inference.

See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233

Source code in vllm/model_executor/model_loader/weight_utils.py
def filter_files_not_needed_for_inference(hf_weights_files: list[str]) -> list[str]:
    """
    Exclude files that are not needed for inference.

    See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
    """
    blacklist = [
        "training_args.bin",
        "optimizer.bin",
        "optimizer.pt",
        "scheduler.pt",
        "scaler.pt",
    ]
    hf_weights_files = [
        f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist)
    ]
    return hf_weights_files

get_gguf_weight_type_map

get_gguf_weight_type_map(
    gguf_file: str, gguf_to_hf_name_map: dict[str, str]
) -> dict[str, str]

Return GGUF mapped weight's name and its quant type

Source code in vllm/model_executor/model_loader/weight_utils.py
def get_gguf_weight_type_map(
    gguf_file: str, gguf_to_hf_name_map: dict[str, str]
) -> dict[str, str]:
    """
    Return GGUF mapped weight's name and its quant type
    """
    reader = gguf.GGUFReader(gguf_file)
    return {
        gguf_to_hf_name_map[tensor.name]: tensor.tensor_type.name
        for tensor in reader.tensors
        if tensor.name in gguf_to_hf_name_map
    }

gguf_quant_weights_iterator

gguf_quant_weights_iterator(
    gguf_file: str, gguf_to_hf_name_map: dict[str, str]
) -> Generator[tuple[str, Tensor], None, None]

Iterate over the quant weights in the model gguf files and convert them to torch tensors. Be careful of the order of yielding weight types and weights data, we have to yield all weight types first before yielding any weights. Otherwise it would cause issue when loading weights with for packed layer with different quant types.

Source code in vllm/model_executor/model_loader/weight_utils.py
def gguf_quant_weights_iterator(
    gguf_file: str, gguf_to_hf_name_map: dict[str, str]
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """
    Iterate over the quant weights in the model gguf files and convert
    them to torch tensors.
    Be careful of the order of yielding weight types and weights data,
    we have to yield all weight types first before yielding any weights.
    Otherwise it would cause issue when loading weights with for packed
    layer with different quant types.
    """

    reader = gguf.GGUFReader(gguf_file)

    for tensor in reader.tensors:
        if tensor.name in gguf_to_hf_name_map:
            weight_type = tensor.tensor_type
            name = gguf_to_hf_name_map[tensor.name]

            if weight_type.name not in ("F32", "BF16", "F16"):
                weight_type_name = name.replace("weight", "qweight_type")
                weight_type = torch.tensor(weight_type)
                yield weight_type_name, weight_type

    for tensor in reader.tensors:
        if tensor.name in gguf_to_hf_name_map:
            weight = tensor.data
            weight_type = tensor.tensor_type
            name = gguf_to_hf_name_map[tensor.name]
            if weight_type.name not in ("F32", "BF16", "F16"):
                name = name.replace("weight", "qweight")
            if weight_type.name == "BF16" and tensor.data.dtype == np.uint8:
                # BF16 is currently the only "quantization" type that isn't
                # actually quantized but is read as a raw byte tensor.
                # Reinterpret as `torch.bfloat16` tensor.
                weight = weight.view(np.uint16)
                if reader.byte_order == "S":
                    # GGUF endianness != system endianness
                    weight = weight.byteswap()
                param = torch.tensor(weight).view(torch.bfloat16)
            else:
                param = torch.tensor(weight)
            yield name, param

initialize_dummy_weights

initialize_dummy_weights(
    model: Module,
    model_config: ModelConfig,
    low: float = -0.001,
    high: float = 0.001,
    seed: int = 1234,
) -> None

Initialize model weights with random values.

The model weights must be randomly initialized for accurate performance measurements. Additionally, the model weights should not cause NaNs in the forward pass. We empirically found that initializing the weights with values between -1e-3 and 1e-3 works well for most models.

We use per-parameter random seed, so that dummy weights are consistent, even if the model is partitioned across multiple devices. When the seed is fixed, the random values generated by this function only depends on the parameter's number of elements and its data type.

Source code in vllm/model_executor/model_loader/weight_utils.py
def initialize_dummy_weights(
    model: torch.nn.Module,
    model_config: ModelConfig,
    low: float = -1e-3,
    high: float = 1e-3,
    seed: int = 1234,
) -> None:
    """Initialize model weights with random values.

    The model weights must be randomly initialized for accurate performance
    measurements. Additionally, the model weights should not cause NaNs in the
    forward pass. We empirically found that initializing the weights with
    values between -1e-3 and 1e-3 works well for most models.

    We use per-parameter random seed, so that dummy weights are consistent,
    even if the model is partitioned across multiple devices. When the seed
    is fixed, the random values generated by this function only depends on
    the parameter's number of elements and its data type.
    """
    # TODO(future PR): make the check below more generic as more online
    # quant backends are added
    is_fp8_py_quant = model_config.quantization == "fp8"

    for param in model.state_dict().values():
        if is_fp8_py_quant and param.device == torch.device("meta"):
            # for fp8.py's online quantization, dummy weight init will happen
            # in `process_weights_after_loading`.
            # TODO(future PR): consider refactoring dummy model init to compose
            # better with online quantization
            continue

        initialize_single_dummy_weight(param, low, high, seed)

maybe_download_from_modelscope

maybe_download_from_modelscope(
    model: str,
    revision: str | None = None,
    download_dir: str | None = None,
    ignore_patterns: str | list[str] | None = None,
    allow_patterns: list[str] | str | None = None,
) -> str | None

Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.

Returns the path to the downloaded model, or None if the model is not downloaded from ModelScope.

Source code in vllm/model_executor/model_loader/weight_utils.py
def maybe_download_from_modelscope(
    model: str,
    revision: str | None = None,
    download_dir: str | None = None,
    ignore_patterns: str | list[str] | None = None,
    allow_patterns: list[str] | str | None = None,
) -> str | None:
    """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.

    Returns the path to the downloaded model, or None if the model is not
    downloaded from ModelScope."""
    if envs.VLLM_USE_MODELSCOPE:
        # download model from ModelScope hub,
        # lazy import so that modelscope is not required for normal use.
        # pylint: disable=C.
        from modelscope.hub.snapshot_download import snapshot_download

        # Use file lock to prevent multiple processes from
        # downloading the same model weights at the same time.
        with get_lock(model, download_dir):
            if not os.path.exists(model):
                model_path = snapshot_download(
                    model_id=model,
                    cache_dir=download_dir,
                    local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
                    revision=revision,
                    ignore_file_pattern=ignore_patterns,
                    allow_patterns=allow_patterns,
                )
            else:
                model_path = model
        return model_path
    return None

maybe_remap_kv_scale_name

maybe_remap_kv_scale_name(
    name: str, params_dict: dict
) -> str | None

Remap the name of FP8 k/v_scale parameters.

This function handles the remapping of FP8 k/v_scale parameter names. It detects if the given name ends with a suffix and attempts to remap it to the expected name format in the model. If the remapped name is not found in the params_dict, a warning is printed and None is returned.

Parameters:

Name Type Description Default
name str

The original loaded checkpoint parameter name.

required
params_dict dict

Dictionary containing the model's named parameters.

required

Returns:

Name Type Description
str str | None

The remapped parameter name if successful, or the original name if no remapping is needed.

None str | None

If the remapped name is not found in params_dict.

Source code in vllm/model_executor/model_loader/weight_utils.py
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
    """Remap the name of FP8 k/v_scale parameters.

    This function handles the remapping of FP8 k/v_scale parameter names.
    It detects if the given name ends with a suffix and attempts to remap
    it to the expected name format in the model. If the remapped name is not
    found in the params_dict, a warning is printed and None is returned.

    Args:
        name (str): The original loaded checkpoint parameter name.
        params_dict (dict): Dictionary containing the model's named parameters.

    Returns:
        str: The remapped parameter name if successful, or the original name
             if no remapping is needed.
        None: If the remapped name is not found in params_dict.
    """
    if name.endswith(".kv_scale"):
        logger.warning_once(
            "DEPRECATED. Found kv_scale in the checkpoint. "
            "This format is deprecated in favor of separate k_scale and "
            "v_scale tensors and will be removed in a future release. "
            "Functionally, we will remap kv_scale to k_scale and duplicate "
            "k_scale to v_scale"
        )
        # NOTE: we remap the deprecated kv_scale to k_scale
        remapped_name = name.replace(".kv_scale", ".attn.k_scale")
        if remapped_name not in params_dict:
            logger.warning_once(
                "Found kv_scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv_scale is not loaded.",  #  noqa: E501
                name,
                remapped_name,
            )
            return None
        return remapped_name

    if any("mla_attn" in key for key in params_dict):
        attn_str = "mla_attn.mla_attn"
        logger.debug_once(
            f"Found mla_attn with k_scale and v_scale in "
            f"the checkpoint, using {attn_str} as attn_str"
        )
    else:
        attn_str = "attn"
    # Define scale name mapping patterns in order of precedence
    scale_mapping_patterns = [
        # ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
        # .self_attn.attn.{k,v}_scale
        (
            r"\.self_attn\.([kv])_proj\.([kv])_scale$",
            rf".self_attn.{attn_str}.\2_scale",
        ),
        # QKV proj format: .self_attn.qkv_proj.{k,v}_scale ->
        # .self_attn.attn.{k,v}_scale
        (r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"),
        # Qwen3 MoE format: .self_attn.qkqkv_proj.{k,v}_scale ->
        # .self_attn.attn.{k,v}_scale
        (r"\.self_attn\.qkqkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"),
        # NemotronH format: .mixer.{k,v}_proj.{k,v}_scale ->
        # .mixer.attn.{k,v}_scale
        (r"\.mixer\.[kv]_proj\.([kv])_scale$", r".mixer.attn.\1_scale"),
        # Default format: .{k,v}_scale -> .attn.{k,v}_scale
        (r"\.([qkv])_scale$", r".attn.\1_scale"),
        (r"\.([qkv])_zero_point$", r".attn.\1_zero_point"),
    ]

    # Check if name ends with k_scale or v_scale
    if name.endswith(
        (
            ".k_scale",
            ".v_scale",
            ".q_scale",
            ".k_zero_point",
            ".v_zero_point",
            ".q_zero_point",
        )
    ):
        import regex as re

        for pattern, replacement in scale_mapping_patterns:
            if re.search(pattern, name):
                remapped_name = re.sub(pattern, replacement, name)
                if remapped_name not in params_dict:
                    scale_type = name.split(".")[-1]
                    logger.warning_once(
                        "Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.",  # noqa: E501
                        scale_type,
                        name,
                        remapped_name,
                        scale_type,
                    )
                    return None
                return remapped_name

    # If there were no matches, return the untouched param name
    return name

multi_thread_pt_weights_iterator

multi_thread_pt_weights_iterator(
    hf_weights_files: list[str],
    use_tqdm_on_load: bool,
    pt_load_map_location: str | dict[str, str] = "cpu",
    max_workers: int = 4,
) -> Generator[tuple[str, Tensor], None, None]

Multi-Thread iterate over the weights in the model bin/pt files.

Source code in vllm/model_executor/model_loader/weight_utils.py
def multi_thread_pt_weights_iterator(
    hf_weights_files: list[str],
    use_tqdm_on_load: bool,
    pt_load_map_location: str | dict[str, str] = "cpu",
    max_workers: int = 4,
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Multi-Thread iterate over the weights in the model bin/pt files."""

    def _load_file(bin_file: str):
        return torch.load(
            bin_file, map_location=pt_load_map_location, weights_only=True
        )

    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [
            executor.submit(_load_file, bin_file) for bin_file in hf_weights_files
        ]
        futures_iter = tqdm(
            concurrent.futures.as_completed(futures),
            total=len(hf_weights_files),
            desc="Multi-thread loading pt checkpoint shards",
            disable=not enable_tqdm(use_tqdm_on_load),
            bar_format=_BAR_FORMAT,
        )

        for future in futures_iter:
            state = future.result()
            yield from state.items()
            del state

multi_thread_safetensors_weights_iterator

multi_thread_safetensors_weights_iterator(
    hf_weights_files: list[str],
    use_tqdm_on_load: bool,
    max_workers: int = 4,
) -> Generator[tuple[str, Tensor], None, None]

Multi-Thread iterate over the weights in the model safetensor files.

Source code in vllm/model_executor/model_loader/weight_utils.py
def multi_thread_safetensors_weights_iterator(
    hf_weights_files: list[str],
    use_tqdm_on_load: bool,
    max_workers: int = 4,
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Multi-Thread iterate over the weights in the model safetensor files."""

    def _load_file(st_file: str):
        result = load_file(st_file, device="cpu")
        return result

    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files]
        futures_iter = tqdm(
            concurrent.futures.as_completed(futures),
            total=len(hf_weights_files),
            desc="Multi-thread loading shards",
            disable=not enable_tqdm(use_tqdm_on_load),
            bar_format=_BAR_FORMAT,
        )

        for future in futures_iter:
            state_dict = future.result()
            yield from state_dict.items()

np_cache_weights_iterator

np_cache_weights_iterator(
    model_name_or_path: str,
    cache_dir: str | None,
    hf_folder: str,
    hf_weights_files: list[str],
    use_tqdm_on_load: bool,
) -> Generator[tuple[str, Tensor], None, None]

Iterate over the weights in the model np files.

Will dump the model weights to numpy files if they are not already dumped.

Source code in vllm/model_executor/model_loader/weight_utils.py
def np_cache_weights_iterator(
    model_name_or_path: str,
    cache_dir: str | None,
    hf_folder: str,
    hf_weights_files: list[str],
    use_tqdm_on_load: bool,
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model np files.

    Will dump the model weights to numpy files if they are not already dumped.
    """
    # Convert the model weights from torch tensors to numpy arrays for
    # faster loading.
    np_folder = os.path.join(hf_folder, "np")
    os.makedirs(np_folder, exist_ok=True)
    weight_names_file = os.path.join(np_folder, "weight_names.json")
    # Use file lock to prevent multiple processes from
    # dumping the same model weights to numpy at the same time.
    with get_lock(model_name_or_path, cache_dir):
        if not os.path.exists(weight_names_file):
            weight_names: list[str] = []
            for bin_file in tqdm(
                hf_weights_files,
                desc="Loading np_cache checkpoint shards",
                disable=not enable_tqdm(use_tqdm_on_load),
                bar_format=_BAR_FORMAT,
            ):
                state = torch.load(bin_file, map_location="cpu", weights_only=True)
                for name, param in state.items():
                    param_path = os.path.join(np_folder, name)
                    with open(param_path, "wb") as f:
                        np.save(f, param.cpu().detach().numpy())
                    weight_names.append(name)
            with open(weight_names_file, "w") as f:
                json.dump(weight_names, f)

    with open(weight_names_file) as f:
        weight_names = json.load(f)

    for name in weight_names:
        param_path = os.path.join(np_folder, name)
        with open(param_path, "rb") as f:
            param = np.load(f)
        yield name, torch.from_numpy(param)

pt_weights_iterator

pt_weights_iterator(
    hf_weights_files: list[str],
    use_tqdm_on_load: bool,
    pt_load_map_location: str | dict[str, str] = "cpu",
) -> Generator[tuple[str, Tensor], None, None]

Iterate over the weights in the model bin/pt files.

Source code in vllm/model_executor/model_loader/weight_utils.py
def pt_weights_iterator(
    hf_weights_files: list[str],
    use_tqdm_on_load: bool,
    pt_load_map_location: str | dict[str, str] = "cpu",
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model bin/pt files."""
    for bin_file in tqdm(
        hf_weights_files,
        desc="Loading pt checkpoint shards",
        disable=not enable_tqdm(use_tqdm_on_load),
        bar_format=_BAR_FORMAT,
    ):
        state = torch.load(
            bin_file, map_location=pt_load_map_location, weights_only=True
        )
        yield from state.items()
        del state

row_parallel_weight_loader

row_parallel_weight_loader(
    param: Tensor, loaded_weight: Tensor
) -> None

Load weights that are row-parallelized.

Source code in vllm/model_executor/model_loader/weight_utils.py
def row_parallel_weight_loader(
    param: torch.Tensor, loaded_weight: torch.Tensor
) -> None:
    """Load weights that are row-parallelized."""
    tp_rank = get_tensor_model_parallel_rank()
    shard_dim = 0 if param.dim() != 1 else None

    if shard_dim is not None:
        shard_size = param.data.shape[shard_dim]
        start_idx = tp_rank * shard_size
        loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)

    return default_weight_loader(param, loaded_weight)

runai_safetensors_weights_iterator

runai_safetensors_weights_iterator(
    hf_weights_files: list[str],
    use_tqdm_on_load: bool,
    is_distributed: bool = False,
) -> Generator[tuple[str, Tensor], None, None]

Iterate over the weights in the model safetensor files.

Source code in vllm/model_executor/model_loader/weight_utils.py
def runai_safetensors_weights_iterator(
    hf_weights_files: list[str],
    use_tqdm_on_load: bool,
    is_distributed: bool = False,
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model safetensor files."""
    with SafetensorsStreamer() as streamer:
        is_cuda_alike = current_platform.is_cuda_alike()
        device = (
            f"cuda:{current_platform.current_device()}"
            if is_distributed and is_cuda_alike
            else "cpu"
        )

        streamer.stream_files(
            hf_weights_files,
            device=device,
            is_distributed=is_distributed,
        )
        total_tensors = sum(
            len(tensors_meta)
            for tensors_meta in streamer.files_to_tensors_metadata.values()
        )

        tensor_iter = tqdm(
            streamer.get_tensors(),
            total=total_tensors,
            desc="Loading safetensors using Runai Model Streamer",
            bar_format=_BAR_FORMAT,
            disable=not enable_tqdm(use_tqdm_on_load),
            mininterval=2,
        )

        yield from tensor_iter

safetensors_weights_iterator

safetensors_weights_iterator(
    hf_weights_files: list[str],
    use_tqdm_on_load: bool,
    safetensors_load_strategy: str = "lazy",
) -> Generator[tuple[str, Tensor], None, None]

Iterate over the weights in the model safetensor files.

Source code in vllm/model_executor/model_loader/weight_utils.py
def safetensors_weights_iterator(
    hf_weights_files: list[str],
    use_tqdm_on_load: bool,
    safetensors_load_strategy: str = "lazy",
) -> Generator[tuple[str, torch.Tensor], None, None]:
    """Iterate over the weights in the model safetensor files."""
    loading_desc = "Loading safetensors checkpoint shards"
    if safetensors_load_strategy == "eager":
        loading_desc += " (eager)"

    leftover_state_dict: dict[str, torch.Tensor] = {}
    for st_file in tqdm(
        sorted(hf_weights_files, key=_natural_sort_key),
        desc=loading_desc,
        disable=not enable_tqdm(use_tqdm_on_load),
        bar_format=_BAR_FORMAT,
    ):
        if safetensors_load_strategy == "eager":
            with open(st_file, "rb") as f:
                state_dict = load(f.read())
            yield from state_dict.items()
        elif safetensors_load_strategy == "torchao":
            # we can't load flattened torchao tensor subclasses directly into the model
            # instead we reconstruct the subclasses here before returning
            if not torchao_version_at_least("0.15.0"):
                raise ValueError(
                    "Please use torchao version >= 0.15.0 "
                    "to load torchao safetensors checkpoint"
                )
            from torchao.prototype.safetensors.safetensors_support import (
                unflatten_tensor_state_dict,
            )

            with safe_open(st_file, framework="pt") as f:
                state_dict = {}
                for name in f.keys():  # noqa: SIM118
                    state_dict[name] = f.get_tensor(name)

                # update with leftover tensor data from previous iteration, if any
                state_dict.update(leftover_state_dict)
                metadata = f.metadata()
                # due to sharded checkpoints, we are not guaranteed that we have all
                # tensor subclass data on one file
                # state_dict has the leftover data from this step and we wait for
                # missing information to be provided in a future iteration
                unflattened_state_dict, leftover_state_dict = (
                    unflatten_tensor_state_dict(state_dict, metadata)
                )
            yield from unflattened_state_dict.items()
        else:
            with safe_open(st_file, framework="pt") as f:
                for name in f.keys():  # noqa: SIM118
                    param = f.get_tensor(name)
                    yield name, param

sharded_weight_loader

sharded_weight_loader(shard_axis: int) -> LoaderFunction

Create a weight loader that shards the weights along the given axis

Source code in vllm/model_executor/model_loader/weight_utils.py
def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
    """Create a weight loader that shards the weights along the given axis"""

    def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
        tp_rank = get_tensor_model_parallel_rank()

        shard_size = param.data.shape[shard_axis]
        start_idx = tp_rank * shard_size
        loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size)

        return default_weight_loader(param, loaded_weight)

    return loader