167 lines
5.4 KiB
Python
167 lines
5.4 KiB
Python
"""
|
||
LLM 调用模块
|
||
用于调用外部 LLM API 进行文本生成
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
from typing import Dict, Any, Optional, List
|
||
import requests
|
||
|
||
|
||
class LLMClient:
|
||
"""LLM API 客户端"""
|
||
|
||
def __init__(
|
||
self,
|
||
base_url: Optional[str] = None,
|
||
model: Optional[str] = None,
|
||
temperature: Optional[float] = None,
|
||
max_tokens: Optional[int] = None,
|
||
api_key: Optional[str] = None
|
||
):
|
||
"""
|
||
初始化 LLM 客户端
|
||
|
||
Args:
|
||
base_url: API 基础 URL(默认从环境变量 LLM_BASE_URL 读取)
|
||
model: 模型名称(默认从环境变量 LLM_MODEL 读取)
|
||
temperature: 温度参数(默认从环境变量 LLM_TEMPERATURE 读取)
|
||
max_tokens: 最大生成长度(默认从环境变量 LLM_MAX_TOKENS 读取)
|
||
api_key: API 密钥(可选,从环境变量 LLM_API_KEY 读取)
|
||
"""
|
||
import os
|
||
self.base_url = base_url or os.getenv("LLM_BASE_URL", "https://glm47flash.cloyir.com/v1")
|
||
self.model = model or os.getenv("LLM_MODEL", "qwen3.5-35b")
|
||
self.temperature = temperature if temperature is not None else float(os.getenv("LLM_TEMPERATURE", "0.3"))
|
||
self.max_tokens = max_tokens if max_tokens is not None else int(os.getenv("LLM_MAX_TOKENS", "2048"))
|
||
self.api_key = api_key or os.getenv("LLM_API_KEY", "")
|
||
|
||
self._headers = {
|
||
"Content-Type": "application/json"
|
||
}
|
||
if self.api_key:
|
||
self._headers["Authorization"] = f"Bearer {self.api_key}"
|
||
|
||
def chat_completion(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
temperature: Optional[float] = None,
|
||
max_tokens: Optional[int] = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
发送聊天请求并获取响应
|
||
|
||
Args:
|
||
messages: 消息列表,格式为 [{"role": "system/user/assistant", "content": "..."}]
|
||
temperature: 温度参数(覆盖默认值)
|
||
max_tokens: 最大生成长度(覆盖默认值)
|
||
|
||
Returns:
|
||
API 响应字典
|
||
"""
|
||
url = f"{self.base_url}/chat/completions"
|
||
|
||
payload = {
|
||
"model": self.model,
|
||
"messages": messages,
|
||
"temperature": temperature or self.temperature,
|
||
"max_tokens": max_tokens or self.max_tokens,
|
||
"stream": False,
|
||
"extra_body": {
|
||
"chat_template_kwargs": {
|
||
"enable_thinking": False
|
||
}
|
||
}
|
||
}
|
||
|
||
try:
|
||
response = requests.post(
|
||
url,
|
||
headers=self._headers,
|
||
json=payload,
|
||
timeout=120
|
||
)
|
||
response.raise_for_status()
|
||
return response.json()
|
||
except requests.exceptions.RequestException as e:
|
||
return {
|
||
"error": str(e),
|
||
"status_code": getattr(e.response, 'status_code', None) if hasattr(e, 'response') else None
|
||
}
|
||
|
||
def generate_response(
|
||
self,
|
||
system_prompt: str,
|
||
user_prompt: str,
|
||
temperature: Optional[float] = None,
|
||
max_tokens: Optional[int] = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
生成响应(简化接口)
|
||
|
||
Args:
|
||
system_prompt: 系统提示词
|
||
user_prompt: 用户提示词
|
||
temperature: 温度参数
|
||
max_tokens: 最大生成长度
|
||
|
||
Returns:
|
||
包含响应内容的字典
|
||
"""
|
||
messages = [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
]
|
||
|
||
result = self.chat_completion(messages, temperature, max_tokens)
|
||
|
||
if "error" in result:
|
||
return {
|
||
"success": False,
|
||
"error": result["error"],
|
||
"content": None
|
||
}
|
||
|
||
if "choices" in result and len(result["choices"]) > 0:
|
||
choice = result["choices"][0]
|
||
message = choice.get("message", {})
|
||
|
||
# 尝试从不同位置获取 content
|
||
content = message.get("content")
|
||
|
||
# 如果 content 是 None,检查是否有 reasoning 或其他字段
|
||
if content is None:
|
||
# 有些模型可能将内容放在 reasoning 字段中
|
||
reasoning = message.get("reasoning", "")
|
||
if reasoning:
|
||
content = reasoning
|
||
|
||
# 如果仍然没有内容,返回错误
|
||
if content is None:
|
||
return {
|
||
"success": False,
|
||
"error": "No content found in response",
|
||
"content": None
|
||
}
|
||
|
||
return {
|
||
"success": True,
|
||
"error": None,
|
||
"content": content,
|
||
"usage": result.get("usage", {}),
|
||
"model": result.get("model", self.model)
|
||
}
|
||
|
||
return {
|
||
"success": False,
|
||
"error": f"Unexpected response format: {result}",
|
||
"content": None
|
||
}
|
||
|
||
|
||
# 全局 LLM 客户端实例(可通过环境变量配置)
|
||
def get_llm_client() -> LLMClient:
|
||
"""获取 LLM 客户端实例(配置优先从环境变量读取)"""
|
||
return LLMClient() # 所有默认值已从环境变量读取
|