Skip to content

vllm.distributed.device_communicators.base_device_communicator

DeviceCommunicatorBase

Base class for device-specific communicator. It can use the cpu_group to initialize the communicator. If the device has PyTorch integration (PyTorch can recognize its communication backend), the device_group will also be given.

Source code in vllm/distributed/device_communicators/base_device_communicator.py
class DeviceCommunicatorBase:
    """
    Base class for device-specific communicator.
    It can use the `cpu_group` to initialize the communicator.
    If the device has PyTorch integration (PyTorch can recognize its
    communication backend), the `device_group` will also be given.
    """

    def __init__(
        self,
        cpu_group: ProcessGroup,
        device: torch.device | None = None,
        device_group: ProcessGroup | None = None,
        unique_name: str = "",
    ):
        self.device = device or torch.device("cpu")
        self.cpu_group = cpu_group
        self.device_group = device_group
        self.unique_name = unique_name
        self.rank = dist.get_rank(cpu_group)
        self.world_size = dist.get_world_size(cpu_group)
        self.ranks = dist.get_process_group_ranks(cpu_group)
        self.global_rank = dist.get_rank()
        self.global_world_size = dist.get_world_size()
        self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)

        use_ep = False
        all2all_backend = None
        from vllm.config import get_current_vllm_config_or_none

        config = get_current_vllm_config_or_none()
        if config is not None:
            # as long as we use data parallel (coupled data parallel
            # where all data parallel ranks execute forward together),
            # we initialize the all2all manager used in expert parallel.
            use_ep = config.parallel_config.data_parallel_size > 1
            all2all_backend = config.parallel_config.all2all_backend

        self.is_ep_communicator = "ep" in unique_name
        self.use_all2all = self.is_ep_communicator and use_ep
        self.all2all_backend = all2all_backend
        self.all2all_manager: All2AllManagerBase | None = None

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        dist.all_reduce(input_, group=self.device_group)
        return input_

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()
        input_size = input_.size()
        # NOTE: we have to use concat-style all-gather here,
        # stack-style all-gather has compatibility issues with
        # torch.compile . see https://github.com/pytorch/pytorch/issues/138795
        output_size = (input_size[0] * self.world_size,) + input_size[1:]
        # Allocate output tensor.
        output_tensor = torch.empty(
            output_size, dtype=input_.dtype, device=input_.device
        )
        # All-gather.
        dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
        # Reshape
        output_tensor = output_tensor.reshape((self.world_size,) + input_size)
        output_tensor = output_tensor.movedim(0, dim)
        output_tensor = output_tensor.reshape(
            input_size[:dim]
            + (self.world_size * input_size[dim],)
            + input_size[dim + 1 :]
        )
        return output_tensor

    def all_gatherv(
        self,
        input_: torch.Tensor | list[torch.Tensor],
        dim: int = 0,
        sizes: list[int] | None = None,
    ) -> torch.Tensor | list[torch.Tensor]:
        raise NotImplementedError

    def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )

        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()

        # Note: This will produce an incorrect answer if we don't make
        # the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
        input_tensor = input_.movedim(0, dim).contiguous()

        assert input_tensor.shape[0] % world_size == 0
        chunk_size = input_tensor.shape[0] // world_size
        output_shape = (chunk_size,) + input_tensor.shape[1:]

        output_tensor = torch.empty(
            output_shape, dtype=input_tensor.dtype, device=input_tensor.device
        )

        # Perform reduce-scatter operation
        torch.distributed.reduce_scatter_tensor(
            output_tensor, input_tensor, group=self.device_group
        )

        # Reshape before returning
        return output_tensor.movedim(0, dim).contiguous()

    def reduce_scatterv(
        self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
    ) -> torch.Tensor:
        raise NotImplementedError

    def gather(
        self, input_: torch.Tensor, dst: int = 0, dim: int = -1
    ) -> torch.Tensor | None:
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()

        # Allocate output tensor.
        if self.rank_in_group == dst:
            gather_list = [torch.empty_like(input_) for _ in range(world_size)]
        else:
            gather_list = None
        # Gather.
        torch.distributed.gather(
            input_, gather_list, dst=self.ranks[dst], group=self.device_group
        )
        if self.rank_in_group == dst:
            output_tensor = torch.cat(gather_list, dim=dim)
        else:
            output_tensor = None
        return output_tensor

    def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
        """Sends a tensor to the destination rank in a blocking way"""
        """NOTE: `dst` is the local rank of the destination rank."""
        if dst is None:
            dst = (self.rank_in_group + 1) % self.world_size
        torch.distributed.send(tensor, self.ranks[dst], self.device_group)

    def recv(
        self, size: torch.Size, dtype: torch.dtype, src: int | None = None
    ) -> torch.Tensor:
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
        if src is None:
            src = (self.rank_in_group - 1) % self.world_size

        tensor = torch.empty(size, dtype=dtype, device=self.device)
        torch.distributed.recv(tensor, self.ranks[src], self.device_group)
        return tensor

    def destroy(self):
        pass

    def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None:
        """
        Prepare the communication buffer for the model.
        """
        if not self.is_ep_communicator:
            return

        moe_modules = [
            module
            for module in model.modules()
            # TODO(bnell): Should use isinstance but can't.  Maybe search for
            # presence of quant_method.maybe_init_modular_kernel?
            if (
                module.__class__.__name__ == "FusedMoE"
                or module.__class__.__name__ == "SharedFusedMoE"
            )
        ]
        for module in moe_modules:
            module.maybe_init_modular_kernel()

    def dispatch_router_logits(
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
        is_sequence_parallel: bool = False,
        extra_tensors: list[torch.Tensor] | None = None,
    ) -> (
        tuple[torch.Tensor, torch.Tensor]
        | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
    ):
        """
        Dispatch the hidden states and router logits to the appropriate device.
        This is a no-op in the base class.
        """
        if extra_tensors is not None:
            return hidden_states, router_logits, extra_tensors
        return hidden_states, router_logits

    def dispatch(
        self,
        hidden_states: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        is_sequence_parallel: bool = False,
        extra_tensors: list[torch.Tensor] | None = None,
    ) -> (
        tuple[torch.Tensor, torch.Tensor, torch.Tensor]
        | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
    ):
        """
        Dispatch the hidden states and topk weights/ids to the appropriate device.
        This is a no-op in the base class.
        """
        if extra_tensors is not None:
            return hidden_states, topk_weights, topk_ids, extra_tensors
        return hidden_states, topk_weights, topk_ids

    def combine(
        self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
        """
        Combine the hidden states and router logits from the appropriate device.
        This is a no-op in the base class.
        """
        return hidden_states

combine

combine(
    hidden_states: Tensor,
    is_sequence_parallel: bool = False,
) -> Tensor

Combine the hidden states and router logits from the appropriate device. This is a no-op in the base class.

Source code in vllm/distributed/device_communicators/base_device_communicator.py
def combine(
    self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
) -> torch.Tensor:
    """
    Combine the hidden states and router logits from the appropriate device.
    This is a no-op in the base class.
    """
    return hidden_states

dispatch

dispatch(
    hidden_states: Tensor,
    topk_weights: Tensor,
    topk_ids: Tensor,
    is_sequence_parallel: bool = False,
    extra_tensors: list[Tensor] | None = None,
) -> (
    tuple[Tensor, Tensor, Tensor]
    | tuple[Tensor, Tensor, Tensor, list[Tensor]]
)

Dispatch the hidden states and topk weights/ids to the appropriate device. This is a no-op in the base class.

Source code in vllm/distributed/device_communicators/base_device_communicator.py
def dispatch(
    self,
    hidden_states: torch.Tensor,
    topk_weights: torch.Tensor,
    topk_ids: torch.Tensor,
    is_sequence_parallel: bool = False,
    extra_tensors: list[torch.Tensor] | None = None,
) -> (
    tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
    """
    Dispatch the hidden states and topk weights/ids to the appropriate device.
    This is a no-op in the base class.
    """
    if extra_tensors is not None:
        return hidden_states, topk_weights, topk_ids, extra_tensors
    return hidden_states, topk_weights, topk_ids

dispatch_router_logits

dispatch_router_logits(
    hidden_states: Tensor,
    router_logits: Tensor,
    is_sequence_parallel: bool = False,
    extra_tensors: list[Tensor] | None = None,
) -> (
    tuple[Tensor, Tensor]
    | tuple[Tensor, Tensor, list[Tensor]]
)

Dispatch the hidden states and router logits to the appropriate device. This is a no-op in the base class.

Source code in vllm/distributed/device_communicators/base_device_communicator.py
def dispatch_router_logits(
    self,
    hidden_states: torch.Tensor,
    router_logits: torch.Tensor,
    is_sequence_parallel: bool = False,
    extra_tensors: list[torch.Tensor] | None = None,
) -> (
    tuple[torch.Tensor, torch.Tensor]
    | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
    """
    Dispatch the hidden states and router logits to the appropriate device.
    This is a no-op in the base class.
    """
    if extra_tensors is not None:
        return hidden_states, router_logits, extra_tensors
    return hidden_states, router_logits

gather

gather(
    input_: Tensor, dst: int = 0, dim: int = -1
) -> Tensor | None

NOTE: We assume that the input tensor is on the same device across all the ranks. NOTE: dst is the local rank of the destination rank.

Source code in vllm/distributed/device_communicators/base_device_communicator.py
def gather(
    self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> torch.Tensor | None:
    """
    NOTE: We assume that the input tensor is on the same device across
    all the ranks.
    NOTE: `dst` is the local rank of the destination rank.
    """
    world_size = self.world_size
    assert -input_.dim() <= dim < input_.dim(), (
        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
    )
    if dim < 0:
        # Convert negative dim to positive.
        dim += input_.dim()

    # Allocate output tensor.
    if self.rank_in_group == dst:
        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
    else:
        gather_list = None
    # Gather.
    torch.distributed.gather(
        input_, gather_list, dst=self.ranks[dst], group=self.device_group
    )
    if self.rank_in_group == dst:
        output_tensor = torch.cat(gather_list, dim=dim)
    else:
        output_tensor = None
    return output_tensor

prepare_communication_buffer_for_model

prepare_communication_buffer_for_model(
    model: Module,
) -> None

Prepare the communication buffer for the model.

Source code in vllm/distributed/device_communicators/base_device_communicator.py
def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None:
    """
    Prepare the communication buffer for the model.
    """
    if not self.is_ep_communicator:
        return

    moe_modules = [
        module
        for module in model.modules()
        # TODO(bnell): Should use isinstance but can't.  Maybe search for
        # presence of quant_method.maybe_init_modular_kernel?
        if (
            module.__class__.__name__ == "FusedMoE"
            or module.__class__.__name__ == "SharedFusedMoE"
        )
    ]
    for module in moe_modules:
        module.maybe_init_modular_kernel()

recv

recv(
    size: Size, dtype: dtype, src: int | None = None
) -> Tensor

Receives a tensor from the source rank.

Source code in vllm/distributed/device_communicators/base_device_communicator.py
def recv(
    self, size: torch.Size, dtype: torch.dtype, src: int | None = None
) -> torch.Tensor:
    """Receives a tensor from the source rank."""
    """NOTE: `src` is the local rank of the source rank."""
    if src is None:
        src = (self.rank_in_group - 1) % self.world_size

    tensor = torch.empty(size, dtype=dtype, device=self.device)
    torch.distributed.recv(tensor, self.ranks[src], self.device_group)
    return tensor

send

send(tensor: Tensor, dst: int | None = None) -> None

Sends a tensor to the destination rank in a blocking way

Source code in vllm/distributed/device_communicators/base_device_communicator.py
def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
    """Sends a tensor to the destination rank in a blocking way"""
    """NOTE: `dst` is the local rank of the destination rank."""
    if dst is None:
        dst = (self.rank_in_group + 1) % self.world_size
    torch.distributed.send(tensor, self.ranks[dst], self.device_group)