analyzer.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647
  1. # coding=utf-8
  2. """
  3. AI 分析器模块
  4. 调用 AI 大模型对热点新闻进行深度分析
  5. 基于 LiteLLM 统一接口,支持 100+ AI 提供商
  6. """
  7. import json
  8. from dataclasses import dataclass, field
  9. from pathlib import Path
  10. from typing import Any, Callable, Dict, List, Optional
  11. from trendradar.ai.client import AIClient
  12. @dataclass
  13. class AIAnalysisResult:
  14. """AI 分析结果"""
  15. # 新版 5 核心板块
  16. core_trends: str = "" # 核心热点与舆情态势
  17. sentiment_controversy: str = "" # 舆论风向与争议
  18. signals: str = "" # 异动与弱信号
  19. rss_insights: str = "" # RSS 深度洞察
  20. outlook_strategy: str = "" # 研判与策略建议
  21. standalone_summaries: Dict[str, str] = field(default_factory=dict) # 独立展示区概括 {源ID: 概括}
  22. # 基础元数据
  23. raw_response: str = "" # 原始响应
  24. success: bool = False # 是否成功
  25. error: str = "" # 错误信息
  26. # 新闻数量统计
  27. total_news: int = 0 # 总新闻数(热榜+RSS)
  28. analyzed_news: int = 0 # 实际分析的新闻数
  29. max_news_limit: int = 0 # 分析上限配置值
  30. hotlist_count: int = 0 # 热榜新闻数
  31. rss_count: int = 0 # RSS 新闻数
  32. ai_mode: str = "" # AI 分析使用的模式 (daily/current/incremental)
  33. class AIAnalyzer:
  34. """AI 分析器"""
  35. def __init__(
  36. self,
  37. ai_config: Dict[str, Any],
  38. analysis_config: Dict[str, Any],
  39. get_time_func: Callable,
  40. debug: bool = False,
  41. ):
  42. """
  43. 初始化 AI 分析器
  44. Args:
  45. ai_config: AI 模型配置(LiteLLM 格式)
  46. analysis_config: AI 分析功能配置(language, prompt_file 等)
  47. get_time_func: 获取当前时间的函数
  48. debug: 是否开启调试模式
  49. """
  50. self.ai_config = ai_config
  51. self.analysis_config = analysis_config
  52. self.get_time_func = get_time_func
  53. self.debug = debug
  54. # 创建 AI 客户端(基于 LiteLLM)
  55. self.client = AIClient(ai_config)
  56. # 验证配置
  57. valid, error = self.client.validate_config()
  58. if not valid:
  59. print(f"[AI] 配置警告: {error}")
  60. # 从分析配置获取功能参数
  61. self.max_news = analysis_config.get("MAX_NEWS_FOR_ANALYSIS", 50)
  62. self.include_rss = analysis_config.get("INCLUDE_RSS", True)
  63. self.include_rank_timeline = analysis_config.get("INCLUDE_RANK_TIMELINE", False)
  64. self.include_standalone = analysis_config.get("INCLUDE_STANDALONE", False)
  65. self.language = analysis_config.get("LANGUAGE", "Chinese")
  66. # 加载提示词模板
  67. self.system_prompt, self.user_prompt_template = self._load_prompt_template(
  68. analysis_config.get("PROMPT_FILE", "ai_analysis_prompt.txt")
  69. )
  70. def _load_prompt_template(self, prompt_file: str) -> tuple:
  71. """加载提示词模板"""
  72. config_dir = Path(__file__).parent.parent.parent / "config"
  73. prompt_path = config_dir / prompt_file
  74. if not prompt_path.exists():
  75. print(f"[AI] 提示词文件不存在: {prompt_path}")
  76. return "", ""
  77. content = prompt_path.read_text(encoding="utf-8")
  78. # 解析 [system] 和 [user] 部分
  79. system_prompt = ""
  80. user_prompt = ""
  81. if "[system]" in content and "[user]" in content:
  82. parts = content.split("[user]")
  83. system_part = parts[0]
  84. user_part = parts[1] if len(parts) > 1 else ""
  85. # 提取 system 内容
  86. if "[system]" in system_part:
  87. system_prompt = system_part.split("[system]")[1].strip()
  88. user_prompt = user_part.strip()
  89. else:
  90. # 整个文件作为 user prompt
  91. user_prompt = content
  92. return system_prompt, user_prompt
  93. def analyze(
  94. self,
  95. stats: List[Dict],
  96. rss_stats: Optional[List[Dict]] = None,
  97. report_mode: str = "daily",
  98. report_type: str = "当日汇总",
  99. platforms: Optional[List[str]] = None,
  100. keywords: Optional[List[str]] = None,
  101. standalone_data: Optional[Dict] = None,
  102. ) -> AIAnalysisResult:
  103. """
  104. 执行 AI 分析
  105. Args:
  106. stats: 热榜统计数据
  107. rss_stats: RSS 统计数据
  108. report_mode: 报告模式
  109. report_type: 报告类型
  110. platforms: 平台列表
  111. keywords: 关键词列表
  112. Returns:
  113. AIAnalysisResult: 分析结果
  114. """
  115. # 打印配置信息方便调试
  116. model = self.ai_config.get("MODEL", "unknown")
  117. api_key = self.client.api_key or ""
  118. api_base = self.ai_config.get("API_BASE", "")
  119. masked_key = f"{api_key[:5]}******" if len(api_key) >= 5 else "******"
  120. model_display = model.replace("/", "/\u200b") if model else "unknown"
  121. print(f"[AI] 模型: {model_display}")
  122. print(f"[AI] Key : {masked_key}")
  123. if api_base:
  124. print(f"[AI] 接口: 存在自定义 API 端点")
  125. timeout = self.ai_config.get("TIMEOUT", 120)
  126. max_tokens = self.ai_config.get("MAX_TOKENS", 5000)
  127. print(f"[AI] 参数: timeout={timeout}, max_tokens={max_tokens}")
  128. if not self.client.api_key:
  129. return AIAnalysisResult(
  130. success=False,
  131. error="未配置 AI API Key,请在 config.yaml 或环境变量 AI_API_KEY 中设置"
  132. )
  133. # 准备新闻内容并获取统计数据
  134. news_content, rss_content, hotlist_total, rss_total, analyzed_count = self._prepare_news_content(stats, rss_stats)
  135. total_news = hotlist_total + rss_total
  136. if not news_content and not rss_content:
  137. return AIAnalysisResult(
  138. success=False,
  139. error="没有可分析的新闻内容",
  140. total_news=total_news,
  141. hotlist_count=hotlist_total,
  142. rss_count=rss_total,
  143. analyzed_news=0,
  144. max_news_limit=self.max_news
  145. )
  146. # 构建提示词
  147. current_time = self.get_time_func().strftime("%Y-%m-%d %H:%M:%S")
  148. # 提取关键词
  149. if not keywords:
  150. keywords = [s.get("word", "") for s in stats if s.get("word")] if stats else []
  151. # 使用安全的字符串替换,避免模板中其他花括号(如 JSON 示例)被误解析
  152. user_prompt = self.user_prompt_template
  153. user_prompt = user_prompt.replace("{report_mode}", report_mode)
  154. user_prompt = user_prompt.replace("{report_type}", report_type)
  155. user_prompt = user_prompt.replace("{current_time}", current_time)
  156. user_prompt = user_prompt.replace("{news_count}", str(hotlist_total))
  157. user_prompt = user_prompt.replace("{rss_count}", str(rss_total))
  158. user_prompt = user_prompt.replace("{platforms}", ", ".join(platforms) if platforms else "多平台")
  159. user_prompt = user_prompt.replace("{keywords}", ", ".join(keywords[:20]) if keywords else "无")
  160. user_prompt = user_prompt.replace("{news_content}", news_content)
  161. user_prompt = user_prompt.replace("{rss_content}", rss_content)
  162. user_prompt = user_prompt.replace("{language}", self.language)
  163. # 构建独立展示区内容
  164. standalone_content = ""
  165. if self.include_standalone and standalone_data:
  166. standalone_content = self._prepare_standalone_content(standalone_data)
  167. user_prompt = user_prompt.replace("{standalone_content}", standalone_content)
  168. if self.debug:
  169. print("\n" + "=" * 80)
  170. print("[AI 调试] 发送给 AI 的完整提示词")
  171. print("=" * 80)
  172. if self.system_prompt:
  173. print("\n--- System Prompt ---")
  174. print(self.system_prompt)
  175. print("\n--- User Prompt ---")
  176. print(user_prompt)
  177. print("=" * 80 + "\n")
  178. # 调用 AI API(使用 LiteLLM)
  179. try:
  180. response = self._call_ai(user_prompt)
  181. result = self._parse_response(response)
  182. # JSON 解析失败时的重试兜底(仅重试一次)
  183. if result.error and "JSON 解析错误" in result.error:
  184. print(f"[AI] JSON 解析失败,尝试让 AI 修复...")
  185. retry_result = self._retry_fix_json(response, result.error)
  186. if retry_result and retry_result.success and not retry_result.error:
  187. print("[AI] JSON 修复成功")
  188. retry_result.raw_response = response
  189. result = retry_result
  190. else:
  191. print("[AI] JSON 修复失败,使用原始文本兜底")
  192. # 如果配置未启用 RSS 分析,强制清空 AI 返回的 RSS 洞察
  193. if not self.include_rss:
  194. result.rss_insights = ""
  195. # 如果配置未启用 standalone 分析,强制清空
  196. if not self.include_standalone:
  197. result.standalone_summaries = {}
  198. # 填充统计数据
  199. result.total_news = total_news
  200. result.hotlist_count = hotlist_total
  201. result.rss_count = rss_total
  202. result.analyzed_news = analyzed_count
  203. result.max_news_limit = self.max_news
  204. return result
  205. except Exception as e:
  206. error_type = type(e).__name__
  207. error_msg = str(e)
  208. # 截断过长的错误消息
  209. if len(error_msg) > 200:
  210. error_msg = error_msg[:200] + "..."
  211. friendly_msg = f"AI 分析失败 ({error_type}): {error_msg}"
  212. return AIAnalysisResult(
  213. success=False,
  214. error=friendly_msg
  215. )
  216. def _prepare_news_content(
  217. self,
  218. stats: List[Dict],
  219. rss_stats: Optional[List[Dict]] = None,
  220. ) -> tuple:
  221. """
  222. 准备新闻内容文本(增强版)
  223. 热榜新闻包含:来源、标题、排名范围、时间范围、出现次数
  224. RSS 包含:来源、标题、发布时间
  225. Returns:
  226. tuple: (news_content, rss_content, hotlist_total, rss_total, analyzed_count)
  227. """
  228. news_lines = []
  229. rss_lines = []
  230. news_count = 0
  231. rss_count = 0
  232. # 计算总新闻数
  233. hotlist_total = sum(len(s.get("titles", [])) for s in stats) if stats else 0
  234. rss_total = sum(len(s.get("titles", [])) for s in rss_stats) if rss_stats else 0
  235. # 热榜内容
  236. if stats:
  237. for stat in stats:
  238. word = stat.get("word", "")
  239. titles = stat.get("titles", [])
  240. if word and titles:
  241. news_lines.append(f"\n**{word}** ({len(titles)}条)")
  242. for t in titles:
  243. if not isinstance(t, dict):
  244. continue
  245. title = t.get("title", "")
  246. if not title:
  247. continue
  248. # 来源
  249. source = t.get("source_name", t.get("source", ""))
  250. # 构建行
  251. if source:
  252. line = f"- [{source}] {title}"
  253. else:
  254. line = f"- {title}"
  255. # 始终显示简化格式:排名范围 + 时间范围 + 出现次数
  256. ranks = t.get("ranks", [])
  257. if ranks:
  258. min_rank = min(ranks)
  259. max_rank = max(ranks)
  260. rank_str = f"{min_rank}" if min_rank == max_rank else f"{min_rank}-{max_rank}"
  261. else:
  262. rank_str = "-"
  263. first_time = t.get("first_time", "")
  264. last_time = t.get("last_time", "")
  265. time_str = self._format_time_range(first_time, last_time)
  266. appear_count = t.get("count", 1)
  267. line += f" | 排名:{rank_str} | 时间:{time_str} | 出现:{appear_count}次"
  268. # 开启完整时间线时,额外添加轨迹
  269. if self.include_rank_timeline:
  270. rank_timeline = t.get("rank_timeline", [])
  271. timeline_str = self._format_rank_timeline(rank_timeline)
  272. line += f" | 轨迹:{timeline_str}"
  273. news_lines.append(line)
  274. news_count += 1
  275. if news_count >= self.max_news:
  276. break
  277. if news_count >= self.max_news:
  278. break
  279. # RSS 内容(仅在启用时构建)
  280. if self.include_rss and rss_stats:
  281. remaining = self.max_news - news_count
  282. for stat in rss_stats:
  283. if rss_count >= remaining:
  284. break
  285. word = stat.get("word", "")
  286. titles = stat.get("titles", [])
  287. if word and titles:
  288. rss_lines.append(f"\n**{word}** ({len(titles)}条)")
  289. for t in titles:
  290. if not isinstance(t, dict):
  291. continue
  292. title = t.get("title", "")
  293. if not title:
  294. continue
  295. # 来源
  296. source = t.get("source_name", t.get("feed_name", ""))
  297. # 发布时间
  298. time_display = t.get("time_display", "")
  299. # 构建行:[来源] 标题 | 发布时间
  300. if source:
  301. line = f"- [{source}] {title}"
  302. else:
  303. line = f"- {title}"
  304. if time_display:
  305. line += f" | {time_display}"
  306. rss_lines.append(line)
  307. rss_count += 1
  308. if rss_count >= remaining:
  309. break
  310. news_content = "\n".join(news_lines) if news_lines else ""
  311. rss_content = "\n".join(rss_lines) if rss_lines else ""
  312. total_count = news_count + rss_count
  313. return news_content, rss_content, hotlist_total, rss_total, total_count
  314. def _call_ai(self, user_prompt: str) -> str:
  315. """调用 AI API(使用 LiteLLM)"""
  316. messages = []
  317. if self.system_prompt:
  318. messages.append({"role": "system", "content": self.system_prompt})
  319. messages.append({"role": "user", "content": user_prompt})
  320. return self.client.chat(messages)
  321. def _retry_fix_json(self, original_response: str, error_msg: str) -> Optional[AIAnalysisResult]:
  322. """
  323. JSON 解析失败时,请求 AI 修复 JSON(仅重试一次)
  324. 使用轻量 prompt,不重复原始分析的 system prompt,节省 token。
  325. Args:
  326. original_response: AI 原始响应(JSON 格式有误)
  327. error_msg: JSON 解析的错误信息
  328. Returns:
  329. 修复后的分析结果,失败时返回 None
  330. """
  331. messages = [
  332. {
  333. "role": "system",
  334. "content": (
  335. "你是一个 JSON 修复助手。用户会提供一段格式有误的 JSON 和错误信息,"
  336. "你需要修复 JSON 格式错误并返回正确的 JSON。\n"
  337. "常见问题:字符串值内的双引号未转义、缺少逗号、字符串未正确闭合等。\n"
  338. "只返回纯 JSON,不要包含 markdown 代码块标记(如 ```json)或任何说明文字。"
  339. ),
  340. },
  341. {
  342. "role": "user",
  343. "content": (
  344. f"以下 JSON 解析失败:\n\n"
  345. f"错误:{error_msg}\n\n"
  346. f"原始内容:\n{original_response}\n\n"
  347. f"请修复以上 JSON 中的格式问题(如值中的双引号改用中文引号「」或转义 \\\"、"
  348. f"缺少逗号、不完整的字符串等),保持原始内容语义不变,只修复格式。"
  349. f"直接返回修复后的纯 JSON。"
  350. ),
  351. },
  352. ]
  353. try:
  354. response = self.client.chat(messages)
  355. return self._parse_response(response)
  356. except Exception as e:
  357. print(f"[AI] 重试修复 JSON 异常: {type(e).__name__}: {e}")
  358. return None
  359. def _format_time_range(self, first_time: str, last_time: str) -> str:
  360. """格式化时间范围(简化显示,只保留时分)"""
  361. def extract_time(time_str: str) -> str:
  362. if not time_str:
  363. return "-"
  364. # 尝试提取 HH:MM 部分
  365. if " " in time_str:
  366. parts = time_str.split(" ")
  367. if len(parts) >= 2:
  368. time_part = parts[1]
  369. if ":" in time_part:
  370. return time_part[:5] # HH:MM
  371. elif ":" in time_str:
  372. return time_str[:5]
  373. # 处理 HH-MM 格式
  374. result = time_str[:5] if len(time_str) >= 5 else time_str
  375. if len(result) == 5 and result[2] == '-':
  376. result = result.replace('-', ':')
  377. return result
  378. first = extract_time(first_time)
  379. last = extract_time(last_time)
  380. if first == last or last == "-":
  381. return first
  382. return f"{first}~{last}"
  383. def _format_rank_timeline(self, rank_timeline: List[Dict]) -> str:
  384. """格式化排名时间线"""
  385. if not rank_timeline:
  386. return "-"
  387. parts = []
  388. for item in rank_timeline:
  389. time_str = item.get("time", "")
  390. if len(time_str) == 5 and time_str[2] == '-':
  391. time_str = time_str.replace('-', ':')
  392. rank = item.get("rank")
  393. if rank is None:
  394. parts.append(f"0({time_str})")
  395. else:
  396. parts.append(f"{rank}({time_str})")
  397. return "→".join(parts)
  398. def _prepare_standalone_content(self, standalone_data: Dict) -> str:
  399. """
  400. 将独立展示区数据转为文本,注入 AI 分析 prompt
  401. Args:
  402. standalone_data: 独立展示区数据 {"platforms": [...], "rss_feeds": [...]}
  403. Returns:
  404. 格式化的文本内容
  405. """
  406. lines = []
  407. # 热榜平台
  408. for platform in standalone_data.get("platforms", []):
  409. platform_id = platform.get("id", "")
  410. platform_name = platform.get("name", platform_id)
  411. items = platform.get("items", [])
  412. if not items:
  413. continue
  414. lines.append(f"### [{platform_name}]")
  415. for item in items:
  416. title = item.get("title", "")
  417. if not title:
  418. continue
  419. line = f"- {title}"
  420. # 排名信息
  421. ranks = item.get("ranks", [])
  422. if ranks:
  423. min_rank = min(ranks)
  424. max_rank = max(ranks)
  425. rank_str = f"{min_rank}" if min_rank == max_rank else f"{min_rank}-{max_rank}"
  426. line += f" | 排名:{rank_str}"
  427. # 时间范围
  428. first_time = item.get("first_time", "")
  429. last_time = item.get("last_time", "")
  430. if first_time:
  431. time_str = self._format_time_range(first_time, last_time)
  432. line += f" | 时间:{time_str}"
  433. # 出现次数
  434. count = item.get("count", 1)
  435. if count > 1:
  436. line += f" | 出现:{count}次"
  437. # 排名轨迹(如果启用)
  438. if self.include_rank_timeline:
  439. rank_timeline = item.get("rank_timeline", [])
  440. if rank_timeline:
  441. timeline_str = self._format_rank_timeline(rank_timeline)
  442. line += f" | 轨迹:{timeline_str}"
  443. lines.append(line)
  444. lines.append("")
  445. # RSS 源
  446. for feed in standalone_data.get("rss_feeds", []):
  447. feed_id = feed.get("id", "")
  448. feed_name = feed.get("name", feed_id)
  449. items = feed.get("items", [])
  450. if not items:
  451. continue
  452. lines.append(f"### [{feed_name}]")
  453. for item in items:
  454. title = item.get("title", "")
  455. if not title:
  456. continue
  457. line = f"- {title}"
  458. published_at = item.get("published_at", "")
  459. if published_at:
  460. line += f" | {published_at}"
  461. lines.append(line)
  462. lines.append("")
  463. return "\n".join(lines)
  464. def _parse_response(self, response: str) -> AIAnalysisResult:
  465. """解析 AI 响应"""
  466. result = AIAnalysisResult(raw_response=response)
  467. if not response or not response.strip():
  468. result.error = "AI 返回空响应"
  469. return result
  470. # 提取 JSON 文本(去掉 markdown 代码块标记)
  471. json_str = response
  472. if "```json" in response:
  473. parts = response.split("```json", 1)
  474. if len(parts) > 1:
  475. code_block = parts[1]
  476. end_idx = code_block.find("```")
  477. if end_idx != -1:
  478. json_str = code_block[:end_idx]
  479. else:
  480. json_str = code_block
  481. elif "```" in response:
  482. parts = response.split("```", 2)
  483. if len(parts) >= 2:
  484. json_str = parts[1]
  485. json_str = json_str.strip()
  486. if not json_str:
  487. result.error = "提取的 JSON 内容为空"
  488. result.core_trends = response[:500] + "..." if len(response) > 500 else response
  489. result.success = True
  490. return result
  491. # 第一步:标准 JSON 解析
  492. data = None
  493. parse_error = None
  494. try:
  495. data = json.loads(json_str)
  496. except json.JSONDecodeError as e:
  497. parse_error = e
  498. # 第二步:json_repair 本地修复
  499. if data is None:
  500. try:
  501. from json_repair import repair_json
  502. repaired = repair_json(json_str, return_objects=True)
  503. if isinstance(repaired, dict):
  504. data = repaired
  505. print("[AI] JSON 本地修复成功(json_repair)")
  506. except Exception:
  507. pass
  508. # 两步都失败,记录错误(后续由 analyze 方法的重试机制处理)
  509. if data is None:
  510. if parse_error:
  511. error_context = json_str[max(0, parse_error.pos - 30):parse_error.pos + 30] if json_str and parse_error.pos else ""
  512. result.error = f"JSON 解析错误 (位置 {parse_error.pos}): {parse_error.msg}"
  513. if error_context:
  514. result.error += f",上下文: ...{error_context}..."
  515. else:
  516. result.error = "JSON 解析失败"
  517. # 兜底:使用已提取的 json_str(不含 markdown 标记),避免推送中出现 ```json
  518. result.core_trends = json_str[:500] + "..." if len(json_str) > 500 else json_str
  519. result.success = True
  520. return result
  521. # 解析成功,提取字段
  522. try:
  523. result.core_trends = data.get("core_trends", "")
  524. result.sentiment_controversy = data.get("sentiment_controversy", "")
  525. result.signals = data.get("signals", "")
  526. result.rss_insights = data.get("rss_insights", "")
  527. result.outlook_strategy = data.get("outlook_strategy", "")
  528. # 解析独立展示区概括
  529. summaries = data.get("standalone_summaries", {})
  530. if isinstance(summaries, dict):
  531. result.standalone_summaries = {
  532. str(k): str(v) for k, v in summaries.items()
  533. }
  534. result.success = True
  535. except (KeyError, TypeError, AttributeError) as e:
  536. result.error = f"字段提取错误: {type(e).__name__}: {e}"
  537. result.core_trends = json_str[:500] + "..." if len(json_str) > 500 else json_str
  538. result.success = True
  539. return result