analyzer.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617
  1. # coding=utf-8
  2. """
  3. AI 分析器模块
  4. 调用 AI 大模型对热点新闻进行深度分析
  5. 基于 LiteLLM 统一接口,支持 100+ AI 提供商
  6. """
  7. import json
  8. from dataclasses import dataclass, field
  9. from typing import Any, Callable, Dict, List, Optional
  10. from trendradar.ai.client import AIClient
  11. from trendradar.ai.prompt_loader import load_prompt_template
  12. @dataclass
  13. class AIAnalysisResult:
  14. """AI 分析结果"""
  15. # 新版 5 核心板块
  16. core_trends: str = "" # 核心热点与舆情态势
  17. sentiment_controversy: str = "" # 舆论风向与争议
  18. signals: str = "" # 异动与弱信号
  19. rss_insights: str = "" # RSS 深度洞察
  20. outlook_strategy: str = "" # 研判与策略建议
  21. standalone_summaries: Dict[str, str] = field(default_factory=dict) # 独立展示区概括 {源ID: 概括}
  22. # 基础元数据
  23. raw_response: str = "" # 原始响应
  24. success: bool = False # 是否成功
  25. error: str = "" # 错误信息
  26. # 新闻数量统计
  27. total_news: int = 0 # 总新闻数(热榜+RSS)
  28. analyzed_news: int = 0 # 实际分析的新闻数
  29. max_news_limit: int = 0 # 分析上限配置值
  30. hotlist_count: int = 0 # 热榜新闻数
  31. rss_count: int = 0 # RSS 新闻数
  32. ai_mode: str = "" # AI 分析使用的模式 (daily/current/incremental)
  33. class AIAnalyzer:
  34. """AI 分析器"""
  35. def __init__(
  36. self,
  37. ai_config: Dict[str, Any],
  38. analysis_config: Dict[str, Any],
  39. get_time_func: Callable,
  40. debug: bool = False,
  41. ):
  42. """
  43. 初始化 AI 分析器
  44. Args:
  45. ai_config: AI 模型配置(LiteLLM 格式)
  46. analysis_config: AI 分析功能配置(language, prompt_file 等)
  47. get_time_func: 获取当前时间的函数
  48. debug: 是否开启调试模式
  49. """
  50. self.ai_config = ai_config
  51. self.analysis_config = analysis_config
  52. self.get_time_func = get_time_func
  53. self.debug = debug
  54. # 创建 AI 客户端(基于 LiteLLM)
  55. self.client = AIClient(ai_config)
  56. # 验证配置
  57. valid, error = self.client.validate_config()
  58. if not valid:
  59. print(f"[AI] 配置警告: {error}")
  60. # 从分析配置获取功能参数
  61. self.max_news = analysis_config.get("MAX_NEWS_FOR_ANALYSIS", 50)
  62. self.include_rss = analysis_config.get("INCLUDE_RSS", True)
  63. self.include_rank_timeline = analysis_config.get("INCLUDE_RANK_TIMELINE", False)
  64. self.include_standalone = analysis_config.get("INCLUDE_STANDALONE", False)
  65. self.language = analysis_config.get("LANGUAGE", "Chinese")
  66. # 加载提示词模板
  67. self.system_prompt, self.user_prompt_template = load_prompt_template(
  68. analysis_config.get("PROMPT_FILE", "ai_analysis_prompt.txt"),
  69. label="AI",
  70. )
  71. def analyze(
  72. self,
  73. stats: List[Dict],
  74. rss_stats: Optional[List[Dict]] = None,
  75. report_mode: str = "daily",
  76. report_type: str = "当日汇总",
  77. platforms: Optional[List[str]] = None,
  78. keywords: Optional[List[str]] = None,
  79. standalone_data: Optional[Dict] = None,
  80. ) -> AIAnalysisResult:
  81. """
  82. 执行 AI 分析
  83. Args:
  84. stats: 热榜统计数据
  85. rss_stats: RSS 统计数据
  86. report_mode: 报告模式
  87. report_type: 报告类型
  88. platforms: 平台列表
  89. keywords: 关键词列表
  90. Returns:
  91. AIAnalysisResult: 分析结果
  92. """
  93. # 打印配置信息方便调试
  94. model = self.ai_config.get("MODEL", "unknown")
  95. api_key = self.client.api_key or ""
  96. api_base = self.ai_config.get("API_BASE", "")
  97. masked_key = f"{api_key[:5]}******" if len(api_key) >= 5 else "******"
  98. model_display = model.replace("/", "/\u200b") if model else "unknown"
  99. print(f"[AI] 模型: {model_display}")
  100. print(f"[AI] Key : {masked_key}")
  101. if api_base:
  102. print(f"[AI] 接口: 存在自定义 API 端点")
  103. timeout = self.ai_config.get("TIMEOUT", 120)
  104. max_tokens = self.ai_config.get("MAX_TOKENS", 5000)
  105. print(f"[AI] 参数: timeout={timeout}, max_tokens={max_tokens}")
  106. if not self.client.api_key:
  107. return AIAnalysisResult(
  108. success=False,
  109. error="未配置 AI API Key,请在 config.yaml 或环境变量 AI_API_KEY 中设置"
  110. )
  111. # 准备新闻内容并获取统计数据
  112. news_content, rss_content, hotlist_total, rss_total, analyzed_count = self._prepare_news_content(stats, rss_stats)
  113. total_news = hotlist_total + rss_total
  114. if not news_content and not rss_content:
  115. return AIAnalysisResult(
  116. success=False,
  117. error="没有可分析的新闻内容",
  118. total_news=total_news,
  119. hotlist_count=hotlist_total,
  120. rss_count=rss_total,
  121. analyzed_news=0,
  122. max_news_limit=self.max_news
  123. )
  124. # 构建提示词
  125. current_time = self.get_time_func().strftime("%Y-%m-%d %H:%M:%S")
  126. # 提取关键词
  127. if not keywords:
  128. keywords = [s.get("word", "") for s in stats if s.get("word")] if stats else []
  129. # 使用安全的字符串替换,避免模板中其他花括号(如 JSON 示例)被误解析
  130. user_prompt = self.user_prompt_template
  131. user_prompt = user_prompt.replace("{report_mode}", report_mode)
  132. user_prompt = user_prompt.replace("{report_type}", report_type)
  133. user_prompt = user_prompt.replace("{current_time}", current_time)
  134. user_prompt = user_prompt.replace("{news_count}", str(hotlist_total))
  135. user_prompt = user_prompt.replace("{rss_count}", str(rss_total))
  136. user_prompt = user_prompt.replace("{platforms}", ", ".join(platforms) if platforms else "多平台")
  137. user_prompt = user_prompt.replace("{keywords}", ", ".join(keywords[:20]) if keywords else "无")
  138. user_prompt = user_prompt.replace("{news_content}", news_content)
  139. user_prompt = user_prompt.replace("{rss_content}", rss_content)
  140. user_prompt = user_prompt.replace("{language}", self.language)
  141. # 构建独立展示区内容
  142. standalone_content = ""
  143. if self.include_standalone and standalone_data:
  144. standalone_content = self._prepare_standalone_content(standalone_data)
  145. user_prompt = user_prompt.replace("{standalone_content}", standalone_content)
  146. if self.debug:
  147. print("\n" + "=" * 80)
  148. print("[AI 调试] 发送给 AI 的完整提示词")
  149. print("=" * 80)
  150. if self.system_prompt:
  151. print("\n--- System Prompt ---")
  152. print(self.system_prompt)
  153. print("\n--- User Prompt ---")
  154. print(user_prompt)
  155. print("=" * 80 + "\n")
  156. # 调用 AI API(使用 LiteLLM)
  157. try:
  158. response = self._call_ai(user_prompt)
  159. result = self._parse_response(response)
  160. # JSON 解析失败时的重试兜底(仅重试一次)
  161. if result.error and "JSON 解析错误" in result.error:
  162. print(f"[AI] JSON 解析失败,尝试让 AI 修复...")
  163. retry_result = self._retry_fix_json(response, result.error)
  164. if retry_result and retry_result.success and not retry_result.error:
  165. print("[AI] JSON 修复成功")
  166. retry_result.raw_response = response
  167. result = retry_result
  168. else:
  169. print("[AI] JSON 修复失败,使用原始文本兜底")
  170. # 如果配置未启用 RSS 分析,强制清空 AI 返回的 RSS 洞察
  171. if not self.include_rss:
  172. result.rss_insights = ""
  173. # 如果配置未启用 standalone 分析,强制清空
  174. if not self.include_standalone:
  175. result.standalone_summaries = {}
  176. # 填充统计数据
  177. result.total_news = total_news
  178. result.hotlist_count = hotlist_total
  179. result.rss_count = rss_total
  180. result.analyzed_news = analyzed_count
  181. result.max_news_limit = self.max_news
  182. return result
  183. except Exception as e:
  184. error_type = type(e).__name__
  185. error_msg = str(e)
  186. # 截断过长的错误消息
  187. if len(error_msg) > 200:
  188. error_msg = error_msg[:200] + "..."
  189. friendly_msg = f"AI 分析失败 ({error_type}): {error_msg}"
  190. return AIAnalysisResult(
  191. success=False,
  192. error=friendly_msg
  193. )
  194. def _prepare_news_content(
  195. self,
  196. stats: List[Dict],
  197. rss_stats: Optional[List[Dict]] = None,
  198. ) -> tuple:
  199. """
  200. 准备新闻内容文本(增强版)
  201. 热榜新闻包含:来源、标题、排名范围、时间范围、出现次数
  202. RSS 包含:来源、标题、发布时间
  203. Returns:
  204. tuple: (news_content, rss_content, hotlist_total, rss_total, analyzed_count)
  205. """
  206. news_lines = []
  207. rss_lines = []
  208. news_count = 0
  209. rss_count = 0
  210. # 计算总新闻数
  211. hotlist_total = sum(len(s.get("titles", [])) for s in stats) if stats else 0
  212. rss_total = sum(len(s.get("titles", [])) for s in rss_stats) if rss_stats else 0
  213. # 热榜内容
  214. if stats:
  215. for stat in stats:
  216. word = stat.get("word", "")
  217. titles = stat.get("titles", [])
  218. if word and titles:
  219. news_lines.append(f"\n**{word}** ({len(titles)}条)")
  220. for t in titles:
  221. if not isinstance(t, dict):
  222. continue
  223. title = t.get("title", "")
  224. if not title:
  225. continue
  226. # 来源
  227. source = t.get("source_name", t.get("source", ""))
  228. # 构建行
  229. if source:
  230. line = f"- [{source}] {title}"
  231. else:
  232. line = f"- {title}"
  233. # 始终显示简化格式:排名范围 + 时间范围 + 出现次数
  234. ranks = t.get("ranks", [])
  235. if ranks:
  236. min_rank = min(ranks)
  237. max_rank = max(ranks)
  238. rank_str = f"{min_rank}" if min_rank == max_rank else f"{min_rank}-{max_rank}"
  239. else:
  240. rank_str = "-"
  241. first_time = t.get("first_time", "")
  242. last_time = t.get("last_time", "")
  243. time_str = self._format_time_range(first_time, last_time)
  244. appear_count = t.get("count", 1)
  245. line += f" | 排名:{rank_str} | 时间:{time_str} | 出现:{appear_count}次"
  246. # 开启完整时间线时,额外添加轨迹
  247. if self.include_rank_timeline:
  248. rank_timeline = t.get("rank_timeline", [])
  249. timeline_str = self._format_rank_timeline(rank_timeline)
  250. line += f" | 轨迹:{timeline_str}"
  251. news_lines.append(line)
  252. news_count += 1
  253. if news_count >= self.max_news:
  254. break
  255. if news_count >= self.max_news:
  256. break
  257. # RSS 内容(仅在启用时构建)
  258. if self.include_rss and rss_stats:
  259. remaining = self.max_news - news_count
  260. for stat in rss_stats:
  261. if rss_count >= remaining:
  262. break
  263. word = stat.get("word", "")
  264. titles = stat.get("titles", [])
  265. if word and titles:
  266. rss_lines.append(f"\n**{word}** ({len(titles)}条)")
  267. for t in titles:
  268. if not isinstance(t, dict):
  269. continue
  270. title = t.get("title", "")
  271. if not title:
  272. continue
  273. # 来源
  274. source = t.get("source_name", t.get("feed_name", ""))
  275. # 发布时间
  276. time_display = t.get("time_display", "")
  277. # 构建行:[来源] 标题 | 发布时间
  278. if source:
  279. line = f"- [{source}] {title}"
  280. else:
  281. line = f"- {title}"
  282. if time_display:
  283. line += f" | {time_display}"
  284. rss_lines.append(line)
  285. rss_count += 1
  286. if rss_count >= remaining:
  287. break
  288. news_content = "\n".join(news_lines) if news_lines else ""
  289. rss_content = "\n".join(rss_lines) if rss_lines else ""
  290. total_count = news_count + rss_count
  291. return news_content, rss_content, hotlist_total, rss_total, total_count
  292. def _call_ai(self, user_prompt: str) -> str:
  293. """调用 AI API(使用 LiteLLM)"""
  294. messages = []
  295. if self.system_prompt:
  296. messages.append({"role": "system", "content": self.system_prompt})
  297. messages.append({"role": "user", "content": user_prompt})
  298. return self.client.chat(messages)
  299. def _retry_fix_json(self, original_response: str, error_msg: str) -> Optional[AIAnalysisResult]:
  300. """
  301. JSON 解析失败时,请求 AI 修复 JSON(仅重试一次)
  302. 使用轻量 prompt,不重复原始分析的 system prompt,节省 token。
  303. Args:
  304. original_response: AI 原始响应(JSON 格式有误)
  305. error_msg: JSON 解析的错误信息
  306. Returns:
  307. 修复后的分析结果,失败时返回 None
  308. """
  309. messages = [
  310. {
  311. "role": "system",
  312. "content": (
  313. "你是一个 JSON 修复助手。用户会提供一段格式有误的 JSON 和错误信息,"
  314. "你需要修复 JSON 格式错误并返回正确的 JSON。\n"
  315. "常见问题:字符串值内的双引号未转义、缺少逗号、字符串未正确闭合等。\n"
  316. "只返回纯 JSON,不要包含 markdown 代码块标记(如 ```json)或任何说明文字。"
  317. ),
  318. },
  319. {
  320. "role": "user",
  321. "content": (
  322. f"以下 JSON 解析失败:\n\n"
  323. f"错误:{error_msg}\n\n"
  324. f"原始内容:\n{original_response}\n\n"
  325. f"请修复以上 JSON 中的格式问题(如值中的双引号改用中文引号「」或转义 \\\"、"
  326. f"缺少逗号、不完整的字符串等),保持原始内容语义不变,只修复格式。"
  327. f"直接返回修复后的纯 JSON。"
  328. ),
  329. },
  330. ]
  331. try:
  332. response = self.client.chat(messages)
  333. return self._parse_response(response)
  334. except Exception as e:
  335. print(f"[AI] 重试修复 JSON 异常: {type(e).__name__}: {e}")
  336. return None
  337. def _format_time_range(self, first_time: str, last_time: str) -> str:
  338. """格式化时间范围(简化显示,只保留时分)"""
  339. def extract_time(time_str: str) -> str:
  340. if not time_str:
  341. return "-"
  342. # 尝试提取 HH:MM 部分
  343. if " " in time_str:
  344. parts = time_str.split(" ")
  345. if len(parts) >= 2:
  346. time_part = parts[1]
  347. if ":" in time_part:
  348. return time_part[:5] # HH:MM
  349. elif ":" in time_str:
  350. return time_str[:5]
  351. # 处理 HH-MM 格式
  352. result = time_str[:5] if len(time_str) >= 5 else time_str
  353. if len(result) == 5 and result[2] == '-':
  354. result = result.replace('-', ':')
  355. return result
  356. first = extract_time(first_time)
  357. last = extract_time(last_time)
  358. if first == last or last == "-":
  359. return first
  360. return f"{first}~{last}"
  361. def _format_rank_timeline(self, rank_timeline: List[Dict]) -> str:
  362. """格式化排名时间线"""
  363. if not rank_timeline:
  364. return "-"
  365. parts = []
  366. for item in rank_timeline:
  367. time_str = item.get("time", "")
  368. if len(time_str) == 5 and time_str[2] == '-':
  369. time_str = time_str.replace('-', ':')
  370. rank = item.get("rank")
  371. if rank is None:
  372. parts.append(f"0({time_str})")
  373. else:
  374. parts.append(f"{rank}({time_str})")
  375. return "→".join(parts)
  376. def _prepare_standalone_content(self, standalone_data: Dict) -> str:
  377. """
  378. 将独立展示区数据转为文本,注入 AI 分析 prompt
  379. Args:
  380. standalone_data: 独立展示区数据 {"platforms": [...], "rss_feeds": [...]}
  381. Returns:
  382. 格式化的文本内容
  383. """
  384. lines = []
  385. # 热榜平台
  386. for platform in standalone_data.get("platforms", []):
  387. platform_id = platform.get("id", "")
  388. platform_name = platform.get("name", platform_id)
  389. items = platform.get("items", [])
  390. if not items:
  391. continue
  392. lines.append(f"### [{platform_name}]")
  393. for item in items:
  394. title = item.get("title", "")
  395. if not title:
  396. continue
  397. line = f"- {title}"
  398. # 排名信息
  399. ranks = item.get("ranks", [])
  400. if ranks:
  401. min_rank = min(ranks)
  402. max_rank = max(ranks)
  403. rank_str = f"{min_rank}" if min_rank == max_rank else f"{min_rank}-{max_rank}"
  404. line += f" | 排名:{rank_str}"
  405. # 时间范围
  406. first_time = item.get("first_time", "")
  407. last_time = item.get("last_time", "")
  408. if first_time:
  409. time_str = self._format_time_range(first_time, last_time)
  410. line += f" | 时间:{time_str}"
  411. # 出现次数
  412. count = item.get("count", 1)
  413. if count > 1:
  414. line += f" | 出现:{count}次"
  415. # 排名轨迹(如果启用)
  416. if self.include_rank_timeline:
  417. rank_timeline = item.get("rank_timeline", [])
  418. if rank_timeline:
  419. timeline_str = self._format_rank_timeline(rank_timeline)
  420. line += f" | 轨迹:{timeline_str}"
  421. lines.append(line)
  422. lines.append("")
  423. # RSS 源
  424. for feed in standalone_data.get("rss_feeds", []):
  425. feed_id = feed.get("id", "")
  426. feed_name = feed.get("name", feed_id)
  427. items = feed.get("items", [])
  428. if not items:
  429. continue
  430. lines.append(f"### [{feed_name}]")
  431. for item in items:
  432. title = item.get("title", "")
  433. if not title:
  434. continue
  435. line = f"- {title}"
  436. published_at = item.get("published_at", "")
  437. if published_at:
  438. line += f" | {published_at}"
  439. lines.append(line)
  440. lines.append("")
  441. return "\n".join(lines)
  442. def _parse_response(self, response: str) -> AIAnalysisResult:
  443. """解析 AI 响应"""
  444. result = AIAnalysisResult(raw_response=response)
  445. if not response or not response.strip():
  446. result.error = "AI 返回空响应"
  447. return result
  448. # 提取 JSON 文本(去掉 markdown 代码块标记)
  449. json_str = response
  450. if "```json" in response:
  451. parts = response.split("```json", 1)
  452. if len(parts) > 1:
  453. code_block = parts[1]
  454. end_idx = code_block.find("```")
  455. if end_idx != -1:
  456. json_str = code_block[:end_idx]
  457. else:
  458. json_str = code_block
  459. elif "```" in response:
  460. parts = response.split("```", 2)
  461. if len(parts) >= 2:
  462. json_str = parts[1]
  463. json_str = json_str.strip()
  464. if not json_str:
  465. result.error = "提取的 JSON 内容为空"
  466. result.core_trends = response[:500] + "..." if len(response) > 500 else response
  467. result.success = True
  468. return result
  469. # 第一步:标准 JSON 解析
  470. data = None
  471. parse_error = None
  472. try:
  473. data = json.loads(json_str)
  474. except json.JSONDecodeError as e:
  475. parse_error = e
  476. # 第二步:json_repair 本地修复
  477. if data is None:
  478. try:
  479. from json_repair import repair_json
  480. repaired = repair_json(json_str, return_objects=True)
  481. if isinstance(repaired, dict):
  482. data = repaired
  483. print("[AI] JSON 本地修复成功(json_repair)")
  484. except Exception:
  485. pass
  486. # 两步都失败,记录错误(后续由 analyze 方法的重试机制处理)
  487. if data is None:
  488. if parse_error:
  489. error_context = json_str[max(0, parse_error.pos - 30):parse_error.pos + 30] if json_str and parse_error.pos else ""
  490. result.error = f"JSON 解析错误 (位置 {parse_error.pos}): {parse_error.msg}"
  491. if error_context:
  492. result.error += f",上下文: ...{error_context}..."
  493. else:
  494. result.error = "JSON 解析失败"
  495. # 兜底:使用已提取的 json_str(不含 markdown 标记),避免推送中出现 ```json
  496. result.core_trends = json_str[:500] + "..." if len(json_str) > 500 else json_str
  497. result.success = True
  498. return result
  499. # 解析成功,提取字段
  500. try:
  501. result.core_trends = data.get("core_trends", "")
  502. result.sentiment_controversy = data.get("sentiment_controversy", "")
  503. result.signals = data.get("signals", "")
  504. result.rss_insights = data.get("rss_insights", "")
  505. result.outlook_strategy = data.get("outlook_strategy", "")
  506. # 解析独立展示区概括
  507. summaries = data.get("standalone_summaries", {})
  508. if isinstance(summaries, dict):
  509. result.standalone_summaries = {
  510. str(k): str(v) for k, v in summaries.items()
  511. }
  512. result.success = True
  513. except (KeyError, TypeError, AttributeError) as e:
  514. result.error = f"字段提取错误: {type(e).__name__}: {e}"
  515. result.core_trends = json_str[:500] + "..." if len(json_str) > 500 else json_str
  516. result.success = True
  517. return result