Layerwise weight reloading utilities for vLLM.
This module provides functionality to reload model weights layer-by-layer, which is useful for weight updates without full model reconstruction.
Limitations: 1. Composition with CPU offloading has not been implemented 2. Reloading Attention/MLA weights (q_scale, k_scale, v_scale) has not been implemented 3. Tied parameters will only reflect processing from one of the parent layers (for example, only processing from embed_tokens will have an effect) 4. This design assumes that the number of weights loaded from disk is the same as the number of weights created at model init time. This is not true for quant methods which (1) pad weights or (2) load qkv weights into the same parameter. Both of these cases are non-issues for today's quant methods, but future quantizations may cause reloading to fail
Modules:
finalize_layerwise_reload
Remove the outermost layer of weight loading wrappers.
This function should be applied after initialize_layerwise_reload is applied unwrap the layerwise weight loaders.
Also processes Attention/MLA layers, which must be processed after all other layers
Source code in vllm/model_executor/model_loader/reload/layerwise.py
| def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig):
"""
Remove the outermost layer of weight loading wrappers.
This function should be applied after `initialize_layerwise_reload` is applied
unwrap the layerwise weight loaders.
Also processes Attention/MLA layers, which must be processed after all other layers
"""
model._do_torchao_reload = model._original_do_torchao_reload
for layer in model.modules():
info = get_layerwise_info(layer)
# Attention/MLA layers are processed after all other layers
if isinstance(layer, (Attention, MLAAttention)):
if info.load_numel > 0:
raise NotImplementedError(
"Layerwise reloading of Q/K/V scale weights is not implemented yet"
)
else:
_place_kernel_tensors(layer, info)
layer.process_weights_after_loading(model_config.dtype)
# No weights were loaded, place kernel tensors back
elif info.can_process() and info.load_numel <= 0:
_place_kernel_tensors(layer, info)
# Process non-attention layers which did not load all elements. This can happen
# if the created weight has extra padding elements which are not loaded
# Having too many of these delayed layers can lead to execess memory usage
# see Limitations(4)
elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator]
logger.debug("%s: Delayed processing", layer.__class__.__name__)
_layerwise_process(layer, info)
info.reset()
|
initialize_layerwise_reload
initialize_layerwise_reload(model: Module)
Set up layerwise weight loading with deferred processing.
Must be called after record_metadata_for_reloading. This function: 1. Saves current kernel tensors for later copying 2. Restores layer parameters/buffers from metadata (on meta device) 3. Wraps weight loaders to defer processing until all weights are loaded
When all weights for a layer are loaded, the wrapped loaders will: 1. Materialize the layer onto the target device 2. Load all cached weights 3. Run quantization processing if applicable 4. Copy processed values back to original tensor storage
Source code in vllm/model_executor/model_loader/reload/layerwise.py
| @torch.no_grad()
def initialize_layerwise_reload(model: torch.nn.Module):
"""
Set up layerwise weight loading with deferred processing.
Must be called after `record_metadata_for_reloading`. This function:
1. Saves current kernel tensors for later copying
2. Restores layer parameters/buffers from metadata (on meta device)
3. Wraps weight loaders to defer processing until all weights are loaded
When all weights for a layer are loaded, the wrapped loaders will:
1. Materialize the layer onto the target device
2. Load all cached weights
3. Run quantization processing if applicable
4. Copy processed values back to original tensor storage
"""
# disable torchao reloading to avoid infinite recursion
model._original_do_torchao_reload = getattr(model, "_do_torchao_reload", False)
model._do_torchao_reload = False
for layer in model.modules():
info = get_layerwise_info(layer)
# Skip if the layer has already been initialized
if info.can_process():
continue
# Save current tensors for later copying
info.kernel_tensors = get_layer_params_buffers(layer)
# Restore layer parameters/buffers onto meta device
restore_layer_on_meta(layer, info)
# Track loading progress to determine when to process/copy
info.load_numel = 0
info.load_numel_total = get_layer_size(layer)
# Wrap each parameter's weight loader
# Note that nested wrapping will occur for shared tensors
for name, tensor in get_layer_tensors(layer).items():
if _get_weight_loader(tensor).__name__ != "online_process_loader":
tensor.weight_loader = make_online_process_loader(layer, name)
|
record_metadata_for_reloading(model: Module)
Record layer metadata needed for later reloading.
Stores parameter and buffer metadata as meta tensors for restoration. Must be called before initialize_layerwise_reload.
Source code in vllm/model_executor/model_loader/reload/layerwise.py
| def record_metadata_for_reloading(model: torch.nn.Module):
"""
Record layer metadata needed for later reloading.
Stores parameter and buffer metadata as meta tensors for restoration.
Must be called before `initialize_layerwise_reload`.
"""
for layer in model.modules():
info = get_layerwise_info(layer)
info.restore_metadata = capture_layer_to_meta(layer)
|
support_quantized_model_reload_from_hp_weights
support_quantized_model_reload_from_hp_weights(
original_load_weights: FunctionType,
)
Decorator for load_weights method for AutoWeightsLoader.load_weights to support reloading high precision (bfloat16/float16/float32) weight for an already quantized model, this involves restoring the weights to a high precision weights and then online quantize the weights.
Only applies to torchao quantized models. Assumes that all model weights are loaded within a single weights iterator (cannot perform batched updates)
Source code in vllm/model_executor/model_loader/reload/torchao_decorator.py
| def support_quantized_model_reload_from_hp_weights(original_load_weights: FunctionType):
"""
Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
reloading high precision (bfloat16/float16/float32) weight for an already quantized
model, this involves restoring the weights to a high precision weights and
then online quantize the weights.
Only applies to torchao quantized models. Assumes that all model weights are
loaded within a single weights iterator (cannot perform batched updates)
"""
@wraps(original_load_weights)
def patched_model_load_weights(
self: "AutoWeightsLoader",
weights: Iterable[tuple[str, torch.Tensor]],
*args,
**kwargs,
):
model = self.module
if not getattr(model, "_do_torchao_reload", False):
return original_load_weights(self, weights, *args, **kwargs)
initialize_layerwise_reload(model)
loaded_weights = original_load_weights(self, weights, *args, **kwargs)
finalize_layerwise_reload(model, model._model_config)
return loaded_weights
return patched_model_load_weights
|