client.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # coding=utf-8
  2. """
  3. AI 客户端模块
  4. 基于 LiteLLM 的统一 AI 模型接口
  5. 支持 100+ AI 提供商(OpenAI、DeepSeek、Gemini、Claude、国内模型等)
  6. """
  7. import os
  8. from typing import Any, Dict, List
  9. from litellm import completion
  10. class AIClient:
  11. """统一的 AI 客户端(基于 LiteLLM)"""
  12. def __init__(self, config: Dict[str, Any]):
  13. """
  14. 初始化 AI 客户端
  15. Args:
  16. config: AI 配置字典
  17. - MODEL: 模型标识(格式: provider/model_name)
  18. - API_KEY: API 密钥
  19. - API_BASE: API 基础 URL(可选)
  20. - TEMPERATURE: 采样温度
  21. - MAX_TOKENS: 最大生成 token 数
  22. - TIMEOUT: 请求超时时间(秒)
  23. - NUM_RETRIES: 重试次数(可选)
  24. - FALLBACK_MODELS: 备用模型列表(可选)
  25. """
  26. self.model = config.get("MODEL", "deepseek/deepseek-chat")
  27. self.api_key = config.get("API_KEY") or os.environ.get("AI_API_KEY", "")
  28. self.api_base = config.get("API_BASE", "")
  29. self.temperature = config.get("TEMPERATURE", 1.0)
  30. self.max_tokens = config.get("MAX_TOKENS", 5000)
  31. self.timeout = config.get("TIMEOUT", 120)
  32. self.num_retries = config.get("NUM_RETRIES", 2)
  33. self.fallback_models = config.get("FALLBACK_MODELS", [])
  34. def chat(
  35. self,
  36. messages: List[Dict[str, str]],
  37. **kwargs
  38. ) -> str:
  39. """
  40. 调用 AI 模型进行对话
  41. Args:
  42. messages: 消息列表,格式: [{"role": "system/user/assistant", "content": "..."}]
  43. **kwargs: 额外参数,会覆盖默认配置
  44. Returns:
  45. str: AI 响应内容
  46. Raises:
  47. Exception: API 调用失败时抛出异常
  48. """
  49. # 构建请求参数
  50. params = {
  51. "model": self.model,
  52. "messages": messages,
  53. "temperature": kwargs.get("temperature", self.temperature),
  54. "timeout": kwargs.get("timeout", self.timeout),
  55. "num_retries": kwargs.get("num_retries", self.num_retries),
  56. }
  57. # 添加 API Key
  58. if self.api_key:
  59. params["api_key"] = self.api_key
  60. # 添加 API Base(如果配置了)
  61. if self.api_base:
  62. params["api_base"] = self.api_base
  63. # 添加 max_tokens(如果配置了且不为 0)
  64. max_tokens = kwargs.get("max_tokens", self.max_tokens)
  65. if max_tokens and max_tokens > 0:
  66. params["max_tokens"] = max_tokens
  67. # 添加 fallback 模型(如果配置了)
  68. if self.fallback_models:
  69. params["fallbacks"] = self.fallback_models
  70. # 合并其他额外参数
  71. for key, value in kwargs.items():
  72. if key not in params:
  73. params[key] = value
  74. # 调用 LiteLLM
  75. response = completion(**params)
  76. # 提取响应内容
  77. # 某些模型/提供商返回 list(内容块)而非 str,统一转为 str
  78. content = response.choices[0].message.content
  79. if isinstance(content, list):
  80. content = "\n".join(
  81. item.get("text", str(item)) if isinstance(item, dict) else str(item)
  82. for item in content
  83. )
  84. return content or ""
  85. def validate_config(self) -> tuple[bool, str]:
  86. """
  87. 验证配置是否有效
  88. Returns:
  89. tuple: (是否有效, 错误信息)
  90. """
  91. if not self.model:
  92. return False, "未配置 AI 模型(model)"
  93. if not self.api_key:
  94. return False, "未配置 AI API Key,请在 config.yaml 或环境变量 AI_API_KEY 中设置"
  95. # 验证模型格式(应该包含 provider/model)
  96. if "/" not in self.model:
  97. return False, f"模型格式错误: {self.model},应为 'provider/model' 格式(如 'deepseek/deepseek-chat')"
  98. return True, ""