2026-03-29 23:47:20 +08:00

130 lines
4.3 KiB
Python

"""
LLM 调用模块
用于调用外部 LLM API 进行文本生成,支持重试和 JSON 解析辅助。
"""
import json
import re
import time
from typing import Dict, Any, List, Optional
import requests
class LLMClient:
"""LLM API 客户端"""
def __init__(
self,
base_url: Optional[str] = None,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
api_key: Optional[str] = None,
timeout: int = 120,
max_retries: int = 3,
):
self.base_url = base_url.rstrip("/")
self.model = model or os.getenv("LLM_MODEL", "qwen3.5-35b")
self.temperature = temperature if temperature is not None else float(os.getenv("LLM_TEMPERATURE", "0.3"))
self.max_tokens = max_tokens if max_tokens is not None else int(os.getenv("LLM_MAX_TOKENS", "2048"))
self.timeout = timeout
self.max_retries = max_retries
self.api_key = api_key
self._headers = {"Content-Type": "application/json"}
if api_key:
self._headers["Authorization"] = f"Bearer {api_key}"
def chat_completion(
self,
messages: List[Dict[str, str]],
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Dict[str, Any]:
"""发送聊天请求并获取响应,支持重试"""
url = f"{self.base_url}/chat/completions"
payload = {
"model": self.model,
"messages": messages,
"temperature": temperature or self.temperature,
"max_tokens": max_tokens or self.max_tokens,
"stream": False,
"extra_body": {"chat_template_kwargs": {"enable_thinking": False}},
}
last_error = None
for attempt in range(self.max_retries):
try:
response = requests.post(
url, headers=self._headers, json=payload, timeout=self.timeout
)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
last_error = e
if attempt < self.max_retries - 1:
time.sleep(1 * (attempt + 1))
return {
"error": str(last_error),
"status_code": (
getattr(last_error.response, "status_code", None)
if hasattr(last_error, "response")
else None
),
}
def generate_response(
self,
system_prompt: str,
user_prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Dict[str, Any]:
"""生成响应(简化接口)"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
result = self.chat_completion(messages, temperature, max_tokens)
if "error" in result:
return {"success": False, "error": result["error"], "content": None}
if "choices" in result and len(result["choices"]) > 0:
choice = result["choices"][0]
message = choice.get("message", {})
content = message.get("content")
if content is None:
content = message.get("reasoning", "")
if content is None:
return {
"success": False,
"error": "No content found in response",
"content": None,
}
return {
"success": True,
"error": None,
"content": content,
"usage": result.get("usage", {}),
"model": result.get("model", self.model),
}
return {
"success": False,
"error": f"Unexpected response format: {result}",
"content": None,
}
@staticmethod
def parse_json_response(content: str) -> Dict[str, Any]:
"""清理并解析 LLM 返回的 JSON 字符串"""
cleaned = content.strip()
if cleaned.startswith("```"):
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned)
cleaned = re.sub(r"\s*```$", "", cleaned)
return json.loads(cleaned)