analyzer.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581
  1. # coding=utf-8
  2. """
  3. AI 分析器模块
  4. 调用 AI 大模型对热点新闻进行深度分析
  5. 支持 OpenAI、Google Gemini、Azure OpenAI 等兼容接口
  6. """
  7. import json
  8. import os
  9. from dataclasses import dataclass
  10. from pathlib import Path
  11. from typing import Any, Callable, Dict, List, Optional
  12. @dataclass
  13. class AIAnalysisResult:
  14. """AI 分析结果"""
  15. # 新版 5 核心板块
  16. core_trends: str = "" # 核心热点与舆情态势
  17. sentiment_controversy: str = "" # 舆论风向与争议
  18. signals: str = "" # 异动与弱信号
  19. rss_insights: str = "" # RSS 深度洞察
  20. outlook_strategy: str = "" # 研判与策略建议
  21. # 基础元数据
  22. raw_response: str = "" # 原始响应
  23. success: bool = False # 是否成功
  24. error: str = "" # 错误信息
  25. # 新闻数量统计
  26. total_news: int = 0 # 总新闻数(热榜+RSS)
  27. analyzed_news: int = 0 # 实际分析的新闻数
  28. max_news_limit: int = 0 # 分析上限配置值
  29. hotlist_count: int = 0 # 热榜新闻数
  30. rss_count: int = 0 # RSS 新闻数
  31. class AIAnalyzer:
  32. """AI 分析器"""
  33. def __init__(
  34. self,
  35. ai_config: Dict[str, Any],
  36. analysis_config: Dict[str, Any],
  37. get_time_func: Callable,
  38. debug: bool = False,
  39. ):
  40. """
  41. 初始化 AI 分析器
  42. Args:
  43. ai_config: AI 模型共享配置(provider, api_key, model 等)
  44. analysis_config: AI 分析功能配置(language, prompt_file 等)
  45. get_time_func: 获取当前时间的函数
  46. debug: 是否开启调试模式
  47. """
  48. self.ai_config = ai_config
  49. self.analysis_config = analysis_config
  50. self.get_time_func = get_time_func
  51. self.debug = debug
  52. # 从共享配置获取模型参数
  53. self.api_key = ai_config.get("API_KEY") or os.environ.get("AI_API_KEY", "")
  54. self.provider = ai_config.get("PROVIDER", "deepseek")
  55. self.model = ai_config.get("MODEL", "deepseek-chat")
  56. self.base_url = ai_config.get("BASE_URL", "")
  57. self.timeout = ai_config.get("TIMEOUT", 90)
  58. self.temperature = ai_config.get("TEMPERATURE", 1.0)
  59. self.max_tokens = ai_config.get("MAX_TOKENS", 5000)
  60. # 从分析配置获取功能参数
  61. self.max_news = analysis_config.get("MAX_NEWS_FOR_ANALYSIS", 50)
  62. self.include_rss = analysis_config.get("INCLUDE_RSS", True)
  63. self.include_rank_timeline = analysis_config.get("INCLUDE_RANK_TIMELINE", False)
  64. self.language = analysis_config.get("LANGUAGE", "Chinese")
  65. # 额外的自定义参数(支持字典或 JSON 字符串)
  66. self.extra_params = ai_config.get("EXTRA_PARAMS", {})
  67. if isinstance(self.extra_params, str) and self.extra_params.strip():
  68. try:
  69. self.extra_params = json.loads(self.extra_params)
  70. except json.JSONDecodeError:
  71. print(f"[AI] 解析 extra_params 失败,将忽略: {self.extra_params}")
  72. self.extra_params = {}
  73. if not isinstance(self.extra_params, dict):
  74. self.extra_params = {}
  75. # 加载提示词模板
  76. self.system_prompt, self.user_prompt_template = self._load_prompt_template(
  77. analysis_config.get("PROMPT_FILE", "ai_analysis_prompt.txt")
  78. )
  79. def _load_prompt_template(self, prompt_file: str) -> tuple:
  80. """加载提示词模板"""
  81. config_dir = Path(__file__).parent.parent.parent / "config"
  82. prompt_path = config_dir / prompt_file
  83. if not prompt_path.exists():
  84. print(f"[AI] 提示词文件不存在: {prompt_path}")
  85. return "", ""
  86. content = prompt_path.read_text(encoding="utf-8")
  87. # 解析 [system] 和 [user] 部分
  88. system_prompt = ""
  89. user_prompt = ""
  90. if "[system]" in content and "[user]" in content:
  91. parts = content.split("[user]")
  92. system_part = parts[0]
  93. user_part = parts[1] if len(parts) > 1 else ""
  94. # 提取 system 内容
  95. if "[system]" in system_part:
  96. system_prompt = system_part.split("[system]")[1].strip()
  97. user_prompt = user_part.strip()
  98. else:
  99. # 整个文件作为 user prompt
  100. user_prompt = content
  101. return system_prompt, user_prompt
  102. def analyze(
  103. self,
  104. stats: List[Dict],
  105. rss_stats: Optional[List[Dict]] = None,
  106. report_mode: str = "daily",
  107. report_type: str = "当日汇总",
  108. platforms: Optional[List[str]] = None,
  109. keywords: Optional[List[str]] = None,
  110. ) -> AIAnalysisResult:
  111. """
  112. 执行 AI 分析
  113. Args:
  114. stats: 热榜统计数据
  115. rss_stats: RSS 统计数据
  116. report_mode: 报告模式
  117. report_type: 报告类型
  118. platforms: 平台列表
  119. keywords: 关键词列表
  120. Returns:
  121. AIAnalysisResult: 分析结果
  122. """
  123. if not self.api_key:
  124. return AIAnalysisResult(
  125. success=False,
  126. error="未配置 AI API Key,请在 config.yaml 或环境变量 AI_API_KEY 中设置"
  127. )
  128. # 准备新闻内容并获取统计数据
  129. news_content, rss_content, hotlist_total, rss_total, analyzed_count = self._prepare_news_content(stats, rss_stats)
  130. total_news = hotlist_total + rss_total
  131. if not news_content and not rss_content:
  132. return AIAnalysisResult(
  133. success=False,
  134. error="没有可分析的新闻内容",
  135. total_news=total_news,
  136. hotlist_count=hotlist_total,
  137. rss_count=rss_total,
  138. analyzed_news=0,
  139. max_news_limit=self.max_news
  140. )
  141. # 构建提示词
  142. current_time = self.get_time_func().strftime("%Y-%m-%d %H:%M:%S")
  143. # 提取关键词
  144. if not keywords:
  145. keywords = [s.get("word", "") for s in stats if s.get("word")] if stats else []
  146. # 使用安全的字符串替换,避免模板中其他花括号(如 JSON 示例)被误解析
  147. user_prompt = self.user_prompt_template
  148. user_prompt = user_prompt.replace("{report_mode}", report_mode)
  149. user_prompt = user_prompt.replace("{report_type}", report_type)
  150. user_prompt = user_prompt.replace("{current_time}", current_time)
  151. user_prompt = user_prompt.replace("{news_count}", str(hotlist_total))
  152. user_prompt = user_prompt.replace("{rss_count}", str(rss_total))
  153. user_prompt = user_prompt.replace("{platforms}", ", ".join(platforms) if platforms else "多平台")
  154. user_prompt = user_prompt.replace("{keywords}", ", ".join(keywords[:20]) if keywords else "无")
  155. user_prompt = user_prompt.replace("{news_content}", news_content)
  156. user_prompt = user_prompt.replace("{rss_content}", rss_content)
  157. user_prompt = user_prompt.replace("{language}", self.language)
  158. if self.debug:
  159. print("\n" + "=" * 80)
  160. print("[AI 调试] 发送给 AI 的完整提示词")
  161. print("=" * 80)
  162. if self.system_prompt:
  163. print("\n--- System Prompt ---")
  164. print(self.system_prompt)
  165. print("\n--- User Prompt ---")
  166. print(user_prompt)
  167. print("=" * 80 + "\n")
  168. # 调用 AI API
  169. try:
  170. response = self._call_ai_api(user_prompt)
  171. result = self._parse_response(response)
  172. # 如果配置未启用 RSS 分析,强制清空 AI 返回的 RSS 洞察
  173. if not self.include_rss:
  174. result.rss_insights = ""
  175. # 填充统计数据
  176. result.total_news = total_news
  177. result.hotlist_count = hotlist_total
  178. result.rss_count = rss_total
  179. result.analyzed_news = analyzed_count
  180. result.max_news_limit = self.max_news
  181. return result
  182. except Exception as e:
  183. import requests
  184. error_type = type(e).__name__
  185. error_msg = str(e)
  186. # 针对不同错误类型提供更友好的提示
  187. if isinstance(e, requests.exceptions.Timeout):
  188. friendly_msg = f"AI API 请求超时({self.timeout}秒),请检查网络或增加超时时间"
  189. elif isinstance(e, requests.exceptions.ConnectionError):
  190. friendly_msg = f"无法连接到 AI API ({self.base_url or self.provider}),请检查网络和 API 地址"
  191. elif isinstance(e, requests.exceptions.HTTPError):
  192. status_code = e.response.status_code if hasattr(e, 'response') and e.response else "未知"
  193. if status_code == 401:
  194. friendly_msg = "AI API 认证失败,请检查 API Key 是否正确"
  195. elif status_code == 429:
  196. friendly_msg = "AI API 请求频率过高,请稍后重试"
  197. elif status_code == 500:
  198. friendly_msg = "AI API 服务器内部错误,请稍后重试"
  199. else:
  200. friendly_msg = f"AI API 返回错误 (HTTP {status_code}): {error_msg[:100]}"
  201. else:
  202. # 截断过长的错误消息
  203. if len(error_msg) > 150:
  204. error_msg = error_msg[:150] + "..."
  205. friendly_msg = f"AI 分析失败 ({error_type}): {error_msg}"
  206. return AIAnalysisResult(
  207. success=False,
  208. error=friendly_msg
  209. )
  210. def _prepare_news_content(
  211. self,
  212. stats: List[Dict],
  213. rss_stats: Optional[List[Dict]] = None,
  214. ) -> tuple:
  215. """
  216. 准备新闻内容文本(增强版)
  217. 热榜新闻包含:来源、标题、排名范围、时间范围、出现次数
  218. RSS 包含:来源、标题、发布时间
  219. Returns:
  220. tuple: (news_content, rss_content, hotlist_total, rss_total, analyzed_count)
  221. """
  222. news_lines = []
  223. rss_lines = []
  224. news_count = 0
  225. rss_count = 0
  226. # 计算总新闻数
  227. hotlist_total = sum(len(s.get("titles", [])) for s in stats) if stats else 0
  228. rss_total = sum(len(s.get("titles", [])) for s in rss_stats) if rss_stats else 0
  229. # 热榜内容
  230. if stats:
  231. for stat in stats:
  232. word = stat.get("word", "")
  233. titles = stat.get("titles", [])
  234. if word and titles:
  235. news_lines.append(f"\n**{word}** ({len(titles)}条)")
  236. for t in titles:
  237. if not isinstance(t, dict):
  238. continue
  239. title = t.get("title", "")
  240. if not title:
  241. continue
  242. # 来源
  243. source = t.get("source_name", t.get("source", ""))
  244. # 构建行
  245. if source:
  246. line = f"- [{source}] {title}"
  247. else:
  248. line = f"- {title}"
  249. # 始终显示简化格式:排名范围 + 时间范围 + 出现次数
  250. ranks = t.get("ranks", [])
  251. if ranks:
  252. min_rank = min(ranks)
  253. max_rank = max(ranks)
  254. rank_str = f"{min_rank}" if min_rank == max_rank else f"{min_rank}-{max_rank}"
  255. else:
  256. rank_str = "-"
  257. first_time = t.get("first_time", "")
  258. last_time = t.get("last_time", "")
  259. time_str = self._format_time_range(first_time, last_time)
  260. appear_count = t.get("count", 1)
  261. line += f" | 排名:{rank_str} | 时间:{time_str} | 出现:{appear_count}次"
  262. # 开启完整时间线时,额外添加轨迹
  263. if self.include_rank_timeline:
  264. rank_timeline = t.get("rank_timeline", [])
  265. timeline_str = self._format_rank_timeline(rank_timeline)
  266. line += f" | 轨迹:{timeline_str}"
  267. news_lines.append(line)
  268. news_count += 1
  269. if news_count >= self.max_news:
  270. break
  271. if news_count >= self.max_news:
  272. break
  273. # RSS 内容(仅在启用时构建)
  274. if self.include_rss and rss_stats:
  275. remaining = self.max_news - news_count
  276. for stat in rss_stats:
  277. if rss_count >= remaining:
  278. break
  279. word = stat.get("word", "")
  280. titles = stat.get("titles", [])
  281. if word and titles:
  282. rss_lines.append(f"\n**{word}** ({len(titles)}条)")
  283. for t in titles:
  284. if not isinstance(t, dict):
  285. continue
  286. title = t.get("title", "")
  287. if not title:
  288. continue
  289. # 来源
  290. source = t.get("source_name", t.get("feed_name", ""))
  291. # 发布时间
  292. time_display = t.get("time_display", "")
  293. # 构建行:[来源] 标题 | 发布时间
  294. if source:
  295. line = f"- [{source}] {title}"
  296. else:
  297. line = f"- {title}"
  298. if time_display:
  299. line += f" | {time_display}"
  300. rss_lines.append(line)
  301. rss_count += 1
  302. if rss_count >= remaining:
  303. break
  304. news_content = "\n".join(news_lines) if news_lines else ""
  305. rss_content = "\n".join(rss_lines) if rss_lines else ""
  306. total_count = news_count + rss_count
  307. return news_content, rss_content, hotlist_total, rss_total, total_count
  308. def _format_time_range(self, first_time: str, last_time: str) -> str:
  309. """格式化时间范围(简化显示,只保留时分)"""
  310. def extract_time(time_str: str) -> str:
  311. if not time_str:
  312. return "-"
  313. # 尝试提取 HH:MM 部分
  314. if " " in time_str:
  315. parts = time_str.split(" ")
  316. if len(parts) >= 2:
  317. time_part = parts[1]
  318. if ":" in time_part:
  319. return time_part[:5] # HH:MM
  320. elif ":" in time_str:
  321. return time_str[:5]
  322. # 处理 HH-MM 格式
  323. result = time_str[:5] if len(time_str) >= 5 else time_str
  324. if len(result) == 5 and result[2] == '-':
  325. result = result.replace('-', ':')
  326. return result
  327. first = extract_time(first_time)
  328. last = extract_time(last_time)
  329. if first == last or last == "-":
  330. return first
  331. return f"{first}~{last}"
  332. def _format_rank_timeline(self, rank_timeline: List[Dict]) -> str:
  333. """格式化排名时间线"""
  334. if not rank_timeline:
  335. return "-"
  336. parts = []
  337. for item in rank_timeline:
  338. time_str = item.get("time", "")
  339. if len(time_str) == 5 and time_str[2] == '-':
  340. time_str = time_str.replace('-', ':')
  341. rank = item.get("rank")
  342. if rank is None:
  343. parts.append(f"0({time_str})")
  344. else:
  345. parts.append(f"{rank}({time_str})")
  346. return "→".join(parts)
  347. def _call_ai_api(self, user_prompt: str) -> str:
  348. """调用 AI API"""
  349. if self.provider == "gemini":
  350. return self._call_gemini(user_prompt)
  351. return self._call_openai_compatible(user_prompt)
  352. def _get_api_url(self) -> str:
  353. """获取完整 API URL"""
  354. if self.base_url:
  355. return self.base_url
  356. # 预设完整端点
  357. urls = {
  358. "deepseek": "https://api.deepseek.com/v1/chat/completions",
  359. "openai": "https://api.openai.com/v1/chat/completions",
  360. }
  361. url = urls.get(self.provider)
  362. if not url:
  363. raise ValueError(f"{self.provider} 需要配置 base_url(完整 API 地址)")
  364. return url
  365. def _call_openai_compatible(self, user_prompt: str) -> str:
  366. """调用 OpenAI 兼容接口"""
  367. import requests
  368. url = self._get_api_url()
  369. headers = {
  370. "Authorization": f"Bearer {self.api_key}",
  371. "Content-Type": "application/json",
  372. }
  373. messages = []
  374. if self.system_prompt:
  375. messages.append({"role": "system", "content": self.system_prompt})
  376. messages.append({"role": "user", "content": user_prompt})
  377. payload = {
  378. "model": self.model,
  379. "messages": messages,
  380. "temperature": self.temperature,
  381. }
  382. # 某些 API 不支持 max_tokens
  383. if self.max_tokens:
  384. payload["max_tokens"] = self.max_tokens
  385. if self.extra_params:
  386. payload.update(self.extra_params)
  387. response = requests.post(
  388. url,
  389. headers=headers,
  390. json=payload,
  391. timeout=self.timeout,
  392. )
  393. response.raise_for_status()
  394. data = response.json()
  395. return data["choices"][0]["message"]["content"]
  396. def _call_gemini(self, user_prompt: str) -> str:
  397. """调用 Google Gemini API"""
  398. import requests
  399. model = self.model or "gemini-1.5-flash"
  400. url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={self.api_key}"
  401. headers = {
  402. "Content-Type": "application/json",
  403. }
  404. payload = {
  405. "contents": [{
  406. "role": "user",
  407. "parts": [{"text": user_prompt}]
  408. }],
  409. "generationConfig": {
  410. "temperature": self.temperature,
  411. },
  412. "safetySettings": [
  413. {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
  414. {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
  415. {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
  416. {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
  417. ]
  418. }
  419. if self.system_prompt:
  420. payload["system_instruction"] = {
  421. "parts": [{"text": self.system_prompt}]
  422. }
  423. if self.max_tokens:
  424. payload["generationConfig"]["maxOutputTokens"] = self.max_tokens
  425. if self.extra_params:
  426. payload["generationConfig"].update(self.extra_params)
  427. response = requests.post(
  428. url,
  429. headers=headers,
  430. json=payload,
  431. timeout=self.timeout,
  432. )
  433. response.raise_for_status()
  434. data = response.json()
  435. return data["candidates"][0]["content"]["parts"][0]["text"]
  436. def _parse_response(self, response: str) -> AIAnalysisResult:
  437. """解析 AI 响应"""
  438. result = AIAnalysisResult(raw_response=response)
  439. if not response or not response.strip():
  440. result.error = "AI 返回空响应"
  441. return result
  442. # 尝试解析 JSON
  443. try:
  444. # 提取 JSON 部分
  445. json_str = response
  446. if "```json" in response:
  447. parts = response.split("```json", 1)
  448. if len(parts) > 1:
  449. code_block = parts[1]
  450. end_idx = code_block.find("```")
  451. if end_idx != -1:
  452. json_str = code_block[:end_idx]
  453. else:
  454. json_str = code_block
  455. elif "```" in response:
  456. parts = response.split("```", 2)
  457. if len(parts) >= 2:
  458. json_str = parts[1]
  459. json_str = json_str.strip()
  460. if not json_str:
  461. raise ValueError("提取的 JSON 内容为空")
  462. data = json.loads(json_str)
  463. # 新版字段解析
  464. result.core_trends = data.get("core_trends", "")
  465. result.sentiment_controversy = data.get("sentiment_controversy", "")
  466. result.signals = data.get("signals", "")
  467. result.rss_insights = data.get("rss_insights", "")
  468. result.outlook_strategy = data.get("outlook_strategy", "")
  469. result.success = True
  470. except json.JSONDecodeError as e:
  471. error_context = json_str[max(0, e.pos - 30):e.pos + 30] if json_str and e.pos else ""
  472. result.error = f"JSON 解析错误 (位置 {e.pos}): {e.msg}"
  473. if error_context:
  474. result.error += f",上下文: ...{error_context}..."
  475. # 使用原始响应填充 core_trends,确保有输出
  476. result.core_trends = response[:500] + "..." if len(response) > 500 else response
  477. result.success = True
  478. except (IndexError, KeyError, TypeError, ValueError) as e:
  479. result.error = f"响应解析错误: {type(e).__name__}: {str(e)}"
  480. result.core_trends = response[:500] if len(response) > 500 else response
  481. result.success = True
  482. except Exception as e:
  483. result.error = f"解析时发生未知错误: {type(e).__name__}: {str(e)}"
  484. result.core_trends = response[:500] if len(response) > 500 else response
  485. result.success = True
  486. return result