Skip to content

vllm.model_executor.layers.fla.ops.layernorm_guard

LayerNormFn

Bases: Function

Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
class LayerNormFn(torch.autograd.Function):
    @input_guard
    @staticmethod
    def forward(
        ctx,
        x,
        weight,
        bias,
        z=None,
        eps=1e-6,
        group_size=None,
        norm_before_gate=True,
        is_rms_norm=False,
    ):
        """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""

        x_shape_og = x.shape
        # reshape input data into 2D tensor
        x = x.reshape(-1, x.shape[-1])
        if x.stride(-1) != 1:
            x = x.contiguous()
        if z is not None:
            assert z.shape == x_shape_og
            z = z.reshape(-1, z.shape[-1])
            if z.stride(-1) != 1:
                z = z.contiguous()
        weight = weight.contiguous()
        if bias is not None:
            bias = bias.contiguous()
        y, mean, rstd = layer_norm_fwd(
            x,
            weight,
            bias,
            eps,
            z=z,
            group_size=group_size,
            norm_before_gate=norm_before_gate,
            is_rms_norm=is_rms_norm,
        )
        ctx.save_for_backward(x, weight, bias, mean, rstd, z)
        ctx.x_shape_og = x_shape_og
        ctx.eps = eps
        ctx.group_size = group_size
        ctx.norm_before_gate = norm_before_gate
        ctx.is_rms_norm = is_rms_norm
        return y.reshape(x_shape_og)

forward staticmethod

forward(
    ctx,
    x,
    weight,
    bias,
    z=None,
    eps=1e-06,
    group_size=None,
    norm_before_gate=True,
    is_rms_norm=False,
)

If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))

Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
@input_guard
@staticmethod
def forward(
    ctx,
    x,
    weight,
    bias,
    z=None,
    eps=1e-6,
    group_size=None,
    norm_before_gate=True,
    is_rms_norm=False,
):
    """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""

    x_shape_og = x.shape
    # reshape input data into 2D tensor
    x = x.reshape(-1, x.shape[-1])
    if x.stride(-1) != 1:
        x = x.contiguous()
    if z is not None:
        assert z.shape == x_shape_og
        z = z.reshape(-1, z.shape[-1])
        if z.stride(-1) != 1:
            z = z.contiguous()
    weight = weight.contiguous()
    if bias is not None:
        bias = bias.contiguous()
    y, mean, rstd = layer_norm_fwd(
        x,
        weight,
        bias,
        eps,
        z=z,
        group_size=group_size,
        norm_before_gate=norm_before_gate,
        is_rms_norm=is_rms_norm,
    )
    ctx.save_for_backward(x, weight, bias, mean, rstd, z)
    ctx.x_shape_og = x_shape_og
    ctx.eps = eps
    ctx.group_size = group_size
    ctx.norm_before_gate = norm_before_gate
    ctx.is_rms_norm = is_rms_norm
    return y.reshape(x_shape_og)

LayerNormGated

Bases: Module

Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
class LayerNormGated(nn.Module):
    def __init__(
        self,
        hidden_size,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = True,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """If group_size is not None, we do GroupNorm with each group having group_size elements.
        group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
        """

        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)
        torch.nn.init.zeros_(self.bias)

    def forward(self, x, z=None):
        """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
        return layernorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            group_size=self.group_size,
            eps=self.eps,
            norm_before_gate=self.norm_before_gate,
        )

__init__

__init__(
    hidden_size,
    eps: float = 1e-05,
    group_size: int | None = None,
    norm_before_gate: bool = True,
    device: device | None = None,
    dtype: dtype | None = None,
)

If group_size is not None, we do GroupNorm with each group having group_size elements. group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).

Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
def __init__(
    self,
    hidden_size,
    eps: float = 1e-5,
    group_size: int | None = None,
    norm_before_gate: bool = True,
    device: torch.device | None = None,
    dtype: torch.dtype | None = None,
):
    """If group_size is not None, we do GroupNorm with each group having group_size elements.
    group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
    """

    factory_kwargs = {"device": device, "dtype": dtype}
    super().__init__()
    self.eps = eps
    self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
    self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
    self.group_size = group_size
    self.norm_before_gate = norm_before_gate
    self.reset_parameters()

forward

forward(x, z=None)

If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))

Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
def forward(self, x, z=None):
    """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
    return layernorm_fn(
        x,
        self.weight,
        self.bias,
        z=z,
        group_size=self.group_size,
        eps=self.eps,
        norm_before_gate=self.norm_before_gate,
    )

RMSNormGated

Bases: Module

Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
class RMSNormGated(nn.Module):
    def __init__(
        self,
        hidden_size,
        eps: float = 1e-5,
        group_size: int | None = None,
        norm_before_gate: bool = False,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    ):
        """If group_size is not None, we do GroupNorm with each group having group_size elements.
        group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
        """
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
        self.register_parameter("bias", None)
        self.group_size = group_size
        self.norm_before_gate = norm_before_gate
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.ones_(self.weight)

    def forward(self, x, z=None):
        """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
        return rmsnorm_fn(
            x,
            self.weight,
            self.bias,
            z=z,
            eps=self.eps,
            group_size=self.group_size,
            norm_before_gate=self.norm_before_gate,
        )

__init__

__init__(
    hidden_size,
    eps: float = 1e-05,
    group_size: int | None = None,
    norm_before_gate: bool = False,
    device: device | None = None,
    dtype: dtype | None = None,
)

If group_size is not None, we do GroupNorm with each group having group_size elements. group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).

Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
def __init__(
    self,
    hidden_size,
    eps: float = 1e-5,
    group_size: int | None = None,
    norm_before_gate: bool = False,
    device: torch.device | None = None,
    dtype: torch.dtype | None = None,
):
    """If group_size is not None, we do GroupNorm with each group having group_size elements.
    group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
    """
    factory_kwargs = {"device": device, "dtype": dtype}
    super().__init__()
    self.eps = eps
    self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
    self.register_parameter("bias", None)
    self.group_size = group_size
    self.norm_before_gate = norm_before_gate
    self.reset_parameters()

forward

forward(x, z=None)

If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))

Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
def forward(self, x, z=None):
    """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
    return rmsnorm_fn(
        x,
        self.weight,
        self.bias,
        z=z,
        eps=self.eps,
        group_size=self.group_size,
        norm_before_gate=self.norm_before_gate,
    )

_get_sm_count cached

_get_sm_count(device: device) -> int

Get and cache the SM count for a given device.

Source code in vllm/model_executor/layers/fla/ops/layernorm_guard.py
@lru_cache
def _get_sm_count(device: torch.device) -> int:
    """Get and cache the SM count for a given device."""
    props = torch.cuda.get_device_properties(device)
    return props.multi_processor_count