class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
"""Abstract interface for pre/post-processing of engine I/O."""
def __init__(self, vllm_config: VllmConfig):
super().__init__()
self.vllm_config = vllm_config
def parse_data(self, data: object) -> IOProcessorInput:
if callable(parse_request := getattr(self, "parse_request", None)):
warnings.warn(
"`parse_request` has been renamed to `parse_data`. "
"Please update your IO Processor Plugin to use the new name. "
"The old name will be removed in v0.19.",
DeprecationWarning,
stacklevel=2,
)
return parse_request(data) # type: ignore
raise NotImplementedError
def merge_sampling_params(
self,
params: SamplingParams | None = None,
) -> SamplingParams:
if callable(
validate_or_generate_params := getattr(
self, "validate_or_generate_params", None
)
):
warnings.warn(
"`validate_or_generate_params` has been split into "
"`merge_sampling_params` and `merge_pooling_params`."
"Please update your IO Processor Plugin to use the new methods. "
"The old name will be removed in v0.19.",
DeprecationWarning,
stacklevel=2,
)
return validate_or_generate_params(params) # type: ignore
return params or SamplingParams()
def merge_pooling_params(
self,
params: PoolingParams | None = None,
) -> PoolingParams:
if callable(
validate_or_generate_params := getattr(
self, "validate_or_generate_params", None
)
):
warnings.warn(
"`validate_or_generate_params` has been split into "
"`merge_sampling_params` and `merge_pooling_params`."
"Please update your IO Processor Plugin to use the new methods. "
"The old name will be removed in v0.19.",
DeprecationWarning,
stacklevel=2,
)
return validate_or_generate_params(params) # type: ignore
return params or PoolingParams(task="plugin")
@abstractmethod
def pre_process(
self,
prompt: IOProcessorInput,
request_id: str | None = None,
**kwargs,
) -> PromptType | Sequence[PromptType]:
raise NotImplementedError
async def pre_process_async(
self,
prompt: IOProcessorInput,
request_id: str | None = None,
**kwargs,
) -> PromptType | Sequence[PromptType]:
return self.pre_process(prompt, request_id, **kwargs)
@abstractmethod
def post_process(
self,
model_output: Sequence[PoolingRequestOutput],
request_id: str | None = None,
**kwargs,
) -> IOProcessorOutput:
raise NotImplementedError
async def post_process_async(
self,
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
request_id: str | None = None,
**kwargs,
) -> IOProcessorOutput:
# We cannot guarantee outputs are returned in the same order they were
# fed to vLLM.
# Let's sort them by id before post_processing
sorted_output = sorted(
[(i, item) async for i, item in model_output], key=lambda output: output[0]
)
collected_output = [output[1] for output in sorted_output]
return self.post_process(collected_output, request_id=request_id, **kwargs)