Skip to content

vllm.distributed.weight_transfer.packed_tensor

Packed tensor utilities for efficient weight transfer.

packed_broadcast_consumer

packed_broadcast_consumer(
    iterator: Iterator[tuple[str, tuple[list[int], dtype]]],
    group: Any,
    src: int,
    post_unpack_func: Callable[
        [list[tuple[str, Tensor]]], None
    ],
    buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
    num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
) -> None

Consume packed tensors and unpack them into a list of tensors.

Parameters:

Name Type Description Default
iterator Iterator[tuple[str, tuple[list[int], dtype]]]

Iterator of parameter metadata. Returns (name, (shape, dtype))

required
group Any

Process group (PyNcclCommunicator)

required
src int

Source rank (0 in current implementation)

required
post_unpack_func Callable[[list[tuple[str, Tensor]]], None]

Function to apply to each list of (name, tensor) after unpacking

required
buffer_size_bytes int

Size in bytes for each packed tensor buffer. Both producer and consumer must use the same value.

DEFAULT_PACKED_BUFFER_SIZE_BYTES
num_buffers int

Number of buffers for double/triple buffering. Both producer and consumer must use the same value.

DEFAULT_PACKED_NUM_BUFFERS
Source code in vllm/distributed/weight_transfer/packed_tensor.py
def packed_broadcast_consumer(
    iterator: Iterator[tuple[str, tuple[list[int], torch.dtype]]],
    group: Any,
    src: int,
    post_unpack_func: Callable[[list[tuple[str, torch.Tensor]]], None],
    buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
    num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
) -> None:
    """Consume packed tensors and unpack them into a list of tensors.

    Args:
        iterator: Iterator of parameter metadata. Returns (name, (shape, dtype))
        group: Process group (PyNcclCommunicator)
        src: Source rank (0 in current implementation)
        post_unpack_func: Function to apply to each list of (name, tensor) after
                         unpacking
        buffer_size_bytes: Size in bytes for each packed tensor buffer.
                          Both producer and consumer must use the same value.
        num_buffers: Number of buffers for double/triple buffering.
                    Both producer and consumer must use the same value.

    """

    def unpack_tensor(
        packed_tensor: torch.Tensor,
        names: list[str],
        shapes: list[list[int]],
        dtypes: list[torch.dtype],
        tensor_sizes: list[int],
    ) -> list[tuple[str, torch.Tensor]]:
        """Unpack a single tensor into a list of tensors.

        Args:
            packed_tensor: The packed torch.uint8 tensor to unpack
            names: List of tensor names
            shapes: List of tensor shapes
            dtypes: List of tensor dtypes
            tensor_sizes: List of tensor sizes in bytes

        Returns:
            unpacked List[(name, tensor)]
        """
        unpacked_tensors = packed_tensor.split(tensor_sizes)

        unpacked_list = [
            (name, tensor.contiguous().view(dtype).view(*shape))
            for name, shape, dtype, tensor in zip(
                names, shapes, dtypes, unpacked_tensors
            )
        ]

        return unpacked_list

    target_packed_tensor_size = buffer_size_bytes

    streams = [torch.cuda.Stream() for _ in range(num_buffers)]
    buffer_idx = 0

    packing_tensor_meta_data: list[list[tuple[str, list[int], torch.dtype, int]]] = [
        [] for _ in range(num_buffers)
    ]
    packing_tensor_sizes: list[int] = [0 for _ in range(num_buffers)]
    packed_tensors: list[torch.Tensor] = [
        torch.empty(0, dtype=torch.uint8, device="cuda") for _ in range(num_buffers)
    ]

    while True:
        # Synchronize the current stream
        streams[buffer_idx].synchronize()
        with torch.cuda.stream(streams[buffer_idx]):
            # Initialize the packing tensor meta data
            packing_tensor_meta_data[buffer_idx] = []
            packing_tensor_sizes[buffer_idx] = 0
            try:
                # Form a packed tensor
                while True:
                    name, (shape, dtype) = next(iterator)
                    tensor_size = math.prod(shape) * dtype.itemsize
                    packing_tensor_meta_data[buffer_idx].append(
                        (name, shape, dtype, tensor_size)
                    )
                    packing_tensor_sizes[buffer_idx] += tensor_size
                    if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size:
                        break
                # Create a packed tensor and broadcast it
                packed_tensors[buffer_idx] = torch.empty(
                    packing_tensor_sizes[buffer_idx], dtype=torch.uint8, device="cuda"
                )
                group.broadcast(packed_tensors[buffer_idx], src=src)
                # Load the packed tensor into the model
                names, shapes, dtypes, tensor_sizes = zip(
                    *packing_tensor_meta_data[buffer_idx]
                )
                post_unpack_func(
                    unpack_tensor(
                        packed_tensors[buffer_idx],
                        list(names),
                        list(shapes),
                        list(dtypes),
                        list(tensor_sizes),
                    )
                )
                # Move to the next buffer
                buffer_idx = (buffer_idx + 1) % num_buffers
            except StopIteration:
                # Do the last broadcast if there are remaining tensors
                if len(packing_tensor_meta_data[buffer_idx]) > 0:
                    # Create a packed tensor and broadcast it
                    packed_tensors[buffer_idx] = torch.empty(
                        packing_tensor_sizes[buffer_idx],
                        dtype=torch.uint8,
                        device="cuda",
                    )
                    group.broadcast(packed_tensors[buffer_idx], src=src)
                    # Load the packed tensor into the model
                    names, shapes, dtypes, tensor_sizes = zip(
                        *packing_tensor_meta_data[buffer_idx]
                    )
                    post_unpack_func(
                        unpack_tensor(
                            packed_tensors[buffer_idx],
                            list(names),
                            list(shapes),
                            list(dtypes),
                            list(tensor_sizes),
                        )
                    )
                break

packed_broadcast_producer

packed_broadcast_producer(
    iterator: Iterator[tuple[str, Tensor]],
    group: Any,
    src: int,
    post_iter_func: Callable[[tuple[str, Tensor]], Tensor],
    buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
    num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
) -> None

Broadcast tensors in a packed manner from trainer to workers.

Parameters:

Name Type Description Default
iterator Iterator[tuple[str, Tensor]]

Iterator of model parameters. Returns a tuple of (name, tensor)

required
group Any

Process group (PyNcclCommunicator)

required
src int

Source rank (0 in current implementation)

required
post_iter_func Callable[[tuple[str, Tensor]], Tensor]

Function to apply to each (name, tensor) pair before packing, should return a tensor

required
buffer_size_bytes int

Size in bytes for each packed tensor buffer. Both producer and consumer must use the same value.

DEFAULT_PACKED_BUFFER_SIZE_BYTES
num_buffers int

Number of buffers for double/triple buffering. Both producer and consumer must use the same value.

DEFAULT_PACKED_NUM_BUFFERS
Source code in vllm/distributed/weight_transfer/packed_tensor.py
def packed_broadcast_producer(
    iterator: Iterator[tuple[str, torch.Tensor]],
    group: Any,
    src: int,
    post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor],
    buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
    num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
) -> None:
    """Broadcast tensors in a packed manner from trainer to workers.

    Args:
        iterator: Iterator of model parameters. Returns a tuple of (name, tensor)
        group: Process group (PyNcclCommunicator)
        src: Source rank (0 in current implementation)
        post_iter_func: Function to apply to each (name, tensor) pair before
                       packing, should return a tensor
        buffer_size_bytes: Size in bytes for each packed tensor buffer.
                          Both producer and consumer must use the same value.
        num_buffers: Number of buffers for double/triple buffering.
                    Both producer and consumer must use the same value.

    """
    target_packed_tensor_size = buffer_size_bytes

    streams = [torch.cuda.Stream() for _ in range(num_buffers)]
    buffer_idx = 0

    packing_tensor_list: list[list[torch.Tensor]] = [[] for _ in range(num_buffers)]
    packing_tensor_sizes: list[int] = [0 for _ in range(num_buffers)]
    packed_tensors: list[torch.Tensor] = [
        torch.empty(0, dtype=torch.uint8, device="cuda") for _ in range(num_buffers)
    ]

    while True:
        # Synchronize the current stream
        streams[buffer_idx].synchronize()
        # Start tasks for the new buffer in a new stream
        with torch.cuda.stream(streams[buffer_idx]):
            try:
                # Initialize the packing tensor list and sizes
                packing_tensor_list[buffer_idx] = []
                packing_tensor_sizes[buffer_idx] = 0
                # Pack the tensors
                while True:
                    # Apply post processing and convert to linearized uint8 tensor
                    tensor = (
                        post_iter_func(next(iterator))
                        .contiguous()
                        .view(torch.uint8)
                        .view(-1)
                    )
                    packing_tensor_list[buffer_idx].append(tensor)
                    packing_tensor_sizes[buffer_idx] += tensor.numel()
                    if packing_tensor_sizes[buffer_idx] > target_packed_tensor_size:
                        break
                # Pack the tensors and call broadcast collective
                packed_tensors[buffer_idx] = torch.cat(
                    packing_tensor_list[buffer_idx], dim=0
                )
                group.broadcast(packed_tensors[buffer_idx], src=src)
                # Move to the next buffer
                buffer_idx = (buffer_idx + 1) % num_buffers
            except StopIteration:
                # Do the last broadcast if there are remaining tensors
                if len(packing_tensor_list[buffer_idx]) > 0:
                    packed_tensors[buffer_idx] = torch.cat(
                        packing_tensor_list[buffer_idx], dim=0
                    )
                    group.broadcast(packed_tensors[buffer_idx], src=src)
                break