130 lines
4.3 KiB
Python
130 lines
4.3 KiB
Python
"""
|
|
LLM 调用模块
|
|
|
|
用于调用外部 LLM API 进行文本生成,支持重试和 JSON 解析辅助。
|
|
"""
|
|
|
|
import json
|
|
import re
|
|
import time
|
|
from typing import Dict, Any, List, Optional
|
|
|
|
import requests
|
|
|
|
|
|
class LLMClient:
|
|
"""LLM API 客户端"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
temperature: Optional[float] = None,
|
|
max_tokens: Optional[int] = None,
|
|
api_key: Optional[str] = None,
|
|
timeout: int = 120,
|
|
max_retries: int = 3,
|
|
):
|
|
self.base_url = base_url.rstrip("/")
|
|
self.model = model or os.getenv("LLM_MODEL", "qwen3.5-35b")
|
|
self.temperature = temperature if temperature is not None else float(os.getenv("LLM_TEMPERATURE", "0.3"))
|
|
self.max_tokens = max_tokens if max_tokens is not None else int(os.getenv("LLM_MAX_TOKENS", "2048"))
|
|
self.timeout = timeout
|
|
self.max_retries = max_retries
|
|
self.api_key = api_key
|
|
|
|
self._headers = {"Content-Type": "application/json"}
|
|
if api_key:
|
|
self._headers["Authorization"] = f"Bearer {api_key}"
|
|
|
|
def chat_completion(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
temperature: Optional[float] = None,
|
|
max_tokens: Optional[int] = None,
|
|
) -> Dict[str, Any]:
|
|
"""发送聊天请求并获取响应,支持重试"""
|
|
url = f"{self.base_url}/chat/completions"
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"temperature": temperature or self.temperature,
|
|
"max_tokens": max_tokens or self.max_tokens,
|
|
"stream": False,
|
|
"extra_body": {"chat_template_kwargs": {"enable_thinking": False}},
|
|
}
|
|
|
|
last_error = None
|
|
for attempt in range(self.max_retries):
|
|
try:
|
|
response = requests.post(
|
|
url, headers=self._headers, json=payload, timeout=self.timeout
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()
|
|
except requests.exceptions.RequestException as e:
|
|
last_error = e
|
|
if attempt < self.max_retries - 1:
|
|
time.sleep(1 * (attempt + 1))
|
|
|
|
return {
|
|
"error": str(last_error),
|
|
"status_code": (
|
|
getattr(last_error.response, "status_code", None)
|
|
if hasattr(last_error, "response")
|
|
else None
|
|
),
|
|
}
|
|
|
|
def generate_response(
|
|
self,
|
|
system_prompt: str,
|
|
user_prompt: str,
|
|
temperature: Optional[float] = None,
|
|
max_tokens: Optional[int] = None,
|
|
) -> Dict[str, Any]:
|
|
"""生成响应(简化接口)"""
|
|
messages = [
|
|
{"role": "system", "content": system_prompt},
|
|
{"role": "user", "content": user_prompt},
|
|
]
|
|
result = self.chat_completion(messages, temperature, max_tokens)
|
|
|
|
if "error" in result:
|
|
return {"success": False, "error": result["error"], "content": None}
|
|
|
|
if "choices" in result and len(result["choices"]) > 0:
|
|
choice = result["choices"][0]
|
|
message = choice.get("message", {})
|
|
content = message.get("content")
|
|
if content is None:
|
|
content = message.get("reasoning", "")
|
|
if content is None:
|
|
return {
|
|
"success": False,
|
|
"error": "No content found in response",
|
|
"content": None,
|
|
}
|
|
return {
|
|
"success": True,
|
|
"error": None,
|
|
"content": content,
|
|
"usage": result.get("usage", {}),
|
|
"model": result.get("model", self.model),
|
|
}
|
|
|
|
return {
|
|
"success": False,
|
|
"error": f"Unexpected response format: {result}",
|
|
"content": None,
|
|
}
|
|
|
|
@staticmethod
|
|
def parse_json_response(content: str) -> Dict[str, Any]:
|
|
"""清理并解析 LLM 返回的 JSON 字符串"""
|
|
cleaned = content.strip()
|
|
if cleaned.startswith("```"):
|
|
cleaned = re.sub(r"^```(?:json)?\s*", "", cleaned)
|
|
cleaned = re.sub(r"\s*```$", "", cleaned)
|
|
return json.loads(cleaned)
|