analyzer.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. # coding=utf-8
  2. """
  3. AI 分析器模块
  4. 调用 AI 大模型对热点新闻进行深度分析
  5. 基于 LiteLLM 统一接口,支持 100+ AI 提供商
  6. """
  7. import json
  8. from dataclasses import dataclass, field
  9. from pathlib import Path
  10. from typing import Any, Callable, Dict, List, Optional
  11. from trendradar.ai.client import AIClient
  12. @dataclass
  13. class AIAnalysisResult:
  14. """AI 分析结果"""
  15. # 新版 5 核心板块
  16. core_trends: str = "" # 核心热点与舆情态势
  17. sentiment_controversy: str = "" # 舆论风向与争议
  18. signals: str = "" # 异动与弱信号
  19. rss_insights: str = "" # RSS 深度洞察
  20. outlook_strategy: str = "" # 研判与策略建议
  21. standalone_summaries: Dict[str, str] = field(default_factory=dict) # 独立展示区概括 {源ID: 概括}
  22. # 基础元数据
  23. raw_response: str = "" # 原始响应
  24. success: bool = False # 是否成功
  25. error: str = "" # 错误信息
  26. # 新闻数量统计
  27. total_news: int = 0 # 总新闻数(热榜+RSS)
  28. analyzed_news: int = 0 # 实际分析的新闻数
  29. max_news_limit: int = 0 # 分析上限配置值
  30. hotlist_count: int = 0 # 热榜新闻数
  31. rss_count: int = 0 # RSS 新闻数
  32. ai_mode: str = "" # AI 分析使用的模式 (daily/current/incremental)
  33. class AIAnalyzer:
  34. """AI 分析器"""
  35. def __init__(
  36. self,
  37. ai_config: Dict[str, Any],
  38. analysis_config: Dict[str, Any],
  39. get_time_func: Callable,
  40. debug: bool = False,
  41. ):
  42. """
  43. 初始化 AI 分析器
  44. Args:
  45. ai_config: AI 模型配置(LiteLLM 格式)
  46. analysis_config: AI 分析功能配置(language, prompt_file 等)
  47. get_time_func: 获取当前时间的函数
  48. debug: 是否开启调试模式
  49. """
  50. self.ai_config = ai_config
  51. self.analysis_config = analysis_config
  52. self.get_time_func = get_time_func
  53. self.debug = debug
  54. # 创建 AI 客户端(基于 LiteLLM)
  55. self.client = AIClient(ai_config)
  56. # 验证配置
  57. valid, error = self.client.validate_config()
  58. if not valid:
  59. print(f"[AI] 配置警告: {error}")
  60. # 从分析配置获取功能参数
  61. self.max_news = analysis_config.get("MAX_NEWS_FOR_ANALYSIS", 50)
  62. self.include_rss = analysis_config.get("INCLUDE_RSS", True)
  63. self.include_rank_timeline = analysis_config.get("INCLUDE_RANK_TIMELINE", False)
  64. self.include_standalone = analysis_config.get("INCLUDE_STANDALONE", False)
  65. self.language = analysis_config.get("LANGUAGE", "Chinese")
  66. # 加载提示词模板
  67. self.system_prompt, self.user_prompt_template = self._load_prompt_template(
  68. analysis_config.get("PROMPT_FILE", "ai_analysis_prompt.txt")
  69. )
  70. def _load_prompt_template(self, prompt_file: str) -> tuple:
  71. """加载提示词模板"""
  72. config_dir = Path(__file__).parent.parent.parent / "config"
  73. prompt_path = config_dir / prompt_file
  74. if not prompt_path.exists():
  75. print(f"[AI] 提示词文件不存在: {prompt_path}")
  76. return "", ""
  77. content = prompt_path.read_text(encoding="utf-8")
  78. # 解析 [system] 和 [user] 部分
  79. system_prompt = ""
  80. user_prompt = ""
  81. if "[system]" in content and "[user]" in content:
  82. parts = content.split("[user]")
  83. system_part = parts[0]
  84. user_part = parts[1] if len(parts) > 1 else ""
  85. # 提取 system 内容
  86. if "[system]" in system_part:
  87. system_prompt = system_part.split("[system]")[1].strip()
  88. user_prompt = user_part.strip()
  89. else:
  90. # 整个文件作为 user prompt
  91. user_prompt = content
  92. return system_prompt, user_prompt
  93. def analyze(
  94. self,
  95. stats: List[Dict],
  96. rss_stats: Optional[List[Dict]] = None,
  97. report_mode: str = "daily",
  98. report_type: str = "当日汇总",
  99. platforms: Optional[List[str]] = None,
  100. keywords: Optional[List[str]] = None,
  101. standalone_data: Optional[Dict] = None,
  102. ) -> AIAnalysisResult:
  103. """
  104. 执行 AI 分析
  105. Args:
  106. stats: 热榜统计数据
  107. rss_stats: RSS 统计数据
  108. report_mode: 报告模式
  109. report_type: 报告类型
  110. platforms: 平台列表
  111. keywords: 关键词列表
  112. Returns:
  113. AIAnalysisResult: 分析结果
  114. """
  115. # 打印配置信息方便调试
  116. model = self.ai_config.get("MODEL", "unknown")
  117. api_key = self.client.api_key or ""
  118. api_base = self.ai_config.get("API_BASE", "")
  119. masked_key = f"{api_key[:5]}******" if len(api_key) >= 5 else "******"
  120. model_display = model.replace("/", "/\u200b") if model else "unknown"
  121. print(f"[AI] 模型: {model_display}")
  122. print(f"[AI] Key : {masked_key}")
  123. if api_base:
  124. print(f"[AI] 接口: 存在自定义 API 端点")
  125. timeout = self.ai_config.get("TIMEOUT", 120)
  126. max_tokens = self.ai_config.get("MAX_TOKENS", 5000)
  127. print(f"[AI] 参数: timeout={timeout}, max_tokens={max_tokens}")
  128. if not self.client.api_key:
  129. return AIAnalysisResult(
  130. success=False,
  131. error="未配置 AI API Key,请在 config.yaml 或环境变量 AI_API_KEY 中设置"
  132. )
  133. # 准备新闻内容并获取统计数据
  134. news_content, rss_content, hotlist_total, rss_total, analyzed_count = self._prepare_news_content(stats, rss_stats)
  135. total_news = hotlist_total + rss_total
  136. if not news_content and not rss_content:
  137. return AIAnalysisResult(
  138. success=False,
  139. error="没有可分析的新闻内容",
  140. total_news=total_news,
  141. hotlist_count=hotlist_total,
  142. rss_count=rss_total,
  143. analyzed_news=0,
  144. max_news_limit=self.max_news
  145. )
  146. # 构建提示词
  147. current_time = self.get_time_func().strftime("%Y-%m-%d %H:%M:%S")
  148. # 提取关键词
  149. if not keywords:
  150. keywords = [s.get("word", "") for s in stats if s.get("word")] if stats else []
  151. # 使用安全的字符串替换,避免模板中其他花括号(如 JSON 示例)被误解析
  152. user_prompt = self.user_prompt_template
  153. user_prompt = user_prompt.replace("{report_mode}", report_mode)
  154. user_prompt = user_prompt.replace("{report_type}", report_type)
  155. user_prompt = user_prompt.replace("{current_time}", current_time)
  156. user_prompt = user_prompt.replace("{news_count}", str(hotlist_total))
  157. user_prompt = user_prompt.replace("{rss_count}", str(rss_total))
  158. user_prompt = user_prompt.replace("{platforms}", ", ".join(platforms) if platforms else "多平台")
  159. user_prompt = user_prompt.replace("{keywords}", ", ".join(keywords[:20]) if keywords else "无")
  160. user_prompt = user_prompt.replace("{news_content}", news_content)
  161. user_prompt = user_prompt.replace("{rss_content}", rss_content)
  162. user_prompt = user_prompt.replace("{language}", self.language)
  163. # 构建独立展示区内容
  164. standalone_content = ""
  165. if self.include_standalone and standalone_data:
  166. standalone_content = self._prepare_standalone_content(standalone_data)
  167. user_prompt = user_prompt.replace("{standalone_content}", standalone_content)
  168. if self.debug:
  169. print("\n" + "=" * 80)
  170. print("[AI 调试] 发送给 AI 的完整提示词")
  171. print("=" * 80)
  172. if self.system_prompt:
  173. print("\n--- System Prompt ---")
  174. print(self.system_prompt)
  175. print("\n--- User Prompt ---")
  176. print(user_prompt)
  177. print("=" * 80 + "\n")
  178. # 调用 AI API(使用 LiteLLM)
  179. try:
  180. response = self._call_ai(user_prompt)
  181. result = self._parse_response(response)
  182. # 如果配置未启用 RSS 分析,强制清空 AI 返回的 RSS 洞察
  183. if not self.include_rss:
  184. result.rss_insights = ""
  185. # 如果配置未启用 standalone 分析,强制清空
  186. if not self.include_standalone:
  187. result.standalone_summaries = {}
  188. # 填充统计数据
  189. result.total_news = total_news
  190. result.hotlist_count = hotlist_total
  191. result.rss_count = rss_total
  192. result.analyzed_news = analyzed_count
  193. result.max_news_limit = self.max_news
  194. return result
  195. except Exception as e:
  196. error_type = type(e).__name__
  197. error_msg = str(e)
  198. # 截断过长的错误消息
  199. if len(error_msg) > 200:
  200. error_msg = error_msg[:200] + "..."
  201. friendly_msg = f"AI 分析失败 ({error_type}): {error_msg}"
  202. return AIAnalysisResult(
  203. success=False,
  204. error=friendly_msg
  205. )
  206. def _prepare_news_content(
  207. self,
  208. stats: List[Dict],
  209. rss_stats: Optional[List[Dict]] = None,
  210. ) -> tuple:
  211. """
  212. 准备新闻内容文本(增强版)
  213. 热榜新闻包含:来源、标题、排名范围、时间范围、出现次数
  214. RSS 包含:来源、标题、发布时间
  215. Returns:
  216. tuple: (news_content, rss_content, hotlist_total, rss_total, analyzed_count)
  217. """
  218. news_lines = []
  219. rss_lines = []
  220. news_count = 0
  221. rss_count = 0
  222. # 计算总新闻数
  223. hotlist_total = sum(len(s.get("titles", [])) for s in stats) if stats else 0
  224. rss_total = sum(len(s.get("titles", [])) for s in rss_stats) if rss_stats else 0
  225. # 热榜内容
  226. if stats:
  227. for stat in stats:
  228. word = stat.get("word", "")
  229. titles = stat.get("titles", [])
  230. if word and titles:
  231. news_lines.append(f"\n**{word}** ({len(titles)}条)")
  232. for t in titles:
  233. if not isinstance(t, dict):
  234. continue
  235. title = t.get("title", "")
  236. if not title:
  237. continue
  238. # 来源
  239. source = t.get("source_name", t.get("source", ""))
  240. # 构建行
  241. if source:
  242. line = f"- [{source}] {title}"
  243. else:
  244. line = f"- {title}"
  245. # 始终显示简化格式:排名范围 + 时间范围 + 出现次数
  246. ranks = t.get("ranks", [])
  247. if ranks:
  248. min_rank = min(ranks)
  249. max_rank = max(ranks)
  250. rank_str = f"{min_rank}" if min_rank == max_rank else f"{min_rank}-{max_rank}"
  251. else:
  252. rank_str = "-"
  253. first_time = t.get("first_time", "")
  254. last_time = t.get("last_time", "")
  255. time_str = self._format_time_range(first_time, last_time)
  256. appear_count = t.get("count", 1)
  257. line += f" | 排名:{rank_str} | 时间:{time_str} | 出现:{appear_count}次"
  258. # 开启完整时间线时,额外添加轨迹
  259. if self.include_rank_timeline:
  260. rank_timeline = t.get("rank_timeline", [])
  261. timeline_str = self._format_rank_timeline(rank_timeline)
  262. line += f" | 轨迹:{timeline_str}"
  263. news_lines.append(line)
  264. news_count += 1
  265. if news_count >= self.max_news:
  266. break
  267. if news_count >= self.max_news:
  268. break
  269. # RSS 内容(仅在启用时构建)
  270. if self.include_rss and rss_stats:
  271. remaining = self.max_news - news_count
  272. for stat in rss_stats:
  273. if rss_count >= remaining:
  274. break
  275. word = stat.get("word", "")
  276. titles = stat.get("titles", [])
  277. if word and titles:
  278. rss_lines.append(f"\n**{word}** ({len(titles)}条)")
  279. for t in titles:
  280. if not isinstance(t, dict):
  281. continue
  282. title = t.get("title", "")
  283. if not title:
  284. continue
  285. # 来源
  286. source = t.get("source_name", t.get("feed_name", ""))
  287. # 发布时间
  288. time_display = t.get("time_display", "")
  289. # 构建行:[来源] 标题 | 发布时间
  290. if source:
  291. line = f"- [{source}] {title}"
  292. else:
  293. line = f"- {title}"
  294. if time_display:
  295. line += f" | {time_display}"
  296. rss_lines.append(line)
  297. rss_count += 1
  298. if rss_count >= remaining:
  299. break
  300. news_content = "\n".join(news_lines) if news_lines else ""
  301. rss_content = "\n".join(rss_lines) if rss_lines else ""
  302. total_count = news_count + rss_count
  303. return news_content, rss_content, hotlist_total, rss_total, total_count
  304. def _call_ai(self, user_prompt: str) -> str:
  305. """调用 AI API(使用 LiteLLM)"""
  306. messages = []
  307. if self.system_prompt:
  308. messages.append({"role": "system", "content": self.system_prompt})
  309. messages.append({"role": "user", "content": user_prompt})
  310. return self.client.chat(messages)
  311. def _format_time_range(self, first_time: str, last_time: str) -> str:
  312. """格式化时间范围(简化显示,只保留时分)"""
  313. def extract_time(time_str: str) -> str:
  314. if not time_str:
  315. return "-"
  316. # 尝试提取 HH:MM 部分
  317. if " " in time_str:
  318. parts = time_str.split(" ")
  319. if len(parts) >= 2:
  320. time_part = parts[1]
  321. if ":" in time_part:
  322. return time_part[:5] # HH:MM
  323. elif ":" in time_str:
  324. return time_str[:5]
  325. # 处理 HH-MM 格式
  326. result = time_str[:5] if len(time_str) >= 5 else time_str
  327. if len(result) == 5 and result[2] == '-':
  328. result = result.replace('-', ':')
  329. return result
  330. first = extract_time(first_time)
  331. last = extract_time(last_time)
  332. if first == last or last == "-":
  333. return first
  334. return f"{first}~{last}"
  335. def _format_rank_timeline(self, rank_timeline: List[Dict]) -> str:
  336. """格式化排名时间线"""
  337. if not rank_timeline:
  338. return "-"
  339. parts = []
  340. for item in rank_timeline:
  341. time_str = item.get("time", "")
  342. if len(time_str) == 5 and time_str[2] == '-':
  343. time_str = time_str.replace('-', ':')
  344. rank = item.get("rank")
  345. if rank is None:
  346. parts.append(f"0({time_str})")
  347. else:
  348. parts.append(f"{rank}({time_str})")
  349. return "→".join(parts)
  350. def _prepare_standalone_content(self, standalone_data: Dict) -> str:
  351. """
  352. 将独立展示区数据转为文本,注入 AI 分析 prompt
  353. Args:
  354. standalone_data: 独立展示区数据 {"platforms": [...], "rss_feeds": [...]}
  355. Returns:
  356. 格式化的文本内容
  357. """
  358. lines = []
  359. # 热榜平台
  360. for platform in standalone_data.get("platforms", []):
  361. platform_id = platform.get("id", "")
  362. platform_name = platform.get("name", platform_id)
  363. items = platform.get("items", [])
  364. if not items:
  365. continue
  366. lines.append(f"### [{platform_name}]")
  367. for item in items:
  368. title = item.get("title", "")
  369. if not title:
  370. continue
  371. line = f"- {title}"
  372. # 排名信息
  373. ranks = item.get("ranks", [])
  374. if ranks:
  375. min_rank = min(ranks)
  376. max_rank = max(ranks)
  377. rank_str = f"{min_rank}" if min_rank == max_rank else f"{min_rank}-{max_rank}"
  378. line += f" | 排名:{rank_str}"
  379. # 时间范围
  380. first_time = item.get("first_time", "")
  381. last_time = item.get("last_time", "")
  382. if first_time:
  383. time_str = self._format_time_range(first_time, last_time)
  384. line += f" | 时间:{time_str}"
  385. # 出现次数
  386. count = item.get("count", 1)
  387. if count > 1:
  388. line += f" | 出现:{count}次"
  389. # 排名轨迹(如果启用)
  390. if self.include_rank_timeline:
  391. rank_timeline = item.get("rank_timeline", [])
  392. if rank_timeline:
  393. timeline_str = self._format_rank_timeline(rank_timeline)
  394. line += f" | 轨迹:{timeline_str}"
  395. lines.append(line)
  396. lines.append("")
  397. # RSS 源
  398. for feed in standalone_data.get("rss_feeds", []):
  399. feed_id = feed.get("id", "")
  400. feed_name = feed.get("name", feed_id)
  401. items = feed.get("items", [])
  402. if not items:
  403. continue
  404. lines.append(f"### [{feed_name}]")
  405. for item in items:
  406. title = item.get("title", "")
  407. if not title:
  408. continue
  409. line = f"- {title}"
  410. published_at = item.get("published_at", "")
  411. if published_at:
  412. line += f" | {published_at}"
  413. lines.append(line)
  414. lines.append("")
  415. return "\n".join(lines)
  416. def _parse_response(self, response: str) -> AIAnalysisResult:
  417. """解析 AI 响应"""
  418. result = AIAnalysisResult(raw_response=response)
  419. if not response or not response.strip():
  420. result.error = "AI 返回空响应"
  421. return result
  422. try:
  423. json_str = response
  424. if "```json" in response:
  425. parts = response.split("```json", 1)
  426. if len(parts) > 1:
  427. code_block = parts[1]
  428. end_idx = code_block.find("```")
  429. if end_idx != -1:
  430. json_str = code_block[:end_idx]
  431. else:
  432. json_str = code_block
  433. elif "```" in response:
  434. parts = response.split("```", 2)
  435. if len(parts) >= 2:
  436. json_str = parts[1]
  437. json_str = json_str.strip()
  438. if not json_str:
  439. raise ValueError("提取的 JSON 内容为空")
  440. data = json.loads(json_str)
  441. # 新版字段解析
  442. result.core_trends = data.get("core_trends", "")
  443. result.sentiment_controversy = data.get("sentiment_controversy", "")
  444. result.signals = data.get("signals", "")
  445. result.rss_insights = data.get("rss_insights", "")
  446. result.outlook_strategy = data.get("outlook_strategy", "")
  447. # 解析独立展示区概括
  448. summaries = data.get("standalone_summaries", {})
  449. if isinstance(summaries, dict):
  450. result.standalone_summaries = {
  451. str(k): str(v) for k, v in summaries.items()
  452. }
  453. result.success = True
  454. except json.JSONDecodeError as e:
  455. error_context = json_str[max(0, e.pos - 30):e.pos + 30] if json_str and e.pos else ""
  456. result.error = f"JSON 解析错误 (位置 {e.pos}): {e.msg}"
  457. if error_context:
  458. result.error += f",上下文: ...{error_context}..."
  459. # 使用原始响应填充 core_trends,确保有输出
  460. result.core_trends = response[:500] + "..." if len(response) > 500 else response
  461. result.success = True
  462. except (IndexError, KeyError, TypeError, ValueError) as e:
  463. result.error = f"响应解析错误: {type(e).__name__}: {str(e)}"
  464. result.core_trends = response[:500] if len(response) > 500 else response
  465. result.success = True
  466. except Exception as e:
  467. result.error = f"解析时发生未知错误: {type(e).__name__}: {str(e)}"
  468. result.core_trends = response[:500] if len(response) > 500 else response
  469. result.success = True
  470. return result