Skip to content

vllm.tool_parsers.hunyuan_a13b_tool_parser

HunyuanA13BToolParser

Bases: ToolParser

Source code in vllm/tool_parsers/hunyuan_a13b_tool_parser.py
class HunyuanA13BToolParser(ToolParser):
    def __init__(self, tokenizer: TokenizerLike):
        super().__init__(tokenizer)

        # Initialize state for streaming mode
        self.prev_tool_calls: list[dict] = []
        self.current_tool_id = -1
        self.current_tool_name_sent = False
        self.streamed_args: list[str] = []  # Track arguments sent for each tool

        # For backward compatibility with tests
        self.current_tools_sent: list[bool] = []

        # For backward compatibility with serving code
        self.prev_tool_call_arr = []

        # Regex patterns for preprocessing
        self.answer_tool_calls_pattern = re.compile(
            r"<tool_calls>([\s\S]*?)</tool_calls>", re.DOTALL
        )

        self.tool_name_reg = re.compile(r'"name"\s*:\s*"([^"]+)"')

        self.tool_empty_arg_reg = re.compile(
            r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}'
        )

        # TODO: not support nested json object in fc arguments.
        self.tool_non_empty_arg_reg = re.compile(
            r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
        )

        self.bot_string = "<tool_calls>"

        # Define streaming state type to be initialized later
        self.streaming_state: dict[str, Any] = {
            "current_tool_index": -1,
            "tool_ids": [],
            "sent_tools": [],
        }

    def preprocess_model_output(
        self, model_output: str
    ) -> tuple[str | None, str | None]:
        # find the location tool call
        for match in self.answer_tool_calls_pattern.finditer(model_output):
            start, end = match.span()
            # check tool_calls whether in side of <think>
            think_regions = [
                (m.start(), m.end())
                for m in re.finditer(
                    r"<think>(.*?)</think>", model_output, flags=re.DOTALL
                )
            ]
            in_think = any(
                start > t_start and end < t_end for t_start, t_end in think_regions
            )
            if not in_think:
                content = model_output[:start]
                tool_calls_content = match.group(1).strip()
                try:
                    json.loads(tool_calls_content)
                    return content, tool_calls_content
                except Exception:
                    continue
        return model_output, None

    def extract_tool_calls(
        self, model_output: str, request: ChatCompletionRequest
    ) -> ExtractedToolCallInformation:
        """
        Extract tool calls from a complete model output.
        """
        try:
            # Preprocess the model output
            content, potential_tool_calls = self.preprocess_model_output(model_output)

            if not potential_tool_calls:
                # some text should be filtered out for no function call
                # this text is in a13b's chat template.
                if content:
                    content = content.replace("助手:", "", 1)
                return ExtractedToolCallInformation(
                    tools_called=False, tool_calls=[], content=content
                )

            # Parse the potential tool calls as JSON
            tool_calls_data = json.loads(potential_tool_calls)

            # Ensure it's an array
            if not isinstance(tool_calls_data, list):
                logger.debug("Tool calls data is not an array")
                return ExtractedToolCallInformation(
                    tools_called=False,
                    tool_calls=[],
                    content=content or model_output,
                )

            tool_calls: list[ToolCall] = []

            for idx, call in enumerate(tool_calls_data):
                if (
                    not isinstance(call, dict)
                    or "name" not in call
                    or "arguments" not in call
                ):
                    continue

                tool_call = ToolCall(
                    id=f"call_{random_uuid()}",
                    type="function",
                    function=FunctionCall(
                        name=call["name"],
                        arguments=(
                            json.dumps(call["arguments"])
                            if isinstance(call["arguments"], dict)
                            else call["arguments"]
                        ),
                    ),
                )
                tool_calls.append(tool_call)

            if not content or len(content.strip()) == 0:
                # clear the whitespace content.
                content = None

            return ExtractedToolCallInformation(
                tools_called=len(tool_calls) > 0,
                tool_calls=tool_calls,
                content=content,
            )

        except Exception:
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> DeltaMessage | None:
        """
        Extract tool calls for streaming mode.
        """

        start_idx = consume_space(0, current_text)
        if current_text[start_idx:].startswith(self.bot_string):
            start_idx = consume_space(start_idx + len(self.bot_string), current_text)
        if (
            not current_text
            or start_idx >= len(current_text)
            or current_text[start_idx] != "["
        ):
            return DeltaMessage(content=delta_text)

        self._try_parse_json_tools(current_text[start_idx:])

        test_delta = self._handle_test_compatibility(current_text)
        if test_delta:
            return test_delta

        name_matches = list(self.tool_name_reg.finditer(current_text))
        tool_count = len(name_matches)
        if tool_count == 0:
            return None
        self._ensure_state_arrays(tool_count)
        current_idx = self.streaming_state["current_tool_index"]

        name_delta = self._handle_tool_name_streaming(
            current_idx, tool_count, name_matches
        )
        if name_delta:
            return name_delta

        args_delta = self._handle_tool_args_streaming(
            current_text, current_idx, tool_count
        )
        if args_delta:
            return args_delta

        return None

    def _try_parse_json_tools(self, current_text: str):
        try:
            parsed_tools = json.loads(current_text)
            if isinstance(parsed_tools, list):
                self.prev_tool_call_arr = parsed_tools
        except json.JSONDecodeError:
            pass

    def _handle_test_compatibility(self, current_text: str):
        if len(self.current_tools_sent) > 0:
            if (
                len(self.current_tools_sent) == 1
                and self.current_tools_sent[0] is False
            ):
                name_match = self.tool_name_reg.search(current_text)
                if name_match:
                    function_name = name_match.group(1)
                    tool_id = f"chatcmpl-tool-{random_uuid()}"
                    delta = DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                index=0,
                                type="function",
                                id=tool_id,
                                function=DeltaFunctionCall(
                                    name=function_name
                                ).model_dump(exclude_none=True),
                            )
                        ]
                    )
                    self.current_tools_sent = [True]
                    self.current_tool_id = 0
                    self.streaming_state["current_tool_index"] = 0
                    if len(self.streaming_state["sent_tools"]) == 0:
                        self.streaming_state["sent_tools"].append(
                            {
                                "sent_name": True,
                                "sent_arguments_prefix": False,
                                "sent_arguments": "",
                            }
                        )
                    else:
                        self.streaming_state["sent_tools"][0]["sent_name"] = True
                    self.current_tool_name_sent = True
                    return delta
        return None

    def _ensure_state_arrays(self, tool_count: int):
        while len(self.streaming_state["sent_tools"]) < tool_count:
            self.streaming_state["sent_tools"].append(
                {
                    "sent_name": False,
                    "sent_arguments_prefix": False,
                    "sent_arguments": "",
                }
            )
        while len(self.streaming_state["tool_ids"]) < tool_count:
            self.streaming_state["tool_ids"].append(None)

    def _handle_tool_name_streaming(
        self, current_idx: int, tool_count: int, name_matches
    ):
        if current_idx == -1 or current_idx < tool_count - 1:
            next_idx = current_idx + 1
            if (
                next_idx < tool_count
                and not self.streaming_state["sent_tools"][next_idx]["sent_name"]
            ):
                self.streaming_state["current_tool_index"] = next_idx
                self.current_tool_id = next_idx
                current_idx = next_idx
                tool_name = name_matches[current_idx].group(1)
                tool_id = f"call_{current_idx}_{random_uuid()}"
                self.streaming_state["tool_ids"][current_idx] = tool_id
                delta = DeltaMessage(
                    tool_calls=[
                        DeltaToolCall(
                            index=current_idx,
                            type="function",
                            id=tool_id,
                            function=DeltaFunctionCall(name=tool_name).model_dump(
                                exclude_none=True
                            ),
                        )
                    ]
                )
                self.streaming_state["sent_tools"][current_idx]["sent_name"] = True
                self.current_tool_name_sent = True
                while len(self.streamed_args) <= current_idx:
                    self.streamed_args.append("")
                return delta
        return None

    def _handle_tool_args_streaming(
        self, current_text: str, current_idx: int, tool_count: int
    ):
        if current_idx >= 0 and current_idx < tool_count:
            empty_args_match = self.tool_empty_arg_reg.search(current_text)
            if empty_args_match and empty_args_match.start() > 0:
                for i in range(tool_count):
                    if i == current_idx:
                        if not self.streaming_state["sent_tools"][current_idx][
                            "sent_arguments_prefix"
                        ]:
                            self.streaming_state["sent_tools"][current_idx][
                                "sent_arguments_prefix"
                            ] = True
                            self.streaming_state["sent_tools"][current_idx][
                                "sent_arguments"
                            ] = "{}"
                            while len(self.streamed_args) <= current_idx:
                                self.streamed_args.append("")
                            self.streamed_args[current_idx] += "{}"
                            delta = DeltaMessage(
                                tool_calls=[
                                    DeltaToolCall(
                                        index=current_idx,
                                        function=DeltaFunctionCall(
                                            arguments="{}"
                                        ).model_dump(exclude_none=True),
                                    )
                                ]
                            )
                            if current_idx < tool_count - 1:
                                self.streaming_state["current_tool_index"] += 1
                                self.current_tool_id = self.streaming_state[
                                    "current_tool_index"
                                ]
                            return delta

            args_matches = list(self.tool_non_empty_arg_reg.finditer(current_text))
            if current_idx < len(args_matches):
                args_text = args_matches[current_idx].group(1)
                is_last_tool = current_idx == tool_count - 1
                if not is_last_tool:
                    next_tool_pos = current_text.find(
                        "},{", args_matches[current_idx].start()
                    )
                    if next_tool_pos != -1:
                        args_end_pos = next_tool_pos + 1
                        args_text = (
                            current_text[
                                args_matches[current_idx].start() : args_end_pos
                            ]
                            .split('"arguments":')[1]
                            .strip()
                        )
                sent_args = self.streaming_state["sent_tools"][current_idx][
                    "sent_arguments"
                ]
                if not self.streaming_state["sent_tools"][current_idx][
                    "sent_arguments_prefix"
                ] and args_text.startswith("{"):
                    self.streaming_state["sent_tools"][current_idx][
                        "sent_arguments_prefix"
                    ] = True
                    self.streaming_state["sent_tools"][current_idx][
                        "sent_arguments"
                    ] = "{"
                    while len(self.streamed_args) <= current_idx:
                        self.streamed_args.append("")
                    self.streamed_args[current_idx] += "{"
                    delta = DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                index=current_idx,
                                function=DeltaFunctionCall(arguments="{").model_dump(
                                    exclude_none=True
                                ),
                            )
                        ]
                    )
                    return delta

                if args_text.startswith(sent_args):
                    args_diff = args_text[len(sent_args) :]
                    if args_diff:
                        self.streaming_state["sent_tools"][current_idx][
                            "sent_arguments"
                        ] = args_text
                        while len(self.streamed_args) <= current_idx:
                            self.streamed_args.append("")
                        self.streamed_args[current_idx] += args_diff
                        delta = DeltaMessage(
                            tool_calls=[
                                DeltaToolCall(
                                    index=current_idx,
                                    function=DeltaFunctionCall(
                                        arguments=args_diff
                                    ).model_dump(exclude_none=True),
                                )
                            ]
                        )
                        return delta

                if args_text.endswith("}") and args_text == sent_args:
                    if current_idx < tool_count - 1:
                        self.streaming_state["current_tool_index"] += 1
                        self.current_tool_id = self.streaming_state[
                            "current_tool_index"
                        ]
        return None

extract_tool_calls

extract_tool_calls(
    model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation

Extract tool calls from a complete model output.

Source code in vllm/tool_parsers/hunyuan_a13b_tool_parser.py
def extract_tool_calls(
    self, model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
    """
    Extract tool calls from a complete model output.
    """
    try:
        # Preprocess the model output
        content, potential_tool_calls = self.preprocess_model_output(model_output)

        if not potential_tool_calls:
            # some text should be filtered out for no function call
            # this text is in a13b's chat template.
            if content:
                content = content.replace("助手:", "", 1)
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=content
            )

        # Parse the potential tool calls as JSON
        tool_calls_data = json.loads(potential_tool_calls)

        # Ensure it's an array
        if not isinstance(tool_calls_data, list):
            logger.debug("Tool calls data is not an array")
            return ExtractedToolCallInformation(
                tools_called=False,
                tool_calls=[],
                content=content or model_output,
            )

        tool_calls: list[ToolCall] = []

        for idx, call in enumerate(tool_calls_data):
            if (
                not isinstance(call, dict)
                or "name" not in call
                or "arguments" not in call
            ):
                continue

            tool_call = ToolCall(
                id=f"call_{random_uuid()}",
                type="function",
                function=FunctionCall(
                    name=call["name"],
                    arguments=(
                        json.dumps(call["arguments"])
                        if isinstance(call["arguments"], dict)
                        else call["arguments"]
                    ),
                ),
            )
            tool_calls.append(tool_call)

        if not content or len(content.strip()) == 0:
            # clear the whitespace content.
            content = None

        return ExtractedToolCallInformation(
            tools_called=len(tool_calls) > 0,
            tool_calls=tool_calls,
            content=content,
        )

    except Exception:
        return ExtractedToolCallInformation(
            tools_called=False, tool_calls=[], content=model_output
        )

extract_tool_calls_streaming

extract_tool_calls_streaming(
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],
    current_token_ids: Sequence[int],
    delta_token_ids: Sequence[int],
    request: ChatCompletionRequest,
) -> DeltaMessage | None

Extract tool calls for streaming mode.

Source code in vllm/tool_parsers/hunyuan_a13b_tool_parser.py
def extract_tool_calls_streaming(
    self,
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],
    current_token_ids: Sequence[int],
    delta_token_ids: Sequence[int],
    request: ChatCompletionRequest,
) -> DeltaMessage | None:
    """
    Extract tool calls for streaming mode.
    """

    start_idx = consume_space(0, current_text)
    if current_text[start_idx:].startswith(self.bot_string):
        start_idx = consume_space(start_idx + len(self.bot_string), current_text)
    if (
        not current_text
        or start_idx >= len(current_text)
        or current_text[start_idx] != "["
    ):
        return DeltaMessage(content=delta_text)

    self._try_parse_json_tools(current_text[start_idx:])

    test_delta = self._handle_test_compatibility(current_text)
    if test_delta:
        return test_delta

    name_matches = list(self.tool_name_reg.finditer(current_text))
    tool_count = len(name_matches)
    if tool_count == 0:
        return None
    self._ensure_state_arrays(tool_count)
    current_idx = self.streaming_state["current_tool_index"]

    name_delta = self._handle_tool_name_streaming(
        current_idx, tool_count, name_matches
    )
    if name_delta:
        return name_delta

    args_delta = self._handle_tool_args_streaming(
        current_text, current_idx, tool_count
    )
    if args_delta:
        return args_delta

    return None