@default_pooling_type(seq_pooling_type="CLS", tok_pooling_type="ALL")
class ColBERTModel(BertEmbeddingModel):
"""ColBERT late interaction model for retrieval/reranking.
This model extends BertEmbeddingModel with a ColBERT-style linear
projection layer for per-token embeddings. It supports only:
- "token_embed" task: Per-token embeddings for late interaction
ColBERT is fundamentally a per-token embedding model - the linear
projection is trained for per-token representations, not for CLS
pooling. Use a dedicated dense embedding model if you need single-
vector representations.
The ColBERT scoring (MaxSim) is computed externally, either client-side
or via the late interaction scoring path in ServingScores.
Attributes:
colbert_linear: Linear projection from hidden_size to colbert_dim
supports_late_interaction: Flag indicating this model uses late
interaction scoring
"""
# Mark this model as supporting late interaction scoring
supports_late_interaction: ClassVar[Literal[True]] = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Get config before calling super().__init__
config = vllm_config.model_config.hf_config
self.hidden_size = config.hidden_size
self.head_dtype = vllm_config.model_config.head_dtype
# ColBERT dimension - check various config field names used by different
# ColBERT implementations. If not found in config, will be inferred
# from loaded weights in load_weights()
self.colbert_dim: int | None = (
getattr(config, "colbert_dim", None)
or getattr(config, "dim", None)
or getattr(config, "projection_dim", None)
)
# Initialize parent (this will call _build_pooler)
super().__init__(vllm_config=vllm_config, prefix=prefix)
def _build_model(self, vllm_config: VllmConfig, prefix: str = "") -> BertModel:
return BertModel(vllm_config=vllm_config, prefix=prefix)
def _build_colbert_linear(self) -> nn.Linear:
"""Build the ColBERT linear projection layer."""
if self.colbert_dim is None:
raise ValueError("colbert_dim must be set before building the linear layer")
return nn.Linear(
self.hidden_size,
self.colbert_dim,
bias=False,
dtype=self.head_dtype,
)
def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler:
# ColBERT linear projection: hidden_size -> colbert_dim
# Original ColBERT uses bias=False
# If colbert_dim is not set from config, it will be inferred during
# load_weights and the linear layer will be created there
if self.colbert_dim is not None:
self.colbert_linear = self._build_colbert_linear()
else:
# Placeholder - will be created when weights are loaded
self.colbert_linear = None
# ColBERT only supports token_embed - it's fundamentally a per-token
# embedding model.
return pooler_for_token_embed(
pooler_config,
projector=self.colbert_linear,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
def _strip(name: str) -> str:
for p in ("model.", "bert."):
if name.startswith(p):
name = name[len(p) :]
return name
weights_list = list(weights)
model_side: list[tuple[str, torch.Tensor]] = []
colbert_side: list[tuple[str, torch.Tensor]] = []
for name, weight in weights_list:
stripped = _strip(name)
# Handle different checkpoint naming conventions for ColBERT linear
if stripped in ("linear.weight", "colbert_linear.weight"):
colbert_side.append(("colbert_linear.weight", weight))
elif stripped.startswith("linear.") or stripped.startswith(
"colbert_linear."
):
new_name = stripped.replace("linear.", "colbert_linear.")
colbert_side.append((new_name, weight))
else:
model_side.append((stripped, weight))
# Load base BERT weights using BertModel.load_weights which handles QKV fusion
loaded: set[str] = set()
loaded_model = self.model.load_weights(model_side)
loaded.update({"model." + n for n in loaded_model})
# Load ColBERT linear weights
if colbert_side:
for name, weight in colbert_side:
if name == "colbert_linear.weight":
# Infer colbert_dim from weights if not set in config
if self.colbert_dim is None:
# Weight shape is [colbert_dim, hidden_size]
self.colbert_dim = weight.shape[0]
# Create the linear layer now that we know the dimension
self.colbert_linear = self._build_colbert_linear()
# Move to the same device as the model's existing parameters
device = next(self.model.parameters()).device
self.colbert_linear.to(device)
# Update the pooler's projector to use the new linear layer
self.pooler.head.projector = self.colbert_linear
# Load weights directly into the pooler's projector
weight = weight.to(self.pooler.head.projector.weight.device)
self.pooler.head.projector.weight.data.copy_(weight)
loaded.add("pooler.head.projector.weight")
break
return loaded