class MinTokensLogitsProcessor(LogitsProcessor):
def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
# index -> (min_toks, output_token_ids, stop_token_ids)
self.device = device
self.pin_memory = is_pin_memory
self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}
# (req_idx_tensor,eos_tok_id_tensor)
self.logits_slice: tuple[torch.Tensor, torch.Tensor] = (
self._device_tensor([], torch.int32),
self._device_tensor([], torch.int32),
)
self.neg_inf_tensor = torch.tensor(
-float("inf"), dtype=torch.float32, device=self.device
)
def is_argmax_invariant(self) -> bool:
"""By censoring stop tokens, min-tokens can change the outcome
of the argmax operation in greedy sampling."""
return False
@staticmethod
def add_request(
params: SamplingParams, _: list[int] | None, output_tok_ids: list[int]
) -> tuple[int, Sequence[int], set[int]] | None:
min_tokens = params.min_tokens
if not min_tokens or len(output_tok_ids) >= min_tokens:
return None
return min_tokens, output_tok_ids, params.all_stop_token_ids
def update_state(self, batch_update: BatchUpdate | None):
needs_update = process_dict_updates(
self.min_toks, batch_update, self.add_request
)
if self.min_toks:
# Check for any requests that have attained their min tokens.
to_remove = tuple(
index
for index, (min_toks, out_tok_ids, _) in self.min_toks.items()
if len(out_tok_ids) >= min_toks
)
if to_remove:
needs_update = True
for index in to_remove:
del self.min_toks[index]
# Update tensors if needed.
if needs_update:
reqs: list[int] = []
tok_ids: list[int] = []
for req, (_, _, stop_tok_ids) in self.min_toks.items():
reqs.extend([req] * len(stop_tok_ids))
tok_ids.extend(stop_tok_ids)
self.logits_slice = (
self._device_tensor(reqs, torch.int32),
self._device_tensor(tok_ids, torch.int32),
)
def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
return torch.tensor(
data, device="cpu", dtype=dtype, pin_memory=self.pin_memory
).to(device=self.device, non_blocking=True)
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.min_toks:
# Inhibit EOS token for requests which have not reached min length
logits.index_put_(self.logits_slice, self.neg_inf_tensor)
return logits