analyzer.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. # coding=utf-8
  2. """
  3. AI 分析器模块
  4. 调用 AI 大模型对热点新闻进行深度分析
  5. 基于 LiteLLM 统一接口,支持 100+ AI 提供商
  6. """
  7. import json
  8. from dataclasses import dataclass
  9. from pathlib import Path
  10. from typing import Any, Callable, Dict, List, Optional
  11. from trendradar.ai.client import AIClient
  12. @dataclass
  13. class AIAnalysisResult:
  14. """AI 分析结果"""
  15. # 新版 5 核心板块
  16. core_trends: str = "" # 核心热点与舆情态势
  17. sentiment_controversy: str = "" # 舆论风向与争议
  18. signals: str = "" # 异动与弱信号
  19. rss_insights: str = "" # RSS 深度洞察
  20. outlook_strategy: str = "" # 研判与策略建议
  21. # 基础元数据
  22. raw_response: str = "" # 原始响应
  23. success: bool = False # 是否成功
  24. error: str = "" # 错误信息
  25. # 新闻数量统计
  26. total_news: int = 0 # 总新闻数(热榜+RSS)
  27. analyzed_news: int = 0 # 实际分析的新闻数
  28. max_news_limit: int = 0 # 分析上限配置值
  29. hotlist_count: int = 0 # 热榜新闻数
  30. rss_count: int = 0 # RSS 新闻数
  31. class AIAnalyzer:
  32. """AI 分析器"""
  33. def __init__(
  34. self,
  35. ai_config: Dict[str, Any],
  36. analysis_config: Dict[str, Any],
  37. get_time_func: Callable,
  38. debug: bool = False,
  39. ):
  40. """
  41. 初始化 AI 分析器
  42. Args:
  43. ai_config: AI 模型配置(LiteLLM 格式)
  44. analysis_config: AI 分析功能配置(language, prompt_file 等)
  45. get_time_func: 获取当前时间的函数
  46. debug: 是否开启调试模式
  47. """
  48. self.ai_config = ai_config
  49. self.analysis_config = analysis_config
  50. self.get_time_func = get_time_func
  51. self.debug = debug
  52. # 创建 AI 客户端(基于 LiteLLM)
  53. self.client = AIClient(ai_config)
  54. # 验证配置
  55. valid, error = self.client.validate_config()
  56. if not valid:
  57. print(f"[AI] 配置警告: {error}")
  58. # 从分析配置获取功能参数
  59. self.max_news = analysis_config.get("MAX_NEWS_FOR_ANALYSIS", 50)
  60. self.include_rss = analysis_config.get("INCLUDE_RSS", True)
  61. self.include_rank_timeline = analysis_config.get("INCLUDE_RANK_TIMELINE", False)
  62. self.language = analysis_config.get("LANGUAGE", "Chinese")
  63. # 加载提示词模板
  64. self.system_prompt, self.user_prompt_template = self._load_prompt_template(
  65. analysis_config.get("PROMPT_FILE", "ai_analysis_prompt.txt")
  66. )
  67. def _load_prompt_template(self, prompt_file: str) -> tuple:
  68. """加载提示词模板"""
  69. config_dir = Path(__file__).parent.parent.parent / "config"
  70. prompt_path = config_dir / prompt_file
  71. if not prompt_path.exists():
  72. print(f"[AI] 提示词文件不存在: {prompt_path}")
  73. return "", ""
  74. content = prompt_path.read_text(encoding="utf-8")
  75. # 解析 [system] 和 [user] 部分
  76. system_prompt = ""
  77. user_prompt = ""
  78. if "[system]" in content and "[user]" in content:
  79. parts = content.split("[user]")
  80. system_part = parts[0]
  81. user_part = parts[1] if len(parts) > 1 else ""
  82. # 提取 system 内容
  83. if "[system]" in system_part:
  84. system_prompt = system_part.split("[system]")[1].strip()
  85. user_prompt = user_part.strip()
  86. else:
  87. # 整个文件作为 user prompt
  88. user_prompt = content
  89. return system_prompt, user_prompt
  90. def analyze(
  91. self,
  92. stats: List[Dict],
  93. rss_stats: Optional[List[Dict]] = None,
  94. report_mode: str = "daily",
  95. report_type: str = "当日汇总",
  96. platforms: Optional[List[str]] = None,
  97. keywords: Optional[List[str]] = None,
  98. ) -> AIAnalysisResult:
  99. """
  100. 执行 AI 分析
  101. Args:
  102. stats: 热榜统计数据
  103. rss_stats: RSS 统计数据
  104. report_mode: 报告模式
  105. report_type: 报告类型
  106. platforms: 平台列表
  107. keywords: 关键词列表
  108. Returns:
  109. AIAnalysisResult: 分析结果
  110. """
  111. if not self.client.api_key:
  112. return AIAnalysisResult(
  113. success=False,
  114. error="未配置 AI API Key,请在 config.yaml 或环境变量 AI_API_KEY 中设置"
  115. )
  116. # 准备新闻内容并获取统计数据
  117. news_content, rss_content, hotlist_total, rss_total, analyzed_count = self._prepare_news_content(stats, rss_stats)
  118. total_news = hotlist_total + rss_total
  119. if not news_content and not rss_content:
  120. return AIAnalysisResult(
  121. success=False,
  122. error="没有可分析的新闻内容",
  123. total_news=total_news,
  124. hotlist_count=hotlist_total,
  125. rss_count=rss_total,
  126. analyzed_news=0,
  127. max_news_limit=self.max_news
  128. )
  129. # 构建提示词
  130. current_time = self.get_time_func().strftime("%Y-%m-%d %H:%M:%S")
  131. # 提取关键词
  132. if not keywords:
  133. keywords = [s.get("word", "") for s in stats if s.get("word")] if stats else []
  134. # 使用安全的字符串替换,避免模板中其他花括号(如 JSON 示例)被误解析
  135. user_prompt = self.user_prompt_template
  136. user_prompt = user_prompt.replace("{report_mode}", report_mode)
  137. user_prompt = user_prompt.replace("{report_type}", report_type)
  138. user_prompt = user_prompt.replace("{current_time}", current_time)
  139. user_prompt = user_prompt.replace("{news_count}", str(hotlist_total))
  140. user_prompt = user_prompt.replace("{rss_count}", str(rss_total))
  141. user_prompt = user_prompt.replace("{platforms}", ", ".join(platforms) if platforms else "多平台")
  142. user_prompt = user_prompt.replace("{keywords}", ", ".join(keywords[:20]) if keywords else "无")
  143. user_prompt = user_prompt.replace("{news_content}", news_content)
  144. user_prompt = user_prompt.replace("{rss_content}", rss_content)
  145. user_prompt = user_prompt.replace("{language}", self.language)
  146. if self.debug:
  147. print("\n" + "=" * 80)
  148. print("[AI 调试] 发送给 AI 的完整提示词")
  149. print("=" * 80)
  150. if self.system_prompt:
  151. print("\n--- System Prompt ---")
  152. print(self.system_prompt)
  153. print("\n--- User Prompt ---")
  154. print(user_prompt)
  155. print("=" * 80 + "\n")
  156. # 调用 AI API(使用 LiteLLM)
  157. try:
  158. response = self._call_ai(user_prompt)
  159. result = self._parse_response(response)
  160. # 如果配置未启用 RSS 分析,强制清空 AI 返回的 RSS 洞察
  161. if not self.include_rss:
  162. result.rss_insights = ""
  163. # 填充统计数据
  164. result.total_news = total_news
  165. result.hotlist_count = hotlist_total
  166. result.rss_count = rss_total
  167. result.analyzed_news = analyzed_count
  168. result.max_news_limit = self.max_news
  169. return result
  170. except Exception as e:
  171. error_type = type(e).__name__
  172. error_msg = str(e)
  173. # 截断过长的错误消息
  174. if len(error_msg) > 200:
  175. error_msg = error_msg[:200] + "..."
  176. friendly_msg = f"AI 分析失败 ({error_type}): {error_msg}"
  177. return AIAnalysisResult(
  178. success=False,
  179. error=friendly_msg
  180. )
  181. def _prepare_news_content(
  182. self,
  183. stats: List[Dict],
  184. rss_stats: Optional[List[Dict]] = None,
  185. ) -> tuple:
  186. """
  187. 准备新闻内容文本(增强版)
  188. 热榜新闻包含:来源、标题、排名范围、时间范围、出现次数
  189. RSS 包含:来源、标题、发布时间
  190. Returns:
  191. tuple: (news_content, rss_content, hotlist_total, rss_total, analyzed_count)
  192. """
  193. news_lines = []
  194. rss_lines = []
  195. news_count = 0
  196. rss_count = 0
  197. # 计算总新闻数
  198. hotlist_total = sum(len(s.get("titles", [])) for s in stats) if stats else 0
  199. rss_total = sum(len(s.get("titles", [])) for s in rss_stats) if rss_stats else 0
  200. # 热榜内容
  201. if stats:
  202. for stat in stats:
  203. word = stat.get("word", "")
  204. titles = stat.get("titles", [])
  205. if word and titles:
  206. news_lines.append(f"\n**{word}** ({len(titles)}条)")
  207. for t in titles:
  208. if not isinstance(t, dict):
  209. continue
  210. title = t.get("title", "")
  211. if not title:
  212. continue
  213. # 来源
  214. source = t.get("source_name", t.get("source", ""))
  215. # 构建行
  216. if source:
  217. line = f"- [{source}] {title}"
  218. else:
  219. line = f"- {title}"
  220. # 始终显示简化格式:排名范围 + 时间范围 + 出现次数
  221. ranks = t.get("ranks", [])
  222. if ranks:
  223. min_rank = min(ranks)
  224. max_rank = max(ranks)
  225. rank_str = f"{min_rank}" if min_rank == max_rank else f"{min_rank}-{max_rank}"
  226. else:
  227. rank_str = "-"
  228. first_time = t.get("first_time", "")
  229. last_time = t.get("last_time", "")
  230. time_str = self._format_time_range(first_time, last_time)
  231. appear_count = t.get("count", 1)
  232. line += f" | 排名:{rank_str} | 时间:{time_str} | 出现:{appear_count}次"
  233. # 开启完整时间线时,额外添加轨迹
  234. if self.include_rank_timeline:
  235. rank_timeline = t.get("rank_timeline", [])
  236. timeline_str = self._format_rank_timeline(rank_timeline)
  237. line += f" | 轨迹:{timeline_str}"
  238. news_lines.append(line)
  239. news_count += 1
  240. if news_count >= self.max_news:
  241. break
  242. if news_count >= self.max_news:
  243. break
  244. # RSS 内容(仅在启用时构建)
  245. if self.include_rss and rss_stats:
  246. remaining = self.max_news - news_count
  247. for stat in rss_stats:
  248. if rss_count >= remaining:
  249. break
  250. word = stat.get("word", "")
  251. titles = stat.get("titles", [])
  252. if word and titles:
  253. rss_lines.append(f"\n**{word}** ({len(titles)}条)")
  254. for t in titles:
  255. if not isinstance(t, dict):
  256. continue
  257. title = t.get("title", "")
  258. if not title:
  259. continue
  260. # 来源
  261. source = t.get("source_name", t.get("feed_name", ""))
  262. # 发布时间
  263. time_display = t.get("time_display", "")
  264. # 构建行:[来源] 标题 | 发布时间
  265. if source:
  266. line = f"- [{source}] {title}"
  267. else:
  268. line = f"- {title}"
  269. if time_display:
  270. line += f" | {time_display}"
  271. rss_lines.append(line)
  272. rss_count += 1
  273. if rss_count >= remaining:
  274. break
  275. news_content = "\n".join(news_lines) if news_lines else ""
  276. rss_content = "\n".join(rss_lines) if rss_lines else ""
  277. total_count = news_count + rss_count
  278. return news_content, rss_content, hotlist_total, rss_total, total_count
  279. def _call_ai(self, user_prompt: str) -> str:
  280. """调用 AI API(使用 LiteLLM)"""
  281. messages = []
  282. if self.system_prompt:
  283. messages.append({"role": "system", "content": self.system_prompt})
  284. messages.append({"role": "user", "content": user_prompt})
  285. return self.client.chat(messages)
  286. def _format_time_range(self, first_time: str, last_time: str) -> str:
  287. """格式化时间范围(简化显示,只保留时分)"""
  288. def extract_time(time_str: str) -> str:
  289. if not time_str:
  290. return "-"
  291. # 尝试提取 HH:MM 部分
  292. if " " in time_str:
  293. parts = time_str.split(" ")
  294. if len(parts) >= 2:
  295. time_part = parts[1]
  296. if ":" in time_part:
  297. return time_part[:5] # HH:MM
  298. elif ":" in time_str:
  299. return time_str[:5]
  300. # 处理 HH-MM 格式
  301. result = time_str[:5] if len(time_str) >= 5 else time_str
  302. if len(result) == 5 and result[2] == '-':
  303. result = result.replace('-', ':')
  304. return result
  305. first = extract_time(first_time)
  306. last = extract_time(last_time)
  307. if first == last or last == "-":
  308. return first
  309. return f"{first}~{last}"
  310. def _format_rank_timeline(self, rank_timeline: List[Dict]) -> str:
  311. """格式化排名时间线"""
  312. if not rank_timeline:
  313. return "-"
  314. parts = []
  315. for item in rank_timeline:
  316. time_str = item.get("time", "")
  317. if len(time_str) == 5 and time_str[2] == '-':
  318. time_str = time_str.replace('-', ':')
  319. rank = item.get("rank")
  320. if rank is None:
  321. parts.append(f"0({time_str})")
  322. else:
  323. parts.append(f"{rank}({time_str})")
  324. return "→".join(parts)
  325. def _parse_response(self, response: str) -> AIAnalysisResult:
  326. """解析 AI 响应"""
  327. result = AIAnalysisResult(raw_response=response)
  328. if not response or not response.strip():
  329. result.error = "AI 返回空响应"
  330. return result
  331. # 尝试解析 JSON
  332. try:
  333. # 提取 JSON 部分
  334. json_str = response
  335. if "```json" in response:
  336. parts = response.split("```json", 1)
  337. if len(parts) > 1:
  338. code_block = parts[1]
  339. end_idx = code_block.find("```")
  340. if end_idx != -1:
  341. json_str = code_block[:end_idx]
  342. else:
  343. json_str = code_block
  344. elif "```" in response:
  345. parts = response.split("```", 2)
  346. if len(parts) >= 2:
  347. json_str = parts[1]
  348. json_str = json_str.strip()
  349. if not json_str:
  350. raise ValueError("提取的 JSON 内容为空")
  351. data = json.loads(json_str)
  352. # 新版字段解析
  353. result.core_trends = data.get("core_trends", "")
  354. result.sentiment_controversy = data.get("sentiment_controversy", "")
  355. result.signals = data.get("signals", "")
  356. result.rss_insights = data.get("rss_insights", "")
  357. result.outlook_strategy = data.get("outlook_strategy", "")
  358. result.success = True
  359. except json.JSONDecodeError as e:
  360. error_context = json_str[max(0, e.pos - 30):e.pos + 30] if json_str and e.pos else ""
  361. result.error = f"JSON 解析错误 (位置 {e.pos}): {e.msg}"
  362. if error_context:
  363. result.error += f",上下文: ...{error_context}..."
  364. # 使用原始响应填充 core_trends,确保有输出
  365. result.core_trends = response[:500] + "..." if len(response) > 500 else response
  366. result.success = True
  367. except (IndexError, KeyError, TypeError, ValueError) as e:
  368. result.error = f"响应解析错误: {type(e).__name__}: {str(e)}"
  369. result.core_trends = response[:500] if len(response) > 500 else response
  370. result.success = True
  371. except Exception as e:
  372. result.error = f"解析时发生未知错误: {type(e).__name__}: {str(e)}"
  373. result.core_trends = response[:500] if len(response) > 500 else response
  374. result.success = True
  375. return result