analyzer.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. # coding=utf-8
  2. """
  3. AI 分析器模块
  4. 调用 AI 大模型对热点新闻进行深度分析
  5. 支持 OpenAI、Google Gemini、Azure OpenAI 等兼容接口
  6. """
  7. import json
  8. import os
  9. from dataclasses import dataclass
  10. from pathlib import Path
  11. from typing import Any, Callable, Dict, List, Optional
  12. @dataclass
  13. class AIAnalysisResult:
  14. """AI 分析结果"""
  15. summary: str = "" # 热点趋势概述
  16. keyword_analysis: str = "" # 关键词热度分析
  17. sentiment: str = "" # 情感倾向分析
  18. cross_platform: str = "" # 跨平台关联
  19. impact: str = "" # 潜在影响评估
  20. signals: str = "" # 值得关注的信号
  21. conclusion: str = "" # 总结与建议
  22. raw_response: str = "" # 原始响应
  23. success: bool = False # 是否成功
  24. error: str = "" # 错误信息
  25. # 新闻数量统计
  26. total_news: int = 0 # 总新闻数(热榜+RSS)
  27. analyzed_news: int = 0 # 实际分析的新闻数
  28. max_news_limit: int = 0 # 分析上限配置值
  29. hotlist_count: int = 0 # 热榜新闻数
  30. rss_count: int = 0 # RSS 新闻数
  31. class AIAnalyzer:
  32. """AI 分析器"""
  33. def __init__(self, config: Dict[str, Any], get_time_func: Callable):
  34. """
  35. 初始化 AI 分析器
  36. Args:
  37. config: AI 分析配置
  38. get_time_func: 获取当前时间的函数
  39. """
  40. self.config = config
  41. self.get_time_func = get_time_func
  42. # 从配置或环境变量获取 API Key
  43. self.api_key = config.get("API_KEY") or os.environ.get("AI_API_KEY", "")
  44. self.provider = config.get("PROVIDER", "openai")
  45. self.model = config.get("MODEL", "gpt-4o-mini")
  46. self.base_url = config.get("BASE_URL", "")
  47. self.timeout = config.get("TIMEOUT", 90)
  48. self.max_news = config.get("MAX_NEWS_FOR_ANALYSIS", 50)
  49. self.include_rss = config.get("INCLUDE_RSS", True)
  50. self.push_mode = config.get("PUSH_MODE", "both")
  51. # 加载提示词模板
  52. self.system_prompt, self.user_prompt_template = self._load_prompt_template(
  53. config.get("PROMPT_FILE", "ai_analysis_prompt.txt")
  54. )
  55. def _load_prompt_template(self, prompt_file: str) -> tuple:
  56. """加载提示词模板"""
  57. config_dir = Path(__file__).parent.parent.parent / "config"
  58. prompt_path = config_dir / prompt_file
  59. if not prompt_path.exists():
  60. print(f"[AI] 提示词文件不存在: {prompt_path}")
  61. return "", ""
  62. content = prompt_path.read_text(encoding="utf-8")
  63. # 解析 [system] 和 [user] 部分
  64. system_prompt = ""
  65. user_prompt = ""
  66. if "[system]" in content and "[user]" in content:
  67. parts = content.split("[user]")
  68. system_part = parts[0]
  69. user_part = parts[1] if len(parts) > 1 else ""
  70. # 提取 system 内容
  71. if "[system]" in system_part:
  72. system_prompt = system_part.split("[system]")[1].strip()
  73. user_prompt = user_part.strip()
  74. else:
  75. # 整个文件作为 user prompt
  76. user_prompt = content
  77. return system_prompt, user_prompt
  78. def analyze(
  79. self,
  80. stats: List[Dict],
  81. rss_stats: Optional[List[Dict]] = None,
  82. report_mode: str = "daily",
  83. report_type: str = "当日汇总",
  84. platforms: Optional[List[str]] = None,
  85. keywords: Optional[List[str]] = None,
  86. ) -> AIAnalysisResult:
  87. """
  88. 执行 AI 分析
  89. Args:
  90. stats: 热榜统计数据
  91. rss_stats: RSS 统计数据
  92. report_mode: 报告模式
  93. report_type: 报告类型
  94. platforms: 平台列表
  95. keywords: 关键词列表
  96. Returns:
  97. AIAnalysisResult: 分析结果
  98. """
  99. if not self.api_key:
  100. return AIAnalysisResult(
  101. success=False,
  102. error="未配置 AI API Key,请在 config.yaml 或环境变量 AI_API_KEY 中设置"
  103. )
  104. # 准备新闻内容并获取统计数据
  105. news_content, hotlist_total, rss_total, analyzed_count = self._prepare_news_content(stats, rss_stats)
  106. total_news = hotlist_total + rss_total
  107. if not news_content:
  108. return AIAnalysisResult(
  109. success=False,
  110. error="没有可分析的新闻内容",
  111. total_news=total_news,
  112. hotlist_count=hotlist_total,
  113. rss_count=rss_total,
  114. analyzed_news=0,
  115. max_news_limit=self.max_news
  116. )
  117. # 构建提示词
  118. current_time = self.get_time_func().strftime("%Y-%m-%d %H:%M:%S")
  119. # 提取关键词
  120. if not keywords:
  121. keywords = [s.get("word", "") for s in stats if s.get("word")] if stats else []
  122. # 使用安全的字符串替换,避免模板中其他花括号(如 JSON 示例)被误解析
  123. user_prompt = self.user_prompt_template
  124. user_prompt = user_prompt.replace("{report_mode}", report_mode)
  125. user_prompt = user_prompt.replace("{report_type}", report_type)
  126. user_prompt = user_prompt.replace("{current_time}", current_time)
  127. user_prompt = user_prompt.replace("{news_count}", str(hotlist_total))
  128. user_prompt = user_prompt.replace("{rss_count}", str(rss_total))
  129. user_prompt = user_prompt.replace("{platforms}", ", ".join(platforms) if platforms else "多平台")
  130. user_prompt = user_prompt.replace("{keywords}", ", ".join(keywords[:20]) if keywords else "无")
  131. user_prompt = user_prompt.replace("{news_content}", news_content)
  132. # 调用 AI API
  133. try:
  134. response = self._call_ai_api(user_prompt)
  135. result = self._parse_response(response)
  136. # 填充统计数据
  137. result.total_news = total_news
  138. result.hotlist_count = hotlist_total
  139. result.rss_count = rss_total
  140. result.analyzed_news = analyzed_count
  141. result.max_news_limit = self.max_news
  142. return result
  143. except Exception as e:
  144. import requests
  145. error_type = type(e).__name__
  146. error_msg = str(e)
  147. # 针对不同错误类型提供更友好的提示
  148. if isinstance(e, requests.exceptions.Timeout):
  149. friendly_msg = f"AI API 请求超时({self.timeout}秒),请检查网络或增加超时时间"
  150. elif isinstance(e, requests.exceptions.ConnectionError):
  151. friendly_msg = f"无法连接到 AI API ({self.base_url or self.provider}),请检查网络和 API 地址"
  152. elif isinstance(e, requests.exceptions.HTTPError):
  153. status_code = e.response.status_code if hasattr(e, 'response') and e.response else "未知"
  154. if status_code == 401:
  155. friendly_msg = "AI API 认证失败,请检查 API Key 是否正确"
  156. elif status_code == 429:
  157. friendly_msg = "AI API 请求频率过高,请稍后重试"
  158. elif status_code == 500:
  159. friendly_msg = "AI API 服务器内部错误,请稍后重试"
  160. else:
  161. friendly_msg = f"AI API 返回错误 (HTTP {status_code}): {error_msg[:100]}"
  162. else:
  163. # 截断过长的错误消息
  164. if len(error_msg) > 150:
  165. error_msg = error_msg[:150] + "..."
  166. friendly_msg = f"AI 分析失败 ({error_type}): {error_msg}"
  167. return AIAnalysisResult(
  168. success=False,
  169. error=friendly_msg
  170. )
  171. def _prepare_news_content(
  172. self,
  173. stats: List[Dict],
  174. rss_stats: Optional[List[Dict]] = None,
  175. ) -> tuple:
  176. """
  177. 准备新闻内容文本(增强版)
  178. 热榜新闻包含:来源、标题、排名范围、时间范围、出现次数
  179. RSS 包含:来源、标题、发布时间
  180. Returns:
  181. tuple: (content_str, hotlist_total, rss_total, analyzed_count)
  182. """
  183. lines = []
  184. count = 0
  185. # 计算总新闻数
  186. hotlist_total = sum(len(s.get("titles", [])) for s in stats) if stats else 0
  187. rss_total = sum(len(s.get("titles", [])) for s in rss_stats) if rss_stats else 0
  188. # 热榜内容
  189. if stats:
  190. lines.append("### 热榜新闻")
  191. lines.append("格式: [来源] 标题 | 排名:最高-最低 | 时间:首次~末次 | 出现:N次")
  192. for stat in stats:
  193. word = stat.get("word", "")
  194. titles = stat.get("titles", [])
  195. if word and titles:
  196. lines.append(f"\n**{word}** ({len(titles)}条)")
  197. for t in titles:
  198. if not isinstance(t, dict):
  199. continue
  200. title = t.get("title", "")
  201. if not title:
  202. continue
  203. # 来源
  204. source = t.get("source_name", t.get("source", ""))
  205. # 排名范围
  206. ranks = t.get("ranks", [])
  207. if ranks:
  208. min_rank = min(ranks)
  209. max_rank = max(ranks)
  210. rank_str = f"{min_rank}" if min_rank == max_rank else f"{min_rank}-{max_rank}"
  211. else:
  212. rank_str = "-"
  213. # 时间范围(简化显示)
  214. first_time = t.get("first_time", "")
  215. last_time = t.get("last_time", "")
  216. time_str = self._format_time_range(first_time, last_time)
  217. # 出现次数
  218. appear_count = t.get("count", 1)
  219. # 构建行:[来源] 标题 | 排名:X-Y | 时间:首次~末次 | 出现:N次
  220. if source:
  221. line = f"- [{source}] {title}"
  222. else:
  223. line = f"- {title}"
  224. line += f" | 排名:{rank_str} | 时间:{time_str} | 出现:{appear_count}次"
  225. lines.append(line)
  226. count += 1
  227. if count >= self.max_news:
  228. break
  229. if count >= self.max_news:
  230. break
  231. # RSS 内容(仅在启用时提交)
  232. if self.include_rss and rss_stats and count < self.max_news:
  233. lines.append("\n### RSS 订阅")
  234. lines.append("格式: [来源] 标题 | 发布时间")
  235. for stat in rss_stats:
  236. word = stat.get("word", "")
  237. titles = stat.get("titles", [])
  238. if word and titles:
  239. lines.append(f"\n**{word}** ({len(titles)}条)")
  240. for t in titles:
  241. if not isinstance(t, dict):
  242. continue
  243. title = t.get("title", "")
  244. if not title:
  245. continue
  246. # 来源
  247. source = t.get("source_name", t.get("feed_name", ""))
  248. # 发布时间
  249. time_display = t.get("time_display", "")
  250. # 构建行:[来源] 标题 | 发布时间
  251. if source:
  252. line = f"- [{source}] {title}"
  253. else:
  254. line = f"- {title}"
  255. if time_display:
  256. line += f" | {time_display}"
  257. lines.append(line)
  258. count += 1
  259. if count >= self.max_news:
  260. break
  261. if count >= self.max_news:
  262. break
  263. return "\n".join(lines), hotlist_total, rss_total, count
  264. def _format_time_range(self, first_time: str, last_time: str) -> str:
  265. """格式化时间范围(简化显示,只保留时分)"""
  266. def extract_time(time_str: str) -> str:
  267. if not time_str:
  268. return "-"
  269. # 尝试提取 HH:MM 部分
  270. # 格式可能是 "2026-01-04 12:30:00" 或 "12:30" 等
  271. if " " in time_str:
  272. parts = time_str.split(" ")
  273. if len(parts) >= 2:
  274. time_part = parts[1]
  275. if ":" in time_part:
  276. return time_part[:5] # HH:MM
  277. elif ":" in time_str:
  278. return time_str[:5]
  279. return time_str[:5] if len(time_str) >= 5 else time_str
  280. first = extract_time(first_time)
  281. last = extract_time(last_time)
  282. if first == last or last == "-":
  283. return first
  284. return f"{first}~{last}"
  285. def _call_ai_api(self, user_prompt: str) -> str:
  286. """调用 AI API"""
  287. if self.provider == "gemini":
  288. return self._call_gemini(user_prompt)
  289. return self._call_openai_compatible(user_prompt)
  290. def _get_api_url(self) -> str:
  291. """获取完整 API URL"""
  292. if self.base_url:
  293. return self.base_url
  294. # 预设完整端点
  295. urls = {
  296. "deepseek": "https://api.deepseek.com/v1/chat/completions",
  297. "openai": "https://api.openai.com/v1/chat/completions",
  298. }
  299. url = urls.get(self.provider)
  300. if not url:
  301. raise ValueError(f"{self.provider} 需要配置 base_url(完整 API 地址)")
  302. return url
  303. def _call_openai_compatible(self, user_prompt: str) -> str:
  304. """调用 OpenAI 兼容接口"""
  305. import requests
  306. url = self._get_api_url()
  307. headers = {
  308. "Authorization": f"Bearer {self.api_key}",
  309. "Content-Type": "application/json",
  310. }
  311. messages = []
  312. if self.system_prompt:
  313. messages.append({"role": "system", "content": self.system_prompt})
  314. messages.append({"role": "user", "content": user_prompt})
  315. payload = {
  316. "model": self.model,
  317. "messages": messages,
  318. "temperature": 0.7,
  319. "max_tokens": 2000,
  320. }
  321. response = requests.post(
  322. url,
  323. headers=headers,
  324. json=payload,
  325. timeout=self.timeout,
  326. )
  327. response.raise_for_status()
  328. data = response.json()
  329. return data["choices"][0]["message"]["content"]
  330. def _call_gemini(self, user_prompt: str) -> str:
  331. """调用 Google Gemini API"""
  332. import requests
  333. # Gemini API URL 格式: https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent
  334. model = self.model or "gemini-1.5-flash"
  335. url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={self.api_key}"
  336. headers = {
  337. "Content-Type": "application/json",
  338. }
  339. # 构建 Gemini 格式的消息
  340. contents = []
  341. if self.system_prompt:
  342. contents.append({
  343. "role": "user",
  344. "parts": [{"text": f"System instruction: {self.system_prompt}"}]
  345. })
  346. contents.append({
  347. "role": "model",
  348. "parts": [{"text": "Understood. I will follow these instructions."}]
  349. })
  350. contents.append({
  351. "role": "user",
  352. "parts": [{"text": user_prompt}]
  353. })
  354. payload = {
  355. "contents": contents,
  356. "generationConfig": {
  357. "temperature": 0.7,
  358. "maxOutputTokens": 2000,
  359. }
  360. }
  361. response = requests.post(
  362. url,
  363. headers=headers,
  364. json=payload,
  365. timeout=self.timeout,
  366. )
  367. response.raise_for_status()
  368. data = response.json()
  369. return data["candidates"][0]["content"]["parts"][0]["text"]
  370. def _parse_response(self, response: str) -> AIAnalysisResult:
  371. """解析 AI 响应"""
  372. result = AIAnalysisResult(raw_response=response)
  373. if not response or not response.strip():
  374. result.error = "AI 返回空响应"
  375. return result
  376. # 尝试解析 JSON
  377. try:
  378. # 提取 JSON 部分
  379. json_str = response
  380. # 尝试提取 ```json ... ``` 代码块
  381. if "```json" in response:
  382. parts = response.split("```json", 1)
  383. if len(parts) > 1:
  384. code_block = parts[1]
  385. # 查找结束的 ```
  386. end_idx = code_block.find("```")
  387. if end_idx != -1:
  388. json_str = code_block[:end_idx]
  389. else:
  390. json_str = code_block # 没有结束标记,使用剩余内容
  391. # 尝试提取 ``` ... ``` 代码块
  392. elif "```" in response:
  393. parts = response.split("```", 2) # 最多分割2次
  394. if len(parts) >= 2:
  395. json_str = parts[1]
  396. # 清理 JSON 字符串
  397. json_str = json_str.strip()
  398. if not json_str:
  399. raise ValueError("提取的 JSON 内容为空")
  400. data = json.loads(json_str)
  401. result.summary = data.get("summary", "")
  402. result.keyword_analysis = data.get("keyword_analysis", "")
  403. result.sentiment = data.get("sentiment", "")
  404. result.cross_platform = data.get("cross_platform", "")
  405. result.impact = data.get("impact", "")
  406. result.signals = data.get("signals", "")
  407. result.conclusion = data.get("conclusion", "")
  408. result.success = True
  409. except json.JSONDecodeError as e:
  410. # JSON 解析失败,记录详细错误但仍使用原始文本
  411. error_context = json_str[max(0, e.pos - 30):e.pos + 30] if json_str and e.pos else ""
  412. result.error = f"JSON 解析错误 (位置 {e.pos}): {e.msg}"
  413. if error_context:
  414. result.error += f",上下文: ...{error_context}..."
  415. # 使用原始响应作为 summary
  416. result.summary = response[:1000] if len(response) > 1000 else response
  417. result.success = True # 仍标记为成功,因为有内容可展示
  418. except (IndexError, KeyError, TypeError, ValueError) as e:
  419. # 其他解析错误
  420. result.error = f"响应解析错误: {type(e).__name__}: {str(e)}"
  421. result.summary = response[:1000] if len(response) > 1000 else response
  422. result.success = True
  423. except Exception as e:
  424. # 未知错误
  425. result.error = f"解析时发生未知错误: {type(e).__name__}: {str(e)}"
  426. result.summary = response[:1000] if len(response) > 1000 else response
  427. result.success = True
  428. return result