class MoeWNA16Method(FusedMoEMethodBase):
"""Linear method for MOE WNA16 (W8A16/W4A16) quantization.
Args:
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
"""
def __init__(self, quant_config: MoeWNA16Config, moe: "FusedMoEConfig") -> None:
super().__init__(moe)
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.quant_config = self.quant_config
bit8_pack_factor = self.quant_config.bit8_pack_factor
group_size = self.quant_config.group_size
group_size_div_factor = 1
# make intermediate_size and hidden_size divisible by group_size
# we reduce the group size to ensure that
# and we would repeat the loaded_weight later
while intermediate_size_per_partition % group_size or hidden_size % group_size:
group_size = group_size // 2
group_size_div_factor *= 2
assert group_size >= 32
layer.group_size = group_size
layer.group_size_div_factor = group_size_div_factor
strategy = FusedMoeWeightScaleSupported.GROUP.value
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": False})
assert "weight_loader" in extra_weight_attrs
weight_loader = extra_weight_attrs["weight_loader"]
wrapped_weight_loader = MoeWNA16Method.get_weight_loader(layer, weight_loader)
extra_weight_attrs["weight_loader"] = wrapped_weight_loader
# Fused gate_up_proj (column parallel)
w13_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // bit8_pack_factor,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w13_qweight", w13_qweight)
set_weight_attrs(w13_qweight, extra_weight_attrs)
# down_proj (row parallel)
w2_qweight = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition // bit8_pack_factor,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w2_qweight", w2_qweight)
set_weight_attrs(w2_qweight, extra_weight_attrs)
w13_scales = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
hidden_size // group_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w13_scales", w13_scales)
set_weight_attrs(w13_scales, extra_weight_attrs)
w2_scales = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size,
intermediate_size_per_partition // group_size,
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_scales", w2_scales)
set_weight_attrs(w2_scales, extra_weight_attrs)
if self.quant_config.has_zp:
w13_qzeros = torch.nn.Parameter(
torch.zeros(
num_experts,
2 * intermediate_size_per_partition // bit8_pack_factor,
hidden_size // group_size,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w13_qzeros", w13_qzeros)
set_weight_attrs(w13_qzeros, extra_weight_attrs)
w2_qzeros = torch.nn.Parameter(
torch.zeros(
num_experts,
hidden_size // bit8_pack_factor,
intermediate_size_per_partition // group_size,
dtype=torch.uint8,
),
requires_grad=False,
)
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)
if self.quant_config.linear_quant_method == "gptq":
# some param are unused, but we need to init them in order to
# load weights
invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
if not self.quant_config.has_zp:
invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
for key in invalid_param_keys:
param = torch.nn.Parameter(
torch.empty((0,), dtype=torch.int32), requires_grad=False
)
layer.register_parameter(key, param)
set_weight_attrs(param, extra_weight_attrs)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
assert weight_bits == 4 or weight_bits == 8
config_builder = (
int4_w4a16_moe_quant_config
if weight_bits == 4
else int8_w8a16_moe_quant_config
)
return config_builder(
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size],
)
def apply(
self,
layer: FusedMoE,
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
assert layer.activation == MoEActivation.SILU, (
f"Only SiLU activation is supported, not {layer.activation}."
)
return fused_experts(
x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=not self.moe.disable_inplace,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
@staticmethod
def get_weight_loader(layer, weight_loader):
def convert_awq_tensor(tensor, tensor_type):
# convert awq qweight/qzeros to a standard format (assume int4)
# qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
# qzeros: (k // group_size, n // pack_factor_bit32) ->
# (n // pack_factor_bit8, k // group_size)
# pack_factor_bit32 = 32 // weight_bits
# pack_factor_bit8 = 8 // weight_bits
# 0. suppose origin shape (a, b), dtype int32
# 1. convert to uint8, shape (a, b) -> (a, 4 * b)
size0 = tensor.size(0)
tensor = tensor.view(torch.uint8)
# 2. unpack to uint4 (only when weight_bits == 4)
# shape (a, 4 * b) -> (a, 4 * b, 2)
shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device)
tensor = (tensor[:, :, None] >> shifter) & 0xF
# 3. change order, see
# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
# shape -> (a, 4 * b * pack_factor_bit8)
reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
tensor = tensor.view(size0, -1)
# 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
tensor = tensor.T.contiguous()
# 5. repack (only when weight_bits == 4)
# qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
# qzeros shape -> (4 * b, a)
if tensor_type == "qweight":
tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
elif tensor_type == "qzeros":
tensor = tensor[1::2, :] * 16 + tensor[::2, :]
return tensor
def convert_gptq_int4_qzeros(tensor):
tensor = tensor.view(torch.uint8)
shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device)
tensor = (tensor[:, :, None] >> shifter) & 0xF
tensor = tensor + 1
tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
return tensor
def moe_wna16_weight_loader(
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
return_success: bool = False,
):
if "g_idx" in weight_name:
return False if return_success else None
if not layer.quant_config.has_zp and "qzeros" in weight_name:
return False if return_success else None
device = get_tp_group().device
tp_rank = get_tensor_model_parallel_rank()
loaded_weight = loaded_weight.to(device)
shard_size = layer.intermediate_size_per_partition
# convert gptq and awq weight to a standard format
# awq_marlin uses the same weight format as awq
if layer.quant_config.linear_quant_method in ("awq", "awq_marlin"):
assert layer.quant_config.weight_bits == 4
if "weight" in weight_name:
loaded_weight = convert_awq_tensor(loaded_weight, "qweight")
elif "zeros" in weight_name:
loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
else:
loaded_weight = loaded_weight.T
elif layer.quant_config.linear_quant_method == "gptq":
assert layer.quant_config.weight_bits in [4, 8]
if "weight" in weight_name:
loaded_weight = loaded_weight.T.contiguous().view(torch.uint8)
elif "zeros" in weight_name:
# add 1 to gptq qzeros to align with awq
loaded_weight = loaded_weight.view(torch.uint8)
if layer.quant_config.weight_bits == 4:
loaded_weight = convert_gptq_int4_qzeros(loaded_weight).T
else:
loaded_weight = loaded_weight.T + 1
else:
loaded_weight = loaded_weight.T
# repeat the qzeros/scales to fit new group size
if (
layer.group_size_div_factor > 1
and "qzeros" in weight_name
or "scales" in weight_name
):
loaded_weight = loaded_weight.repeat_interleave(
layer.group_size_div_factor, 1
)
if "w13_qzeros" in weight_name:
tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[
tp_rank
]
if shard_id == "w1":
param.data[expert_id, : shard_size // 2] = tensor
else:
param.data[expert_id, shard_size // 2 :] = tensor
return True if return_success else None
elif "w2_qzeros" in weight_name:
param.data[expert_id] = loaded_weight.view(
loaded_weight.size(0), layer.tp_size, -1
)[:, tp_rank]
return True if return_success else None
else:
# Delegate to the original loader, passing return_success
return weight_loader(
param,
loaded_weight,
weight_name,
shard_id,
expert_id,
return_success=return_success,
)
return moe_wna16_weight_loader