|
@@ -3,15 +3,16 @@
|
|
|
AI 分析器模块
|
|
AI 分析器模块
|
|
|
|
|
|
|
|
调用 AI 大模型对热点新闻进行深度分析
|
|
调用 AI 大模型对热点新闻进行深度分析
|
|
|
-支持 OpenAI、Google Gemini、Azure OpenAI 等兼容接口
|
|
|
|
|
|
|
+基于 LiteLLM 统一接口,支持 100+ AI 提供商
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
import json
|
|
import json
|
|
|
-import os
|
|
|
|
|
from dataclasses import dataclass
|
|
from dataclasses import dataclass
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
|
|
|
|
|
|
+from trendradar.ai.client import AIClient
|
|
|
|
|
+
|
|
|
|
|
|
|
|
@dataclass
|
|
@dataclass
|
|
|
class AIAnalysisResult:
|
|
class AIAnalysisResult:
|
|
@@ -50,7 +51,7 @@ class AIAnalyzer:
|
|
|
初始化 AI 分析器
|
|
初始化 AI 分析器
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
- ai_config: AI 模型共享配置(provider, api_key, model 等)
|
|
|
|
|
|
|
+ ai_config: AI 模型配置(LiteLLM 格式)
|
|
|
analysis_config: AI 分析功能配置(language, prompt_file 等)
|
|
analysis_config: AI 分析功能配置(language, prompt_file 等)
|
|
|
get_time_func: 获取当前时间的函数
|
|
get_time_func: 获取当前时间的函数
|
|
|
debug: 是否开启调试模式
|
|
debug: 是否开启调试模式
|
|
@@ -60,14 +61,13 @@ class AIAnalyzer:
|
|
|
self.get_time_func = get_time_func
|
|
self.get_time_func = get_time_func
|
|
|
self.debug = debug
|
|
self.debug = debug
|
|
|
|
|
|
|
|
- # 从共享配置获取模型参数
|
|
|
|
|
- self.api_key = ai_config.get("API_KEY") or os.environ.get("AI_API_KEY", "")
|
|
|
|
|
- self.provider = ai_config.get("PROVIDER", "deepseek")
|
|
|
|
|
- self.model = ai_config.get("MODEL", "deepseek-chat")
|
|
|
|
|
- self.base_url = ai_config.get("BASE_URL", "")
|
|
|
|
|
- self.timeout = ai_config.get("TIMEOUT", 90)
|
|
|
|
|
- self.temperature = ai_config.get("TEMPERATURE", 1.0)
|
|
|
|
|
- self.max_tokens = ai_config.get("MAX_TOKENS", 5000)
|
|
|
|
|
|
|
+ # 创建 AI 客户端(基于 LiteLLM)
|
|
|
|
|
+ self.client = AIClient(ai_config)
|
|
|
|
|
+
|
|
|
|
|
+ # 验证配置
|
|
|
|
|
+ valid, error = self.client.validate_config()
|
|
|
|
|
+ if not valid:
|
|
|
|
|
+ print(f"[AI] 配置警告: {error}")
|
|
|
|
|
|
|
|
# 从分析配置获取功能参数
|
|
# 从分析配置获取功能参数
|
|
|
self.max_news = analysis_config.get("MAX_NEWS_FOR_ANALYSIS", 50)
|
|
self.max_news = analysis_config.get("MAX_NEWS_FOR_ANALYSIS", 50)
|
|
@@ -75,18 +75,6 @@ class AIAnalyzer:
|
|
|
self.include_rank_timeline = analysis_config.get("INCLUDE_RANK_TIMELINE", False)
|
|
self.include_rank_timeline = analysis_config.get("INCLUDE_RANK_TIMELINE", False)
|
|
|
self.language = analysis_config.get("LANGUAGE", "Chinese")
|
|
self.language = analysis_config.get("LANGUAGE", "Chinese")
|
|
|
|
|
|
|
|
- # 额外的自定义参数(支持字典或 JSON 字符串)
|
|
|
|
|
- self.extra_params = ai_config.get("EXTRA_PARAMS", {})
|
|
|
|
|
- if isinstance(self.extra_params, str) and self.extra_params.strip():
|
|
|
|
|
- try:
|
|
|
|
|
- self.extra_params = json.loads(self.extra_params)
|
|
|
|
|
- except json.JSONDecodeError:
|
|
|
|
|
- print(f"[AI] 解析 extra_params 失败,将忽略: {self.extra_params}")
|
|
|
|
|
- self.extra_params = {}
|
|
|
|
|
-
|
|
|
|
|
- if not isinstance(self.extra_params, dict):
|
|
|
|
|
- self.extra_params = {}
|
|
|
|
|
-
|
|
|
|
|
# 加载提示词模板
|
|
# 加载提示词模板
|
|
|
self.system_prompt, self.user_prompt_template = self._load_prompt_template(
|
|
self.system_prompt, self.user_prompt_template = self._load_prompt_template(
|
|
|
analysis_config.get("PROMPT_FILE", "ai_analysis_prompt.txt")
|
|
analysis_config.get("PROMPT_FILE", "ai_analysis_prompt.txt")
|
|
@@ -146,7 +134,7 @@ class AIAnalyzer:
|
|
|
Returns:
|
|
Returns:
|
|
|
AIAnalysisResult: 分析结果
|
|
AIAnalysisResult: 分析结果
|
|
|
"""
|
|
"""
|
|
|
- if not self.api_key:
|
|
|
|
|
|
|
+ if not self.client.api_key:
|
|
|
return AIAnalysisResult(
|
|
return AIAnalysisResult(
|
|
|
success=False,
|
|
success=False,
|
|
|
error="未配置 AI API Key,请在 config.yaml 或环境变量 AI_API_KEY 中设置"
|
|
error="未配置 AI API Key,请在 config.yaml 或环境变量 AI_API_KEY 中设置"
|
|
@@ -198,9 +186,9 @@ class AIAnalyzer:
|
|
|
print(user_prompt)
|
|
print(user_prompt)
|
|
|
print("=" * 80 + "\n")
|
|
print("=" * 80 + "\n")
|
|
|
|
|
|
|
|
- # 调用 AI API
|
|
|
|
|
|
|
+ # 调用 AI API(使用 LiteLLM)
|
|
|
try:
|
|
try:
|
|
|
- response = self._call_ai_api(user_prompt)
|
|
|
|
|
|
|
+ response = self._call_ai(user_prompt)
|
|
|
result = self._parse_response(response)
|
|
result = self._parse_response(response)
|
|
|
|
|
|
|
|
# 如果配置未启用 RSS 分析,强制清空 AI 返回的 RSS 洞察
|
|
# 如果配置未启用 RSS 分析,强制清空 AI 返回的 RSS 洞察
|
|
@@ -215,30 +203,13 @@ class AIAnalyzer:
|
|
|
result.max_news_limit = self.max_news
|
|
result.max_news_limit = self.max_news
|
|
|
return result
|
|
return result
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
- import requests
|
|
|
|
|
error_type = type(e).__name__
|
|
error_type = type(e).__name__
|
|
|
error_msg = str(e)
|
|
error_msg = str(e)
|
|
|
|
|
|
|
|
- # 针对不同错误类型提供更友好的提示
|
|
|
|
|
- if isinstance(e, requests.exceptions.Timeout):
|
|
|
|
|
- friendly_msg = f"AI API 请求超时({self.timeout}秒),请检查网络或增加超时时间"
|
|
|
|
|
- elif isinstance(e, requests.exceptions.ConnectionError):
|
|
|
|
|
- friendly_msg = f"无法连接到 AI API ({self.base_url or self.provider}),请检查网络和 API 地址"
|
|
|
|
|
- elif isinstance(e, requests.exceptions.HTTPError):
|
|
|
|
|
- status_code = e.response.status_code if hasattr(e, 'response') and e.response else "未知"
|
|
|
|
|
- if status_code == 401:
|
|
|
|
|
- friendly_msg = "AI API 认证失败,请检查 API Key 是否正确"
|
|
|
|
|
- elif status_code == 429:
|
|
|
|
|
- friendly_msg = "AI API 请求频率过高,请稍后重试"
|
|
|
|
|
- elif status_code == 500:
|
|
|
|
|
- friendly_msg = "AI API 服务器内部错误,请稍后重试"
|
|
|
|
|
- else:
|
|
|
|
|
- friendly_msg = f"AI API 返回错误 (HTTP {status_code}): {error_msg[:100]}"
|
|
|
|
|
- else:
|
|
|
|
|
- # 截断过长的错误消息
|
|
|
|
|
- if len(error_msg) > 150:
|
|
|
|
|
- error_msg = error_msg[:150] + "..."
|
|
|
|
|
- friendly_msg = f"AI 分析失败 ({error_type}): {error_msg}"
|
|
|
|
|
|
|
+ # 截断过长的错误消息
|
|
|
|
|
+ if len(error_msg) > 200:
|
|
|
|
|
+ error_msg = error_msg[:200] + "..."
|
|
|
|
|
+ friendly_msg = f"AI 分析失败 ({error_type}): {error_msg}"
|
|
|
|
|
|
|
|
return AIAnalysisResult(
|
|
return AIAnalysisResult(
|
|
|
success=False,
|
|
success=False,
|
|
@@ -364,6 +335,15 @@ class AIAnalyzer:
|
|
|
|
|
|
|
|
return news_content, rss_content, hotlist_total, rss_total, total_count
|
|
return news_content, rss_content, hotlist_total, rss_total, total_count
|
|
|
|
|
|
|
|
|
|
+ def _call_ai(self, user_prompt: str) -> str:
|
|
|
|
|
+ """调用 AI API(使用 LiteLLM)"""
|
|
|
|
|
+ messages = []
|
|
|
|
|
+ if self.system_prompt:
|
|
|
|
|
+ messages.append({"role": "system", "content": self.system_prompt})
|
|
|
|
|
+ messages.append({"role": "user", "content": user_prompt})
|
|
|
|
|
+
|
|
|
|
|
+ return self.client.chat(messages)
|
|
|
|
|
+
|
|
|
def _format_time_range(self, first_time: str, last_time: str) -> str:
|
|
def _format_time_range(self, first_time: str, last_time: str) -> str:
|
|
|
"""格式化时间范围(简化显示,只保留时分)"""
|
|
"""格式化时间范围(简化显示,只保留时分)"""
|
|
|
def extract_time(time_str: str) -> str:
|
|
def extract_time(time_str: str) -> str:
|
|
@@ -409,116 +389,6 @@ class AIAnalyzer:
|
|
|
|
|
|
|
|
return "→".join(parts)
|
|
return "→".join(parts)
|
|
|
|
|
|
|
|
- def _call_ai_api(self, user_prompt: str) -> str:
|
|
|
|
|
- """调用 AI API"""
|
|
|
|
|
- if self.provider == "gemini":
|
|
|
|
|
- return self._call_gemini(user_prompt)
|
|
|
|
|
- return self._call_openai_compatible(user_prompt)
|
|
|
|
|
-
|
|
|
|
|
- def _get_api_url(self) -> str:
|
|
|
|
|
- """获取完整 API URL"""
|
|
|
|
|
- if self.base_url:
|
|
|
|
|
- return self.base_url
|
|
|
|
|
-
|
|
|
|
|
- # 预设完整端点
|
|
|
|
|
- urls = {
|
|
|
|
|
- "deepseek": "https://api.deepseek.com/v1/chat/completions",
|
|
|
|
|
- "openai": "https://api.openai.com/v1/chat/completions",
|
|
|
|
|
- }
|
|
|
|
|
- url = urls.get(self.provider)
|
|
|
|
|
- if not url:
|
|
|
|
|
- raise ValueError(f"{self.provider} 需要配置 base_url(完整 API 地址)")
|
|
|
|
|
- return url
|
|
|
|
|
-
|
|
|
|
|
- def _call_openai_compatible(self, user_prompt: str) -> str:
|
|
|
|
|
- """调用 OpenAI 兼容接口"""
|
|
|
|
|
- import requests
|
|
|
|
|
-
|
|
|
|
|
- url = self._get_api_url()
|
|
|
|
|
-
|
|
|
|
|
- headers = {
|
|
|
|
|
- "Authorization": f"Bearer {self.api_key}",
|
|
|
|
|
- "Content-Type": "application/json",
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- messages = []
|
|
|
|
|
- if self.system_prompt:
|
|
|
|
|
- messages.append({"role": "system", "content": self.system_prompt})
|
|
|
|
|
- messages.append({"role": "user", "content": user_prompt})
|
|
|
|
|
-
|
|
|
|
|
- payload = {
|
|
|
|
|
- "model": self.model,
|
|
|
|
|
- "messages": messages,
|
|
|
|
|
- "temperature": self.temperature,
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- # 某些 API 不支持 max_tokens
|
|
|
|
|
- if self.max_tokens:
|
|
|
|
|
- payload["max_tokens"] = self.max_tokens
|
|
|
|
|
-
|
|
|
|
|
- if self.extra_params:
|
|
|
|
|
- payload.update(self.extra_params)
|
|
|
|
|
-
|
|
|
|
|
- response = requests.post(
|
|
|
|
|
- url,
|
|
|
|
|
- headers=headers,
|
|
|
|
|
- json=payload,
|
|
|
|
|
- timeout=self.timeout,
|
|
|
|
|
- )
|
|
|
|
|
- response.raise_for_status()
|
|
|
|
|
-
|
|
|
|
|
- data = response.json()
|
|
|
|
|
- return data["choices"][0]["message"]["content"]
|
|
|
|
|
-
|
|
|
|
|
- def _call_gemini(self, user_prompt: str) -> str:
|
|
|
|
|
- """调用 Google Gemini API"""
|
|
|
|
|
- import requests
|
|
|
|
|
-
|
|
|
|
|
- model = self.model or "gemini-1.5-flash"
|
|
|
|
|
- url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={self.api_key}"
|
|
|
|
|
-
|
|
|
|
|
- headers = {
|
|
|
|
|
- "Content-Type": "application/json",
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- payload = {
|
|
|
|
|
- "contents": [{
|
|
|
|
|
- "role": "user",
|
|
|
|
|
- "parts": [{"text": user_prompt}]
|
|
|
|
|
- }],
|
|
|
|
|
- "generationConfig": {
|
|
|
|
|
- "temperature": self.temperature,
|
|
|
|
|
- },
|
|
|
|
|
- "safetySettings": [
|
|
|
|
|
- {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
|
|
|
|
|
- {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
|
|
|
|
|
- {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
|
|
|
|
|
- {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
|
|
|
|
|
- ]
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if self.system_prompt:
|
|
|
|
|
- payload["system_instruction"] = {
|
|
|
|
|
- "parts": [{"text": self.system_prompt}]
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if self.max_tokens:
|
|
|
|
|
- payload["generationConfig"]["maxOutputTokens"] = self.max_tokens
|
|
|
|
|
-
|
|
|
|
|
- if self.extra_params:
|
|
|
|
|
- payload["generationConfig"].update(self.extra_params)
|
|
|
|
|
-
|
|
|
|
|
- response = requests.post(
|
|
|
|
|
- url,
|
|
|
|
|
- headers=headers,
|
|
|
|
|
- json=payload,
|
|
|
|
|
- timeout=self.timeout,
|
|
|
|
|
- )
|
|
|
|
|
- response.raise_for_status()
|
|
|
|
|
-
|
|
|
|
|
- data = response.json()
|
|
|
|
|
- return data["candidates"][0]["content"]["parts"][0]["text"]
|
|
|
|
|
-
|
|
|
|
|
def _parse_response(self, response: str) -> AIAnalysisResult:
|
|
def _parse_response(self, response: str) -> AIAnalysisResult:
|
|
|
"""解析 AI 响应"""
|
|
"""解析 AI 响应"""
|
|
|
result = AIAnalysisResult(raw_response=response)
|
|
result = AIAnalysisResult(raw_response=response)
|