2026-03-29 23:47:20 +08:00

82 lines
2.8 KiB
Python

"""
Agent 配置模块
支持从环境变量、构造函数参数加载配置。
"""
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
def _load_env_file():
"""加载项目根目录的 .env 文件"""
# 尝试多个可能的位置
possible_paths = [
Path.cwd() / ".env",
Path(__file__).parent.parent.parent / ".env",
Path(__file__).parent.parent.parent.parent / ".env",
]
for env_path in possible_paths:
if env_path.exists():
with open(env_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, value = line.split("=", 1)
# 只在环境变量不存在时才设置
if key not in os.environ:
os.environ[key] = value
break
# 模块加载时自动读取 .env 文件
_load_env_file()
@dataclass
class AgentConfig:
"""Causal Inference Agent 配置"""
# LLM 配置(默认值仅在环境变量和 .env 都不存在时使用)
llm_base_url: str = "https://glm47flash.cloyir.com/v1"
llm_model: str = "qwen3.5-35b"
llm_temperature: float = 0.3
llm_max_tokens: int = 2048
llm_api_key: Optional[str] = None
llm_timeout: int = 120
llm_max_retries: int = 3
# 统计筛查配置
corr_threshold: float = 0.1
alpha: float = 0.05
bootstrap_iterations: int = 500
bootstrap_alpha: float = 0.05
# 路径配置
log_path: str = "causal_analysis.log.md"
output_dir: str = "."
@classmethod
def from_env(cls) -> "AgentConfig":
"""从环境变量加载配置"""
return cls(
llm_base_url=os.getenv("LLM_BASE_URL", cls.llm_base_url),
llm_model=os.getenv("LLM_MODEL", cls.llm_model),
llm_temperature=float(os.getenv("LLM_TEMPERATURE", str(cls.llm_temperature))),
llm_max_tokens=int(os.getenv("LLM_MAX_TOKENS", str(cls.llm_max_tokens))),
llm_api_key=os.getenv("LLM_API_KEY", cls.llm_api_key),
llm_timeout=int(os.getenv("LLM_TIMEOUT", str(cls.llm_timeout))),
llm_max_retries=int(os.getenv("LLM_MAX_RETRIES", str(cls.llm_max_retries))),
corr_threshold=float(os.getenv("CORR_THRESHOLD", str(cls.corr_threshold))),
alpha=float(os.getenv("ALPHA", str(cls.alpha))),
bootstrap_iterations=int(
os.getenv("BOOTSTRAP_ITERATIONS", str(cls.bootstrap_iterations))
),
bootstrap_alpha=float(os.getenv("BOOTSTRAP_ALPHA", str(cls.bootstrap_alpha))),
log_path=os.getenv("LOG_PATH", cls.log_path),
output_dir=os.getenv("OUTPUT_DIR", cls.output_dir),
)