Bases: Fp8LinearMethod
Linear method for Per-Token and Per-Channel FP8 Quantization. Only supports loading quantized BF16 model checkpoints with dynamic activation scaling. To load FP16 model checkpoints, user must specify to convert the FP16 model weight loading into BF16. The weight scaling factor will be initialized after the model weights are loaded.
Limitations: 1. Only support float8_e4m3fnuz data type due to the limitation of torch._scaled_mm (https://github.com/ROCm/pytorch/blob/8c0504d7f3fb0ee4c278c096a5c3caedb01129fa/aten/src/ATen/native/cuda/Blas.cpp#L1041)
Parameters:
| Name | Type | Description | Default |
quant_config | PTPCFp8Config | | required |
Source code in vllm/model_executor/layers/quantization/ptpc_fp8.py
| class PTPCFp8LinearMethod(Fp8LinearMethod):
"""Linear method for Per-Token and Per-Channel FP8 Quantization.
Only supports loading quantized BF16 model checkpoints with dynamic
activation scaling. To load FP16 model checkpoints, user must specify
to convert the FP16 model weight loading into BF16.
The weight scaling factor will be initialized after
the model weights are loaded.
Limitations:
1. Only support float8_e4m3fnuz data type due to the limitation of
torch._scaled_mm (https://github.com/ROCm/pytorch/blob/8c0504d7f3fb0ee4c278c096a5c3caedb01129fa/aten/src/ATen/native/cuda/Blas.cpp#L1041)
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: PTPCFp8Config):
assert current_platform.is_rocm(), (
"PTPCFp8LinearMethod is only supported on ROCm."
)
super().__init__(quant_config=quant_config)
# Force weight quantization
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8DynamicTokenSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
assert layer.weight.data.dtype not in (torch.float16, torch.float32), (
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support "
f"output dtype of bfloat16. {layer.weight.data.dtype} is specified."
)
if layer.weight.data.dtype == torch.bfloat16:
# Quantize the weights.
qweight, weight_scale = ops.scaled_fp8_quant(
layer.weight, scale=None, use_per_token_if_dynamic=True
)
# Update the layer with the new values.
layer.weight = Parameter(
qweight.t(), requires_grad=False
) # Pretranspose the weight
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else:
assert layer.weight.data.dtype == current_platform.fp8_dtype()
assert getattr(layer, "weight_scale", None) is not None
layer.input_scale = None
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.fp8_linear.apply_weights(layer, x, bias)
|