82 lines
2.8 KiB
Python
82 lines
2.8 KiB
Python
"""
|
|
Agent 配置模块
|
|
|
|
支持从环境变量、构造函数参数加载配置。
|
|
"""
|
|
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
|
|
def _load_env_file():
|
|
"""加载项目根目录的 .env 文件"""
|
|
# 尝试多个可能的位置
|
|
possible_paths = [
|
|
Path.cwd() / ".env",
|
|
Path(__file__).parent.parent.parent / ".env",
|
|
Path(__file__).parent.parent.parent.parent / ".env",
|
|
]
|
|
|
|
for env_path in possible_paths:
|
|
if env_path.exists():
|
|
with open(env_path, "r", encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line and not line.startswith("#") and "=" in line:
|
|
key, value = line.split("=", 1)
|
|
# 只在环境变量不存在时才设置
|
|
if key not in os.environ:
|
|
os.environ[key] = value
|
|
break
|
|
|
|
|
|
# 模块加载时自动读取 .env 文件
|
|
_load_env_file()
|
|
|
|
|
|
@dataclass
|
|
class AgentConfig:
|
|
"""Causal Inference Agent 配置"""
|
|
|
|
# LLM 配置(默认值仅在环境变量和 .env 都不存在时使用)
|
|
llm_base_url: str = "https://glm47flash.cloyir.com/v1"
|
|
llm_model: str = "qwen3.5-35b"
|
|
llm_temperature: float = 0.3
|
|
llm_max_tokens: int = 2048
|
|
llm_api_key: Optional[str] = None
|
|
llm_timeout: int = 120
|
|
llm_max_retries: int = 3
|
|
|
|
# 统计筛查配置
|
|
corr_threshold: float = 0.1
|
|
alpha: float = 0.05
|
|
bootstrap_iterations: int = 500
|
|
bootstrap_alpha: float = 0.05
|
|
|
|
# 路径配置
|
|
log_path: str = "causal_analysis.log.md"
|
|
output_dir: str = "."
|
|
|
|
@classmethod
|
|
def from_env(cls) -> "AgentConfig":
|
|
"""从环境变量加载配置"""
|
|
return cls(
|
|
llm_base_url=os.getenv("LLM_BASE_URL", cls.llm_base_url),
|
|
llm_model=os.getenv("LLM_MODEL", cls.llm_model),
|
|
llm_temperature=float(os.getenv("LLM_TEMPERATURE", str(cls.llm_temperature))),
|
|
llm_max_tokens=int(os.getenv("LLM_MAX_TOKENS", str(cls.llm_max_tokens))),
|
|
llm_api_key=os.getenv("LLM_API_KEY", cls.llm_api_key),
|
|
llm_timeout=int(os.getenv("LLM_TIMEOUT", str(cls.llm_timeout))),
|
|
llm_max_retries=int(os.getenv("LLM_MAX_RETRIES", str(cls.llm_max_retries))),
|
|
corr_threshold=float(os.getenv("CORR_THRESHOLD", str(cls.corr_threshold))),
|
|
alpha=float(os.getenv("ALPHA", str(cls.alpha))),
|
|
bootstrap_iterations=int(
|
|
os.getenv("BOOTSTRAP_ITERATIONS", str(cls.bootstrap_iterations))
|
|
),
|
|
bootstrap_alpha=float(os.getenv("BOOTSTRAP_ALPHA", str(cls.bootstrap_alpha))),
|
|
log_path=os.getenv("LOG_PATH", cls.log_path),
|
|
output_dir=os.getenv("OUTPUT_DIR", cls.output_dir),
|
|
)
|