2026-03-29 23:47:20 +08:00

167 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LLM 调用模块
用于调用外部 LLM API 进行文本生成
"""
import os
import json
from typing import Dict, Any, Optional, List
import requests
class LLMClient:
"""LLM API 客户端"""
def __init__(
self,
base_url: Optional[str] = None,
model: Optional[str] = None,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
api_key: Optional[str] = None
):
"""
初始化 LLM 客户端
Args:
base_url: API 基础 URL默认从环境变量 LLM_BASE_URL 读取)
model: 模型名称(默认从环境变量 LLM_MODEL 读取)
temperature: 温度参数(默认从环境变量 LLM_TEMPERATURE 读取)
max_tokens: 最大生成长度(默认从环境变量 LLM_MAX_TOKENS 读取)
api_key: API 密钥(可选,从环境变量 LLM_API_KEY 读取)
"""
import os
self.base_url = base_url or os.getenv("LLM_BASE_URL", "https://glm47flash.cloyir.com/v1")
self.model = model or os.getenv("LLM_MODEL", "qwen3.5-35b")
self.temperature = temperature if temperature is not None else float(os.getenv("LLM_TEMPERATURE", "0.3"))
self.max_tokens = max_tokens if max_tokens is not None else int(os.getenv("LLM_MAX_TOKENS", "2048"))
self.api_key = api_key or os.getenv("LLM_API_KEY", "")
self._headers = {
"Content-Type": "application/json"
}
if self.api_key:
self._headers["Authorization"] = f"Bearer {self.api_key}"
def chat_completion(
self,
messages: List[Dict[str, str]],
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> Dict[str, Any]:
"""
发送聊天请求并获取响应
Args:
messages: 消息列表,格式为 [{"role": "system/user/assistant", "content": "..."}]
temperature: 温度参数(覆盖默认值)
max_tokens: 最大生成长度(覆盖默认值)
Returns:
API 响应字典
"""
url = f"{self.base_url}/chat/completions"
payload = {
"model": self.model,
"messages": messages,
"temperature": temperature or self.temperature,
"max_tokens": max_tokens or self.max_tokens,
"stream": False,
"extra_body": {
"chat_template_kwargs": {
"enable_thinking": False
}
}
}
try:
response = requests.post(
url,
headers=self._headers,
json=payload,
timeout=120
)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
return {
"error": str(e),
"status_code": getattr(e.response, 'status_code', None) if hasattr(e, 'response') else None
}
def generate_response(
self,
system_prompt: str,
user_prompt: str,
temperature: Optional[float] = None,
max_tokens: Optional[int] = None
) -> Dict[str, Any]:
"""
生成响应(简化接口)
Args:
system_prompt: 系统提示词
user_prompt: 用户提示词
temperature: 温度参数
max_tokens: 最大生成长度
Returns:
包含响应内容的字典
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
result = self.chat_completion(messages, temperature, max_tokens)
if "error" in result:
return {
"success": False,
"error": result["error"],
"content": None
}
if "choices" in result and len(result["choices"]) > 0:
choice = result["choices"][0]
message = choice.get("message", {})
# 尝试从不同位置获取 content
content = message.get("content")
# 如果 content 是 None检查是否有 reasoning 或其他字段
if content is None:
# 有些模型可能将内容放在 reasoning 字段中
reasoning = message.get("reasoning", "")
if reasoning:
content = reasoning
# 如果仍然没有内容,返回错误
if content is None:
return {
"success": False,
"error": "No content found in response",
"content": None
}
return {
"success": True,
"error": None,
"content": content,
"usage": result.get("usage", {}),
"model": result.get("model", self.model)
}
return {
"success": False,
"error": f"Unexpected response format: {result}",
"content": None
}
# 全局 LLM 客户端实例(可通过环境变量配置)
def get_llm_client() -> LLMClient:
"""获取 LLM 客户端实例(配置优先从环境变量读取)"""
return LLMClient() # 所有默认值已从环境变量读取