| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503 |
- # coding=utf-8
- """
- AI 分析器模块
- 调用 AI 大模型对热点新闻进行深度分析
- 支持 OpenAI、Google Gemini、Azure OpenAI 等兼容接口
- """
- import json
- import os
- from dataclasses import dataclass
- from pathlib import Path
- from typing import Any, Callable, Dict, List, Optional
- @dataclass
- class AIAnalysisResult:
- """AI 分析结果"""
- summary: str = "" # 热点趋势概述
- keyword_analysis: str = "" # 关键词热度分析
- sentiment: str = "" # 情感倾向分析
- cross_platform: str = "" # 跨平台关联
- impact: str = "" # 潜在影响评估
- signals: str = "" # 值得关注的信号
- conclusion: str = "" # 总结与建议
- raw_response: str = "" # 原始响应
- success: bool = False # 是否成功
- error: str = "" # 错误信息
- # 新闻数量统计
- total_news: int = 0 # 总新闻数(热榜+RSS)
- analyzed_news: int = 0 # 实际分析的新闻数
- max_news_limit: int = 0 # 分析上限配置值
- hotlist_count: int = 0 # 热榜新闻数
- rss_count: int = 0 # RSS 新闻数
- class AIAnalyzer:
- """AI 分析器"""
- def __init__(self, config: Dict[str, Any], get_time_func: Callable):
- """
- 初始化 AI 分析器
- Args:
- config: AI 分析配置
- get_time_func: 获取当前时间的函数
- """
- self.config = config
- self.get_time_func = get_time_func
- # 从配置或环境变量获取 API Key
- self.api_key = config.get("API_KEY") or os.environ.get("AI_API_KEY", "")
- self.provider = config.get("PROVIDER", "openai")
- self.model = config.get("MODEL", "gpt-4o-mini")
- self.base_url = config.get("BASE_URL", "")
- self.timeout = config.get("TIMEOUT", 90)
- self.max_news = config.get("MAX_NEWS_FOR_ANALYSIS", 50)
- self.include_rss = config.get("INCLUDE_RSS", True)
- self.push_mode = config.get("PUSH_MODE", "both")
- # 加载提示词模板
- self.system_prompt, self.user_prompt_template = self._load_prompt_template(
- config.get("PROMPT_FILE", "ai_analysis_prompt.txt")
- )
- def _load_prompt_template(self, prompt_file: str) -> tuple:
- """加载提示词模板"""
- config_dir = Path(__file__).parent.parent.parent / "config"
- prompt_path = config_dir / prompt_file
- if not prompt_path.exists():
- print(f"[AI] 提示词文件不存在: {prompt_path}")
- return "", ""
- content = prompt_path.read_text(encoding="utf-8")
- # 解析 [system] 和 [user] 部分
- system_prompt = ""
- user_prompt = ""
- if "[system]" in content and "[user]" in content:
- parts = content.split("[user]")
- system_part = parts[0]
- user_part = parts[1] if len(parts) > 1 else ""
- # 提取 system 内容
- if "[system]" in system_part:
- system_prompt = system_part.split("[system]")[1].strip()
- user_prompt = user_part.strip()
- else:
- # 整个文件作为 user prompt
- user_prompt = content
- return system_prompt, user_prompt
- def analyze(
- self,
- stats: List[Dict],
- rss_stats: Optional[List[Dict]] = None,
- report_mode: str = "daily",
- report_type: str = "当日汇总",
- platforms: Optional[List[str]] = None,
- keywords: Optional[List[str]] = None,
- ) -> AIAnalysisResult:
- """
- 执行 AI 分析
- Args:
- stats: 热榜统计数据
- rss_stats: RSS 统计数据
- report_mode: 报告模式
- report_type: 报告类型
- platforms: 平台列表
- keywords: 关键词列表
- Returns:
- AIAnalysisResult: 分析结果
- """
- if not self.api_key:
- return AIAnalysisResult(
- success=False,
- error="未配置 AI API Key,请在 config.yaml 或环境变量 AI_API_KEY 中设置"
- )
- # 准备新闻内容并获取统计数据
- news_content, hotlist_total, rss_total, analyzed_count = self._prepare_news_content(stats, rss_stats)
- total_news = hotlist_total + rss_total
- if not news_content:
- return AIAnalysisResult(
- success=False,
- error="没有可分析的新闻内容",
- total_news=total_news,
- hotlist_count=hotlist_total,
- rss_count=rss_total,
- analyzed_news=0,
- max_news_limit=self.max_news
- )
- # 构建提示词
- current_time = self.get_time_func().strftime("%Y-%m-%d %H:%M:%S")
- # 提取关键词
- if not keywords:
- keywords = [s.get("word", "") for s in stats if s.get("word")] if stats else []
- # 使用安全的字符串替换,避免模板中其他花括号(如 JSON 示例)被误解析
- user_prompt = self.user_prompt_template
- user_prompt = user_prompt.replace("{report_mode}", report_mode)
- user_prompt = user_prompt.replace("{report_type}", report_type)
- user_prompt = user_prompt.replace("{current_time}", current_time)
- user_prompt = user_prompt.replace("{news_count}", str(hotlist_total))
- user_prompt = user_prompt.replace("{rss_count}", str(rss_total))
- user_prompt = user_prompt.replace("{platforms}", ", ".join(platforms) if platforms else "多平台")
- user_prompt = user_prompt.replace("{keywords}", ", ".join(keywords[:20]) if keywords else "无")
- user_prompt = user_prompt.replace("{news_content}", news_content)
- # 调用 AI API
- try:
- response = self._call_ai_api(user_prompt)
- result = self._parse_response(response)
- # 填充统计数据
- result.total_news = total_news
- result.hotlist_count = hotlist_total
- result.rss_count = rss_total
- result.analyzed_news = analyzed_count
- result.max_news_limit = self.max_news
- return result
- except Exception as e:
- import requests
- error_type = type(e).__name__
- error_msg = str(e)
- # 针对不同错误类型提供更友好的提示
- if isinstance(e, requests.exceptions.Timeout):
- friendly_msg = f"AI API 请求超时({self.timeout}秒),请检查网络或增加超时时间"
- elif isinstance(e, requests.exceptions.ConnectionError):
- friendly_msg = f"无法连接到 AI API ({self.base_url or self.provider}),请检查网络和 API 地址"
- elif isinstance(e, requests.exceptions.HTTPError):
- status_code = e.response.status_code if hasattr(e, 'response') and e.response else "未知"
- if status_code == 401:
- friendly_msg = "AI API 认证失败,请检查 API Key 是否正确"
- elif status_code == 429:
- friendly_msg = "AI API 请求频率过高,请稍后重试"
- elif status_code == 500:
- friendly_msg = "AI API 服务器内部错误,请稍后重试"
- else:
- friendly_msg = f"AI API 返回错误 (HTTP {status_code}): {error_msg[:100]}"
- else:
- # 截断过长的错误消息
- if len(error_msg) > 150:
- error_msg = error_msg[:150] + "..."
- friendly_msg = f"AI 分析失败 ({error_type}): {error_msg}"
- return AIAnalysisResult(
- success=False,
- error=friendly_msg
- )
- def _prepare_news_content(
- self,
- stats: List[Dict],
- rss_stats: Optional[List[Dict]] = None,
- ) -> tuple:
- """
- 准备新闻内容文本(增强版)
- 热榜新闻包含:来源、标题、排名范围、时间范围、出现次数
- RSS 包含:来源、标题、发布时间
- Returns:
- tuple: (content_str, hotlist_total, rss_total, analyzed_count)
- """
- lines = []
- count = 0
- # 计算总新闻数
- hotlist_total = sum(len(s.get("titles", [])) for s in stats) if stats else 0
- rss_total = sum(len(s.get("titles", [])) for s in rss_stats) if rss_stats else 0
- # 热榜内容
- if stats:
- lines.append("### 热榜新闻")
- lines.append("格式: [来源] 标题 | 排名:最高-最低 | 时间:首次~末次 | 出现:N次")
- for stat in stats:
- word = stat.get("word", "")
- titles = stat.get("titles", [])
- if word and titles:
- lines.append(f"\n**{word}** ({len(titles)}条)")
- for t in titles:
- if not isinstance(t, dict):
- continue
- title = t.get("title", "")
- if not title:
- continue
- # 来源
- source = t.get("source_name", t.get("source", ""))
- # 排名范围
- ranks = t.get("ranks", [])
- if ranks:
- min_rank = min(ranks)
- max_rank = max(ranks)
- rank_str = f"{min_rank}" if min_rank == max_rank else f"{min_rank}-{max_rank}"
- else:
- rank_str = "-"
- # 时间范围(简化显示)
- first_time = t.get("first_time", "")
- last_time = t.get("last_time", "")
- time_str = self._format_time_range(first_time, last_time)
- # 出现次数
- appear_count = t.get("count", 1)
- # 构建行:[来源] 标题 | 排名:X-Y | 时间:首次~末次 | 出现:N次
- if source:
- line = f"- [{source}] {title}"
- else:
- line = f"- {title}"
- line += f" | 排名:{rank_str} | 时间:{time_str} | 出现:{appear_count}次"
- lines.append(line)
- count += 1
- if count >= self.max_news:
- break
- if count >= self.max_news:
- break
- # RSS 内容(仅在启用时提交)
- if self.include_rss and rss_stats and count < self.max_news:
- lines.append("\n### RSS 订阅")
- lines.append("格式: [来源] 标题 | 发布时间")
- for stat in rss_stats:
- word = stat.get("word", "")
- titles = stat.get("titles", [])
- if word and titles:
- lines.append(f"\n**{word}** ({len(titles)}条)")
- for t in titles:
- if not isinstance(t, dict):
- continue
- title = t.get("title", "")
- if not title:
- continue
- # 来源
- source = t.get("source_name", t.get("feed_name", ""))
- # 发布时间
- time_display = t.get("time_display", "")
- # 构建行:[来源] 标题 | 发布时间
- if source:
- line = f"- [{source}] {title}"
- else:
- line = f"- {title}"
- if time_display:
- line += f" | {time_display}"
- lines.append(line)
- count += 1
- if count >= self.max_news:
- break
- if count >= self.max_news:
- break
- return "\n".join(lines), hotlist_total, rss_total, count
- def _format_time_range(self, first_time: str, last_time: str) -> str:
- """格式化时间范围(简化显示,只保留时分)"""
- def extract_time(time_str: str) -> str:
- if not time_str:
- return "-"
- # 尝试提取 HH:MM 部分
- # 格式可能是 "2026-01-04 12:30:00" 或 "12:30" 等
- if " " in time_str:
- parts = time_str.split(" ")
- if len(parts) >= 2:
- time_part = parts[1]
- if ":" in time_part:
- return time_part[:5] # HH:MM
- elif ":" in time_str:
- return time_str[:5]
- return time_str[:5] if len(time_str) >= 5 else time_str
- first = extract_time(first_time)
- last = extract_time(last_time)
- if first == last or last == "-":
- return first
- return f"{first}~{last}"
- def _call_ai_api(self, user_prompt: str) -> str:
- """调用 AI API"""
- if self.provider == "gemini":
- return self._call_gemini(user_prompt)
- return self._call_openai_compatible(user_prompt)
- def _get_api_url(self) -> str:
- """获取完整 API URL"""
- if self.base_url:
- return self.base_url
- # 预设完整端点
- urls = {
- "deepseek": "https://api.deepseek.com/v1/chat/completions",
- "openai": "https://api.openai.com/v1/chat/completions",
- }
- url = urls.get(self.provider)
- if not url:
- raise ValueError(f"{self.provider} 需要配置 base_url(完整 API 地址)")
- return url
- def _call_openai_compatible(self, user_prompt: str) -> str:
- """调用 OpenAI 兼容接口"""
- import requests
- url = self._get_api_url()
- headers = {
- "Authorization": f"Bearer {self.api_key}",
- "Content-Type": "application/json",
- }
- messages = []
- if self.system_prompt:
- messages.append({"role": "system", "content": self.system_prompt})
- messages.append({"role": "user", "content": user_prompt})
- payload = {
- "model": self.model,
- "messages": messages,
- "temperature": 0.7,
- "max_tokens": 2000,
- }
- response = requests.post(
- url,
- headers=headers,
- json=payload,
- timeout=self.timeout,
- )
- response.raise_for_status()
- data = response.json()
- return data["choices"][0]["message"]["content"]
- def _call_gemini(self, user_prompt: str) -> str:
- """调用 Google Gemini API"""
- import requests
- # Gemini API URL 格式: https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent
- model = self.model or "gemini-1.5-flash"
- url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={self.api_key}"
- headers = {
- "Content-Type": "application/json",
- }
- # 构建 Gemini 格式的消息
- contents = []
- if self.system_prompt:
- contents.append({
- "role": "user",
- "parts": [{"text": f"System instruction: {self.system_prompt}"}]
- })
- contents.append({
- "role": "model",
- "parts": [{"text": "Understood. I will follow these instructions."}]
- })
- contents.append({
- "role": "user",
- "parts": [{"text": user_prompt}]
- })
- payload = {
- "contents": contents,
- "generationConfig": {
- "temperature": 0.7,
- "maxOutputTokens": 2000,
- }
- }
- response = requests.post(
- url,
- headers=headers,
- json=payload,
- timeout=self.timeout,
- )
- response.raise_for_status()
- data = response.json()
- return data["candidates"][0]["content"]["parts"][0]["text"]
- def _parse_response(self, response: str) -> AIAnalysisResult:
- """解析 AI 响应"""
- result = AIAnalysisResult(raw_response=response)
- if not response or not response.strip():
- result.error = "AI 返回空响应"
- return result
- # 尝试解析 JSON
- try:
- # 提取 JSON 部分
- json_str = response
- # 尝试提取 ```json ... ``` 代码块
- if "```json" in response:
- parts = response.split("```json", 1)
- if len(parts) > 1:
- code_block = parts[1]
- # 查找结束的 ```
- end_idx = code_block.find("```")
- if end_idx != -1:
- json_str = code_block[:end_idx]
- else:
- json_str = code_block # 没有结束标记,使用剩余内容
- # 尝试提取 ``` ... ``` 代码块
- elif "```" in response:
- parts = response.split("```", 2) # 最多分割2次
- if len(parts) >= 2:
- json_str = parts[1]
- # 清理 JSON 字符串
- json_str = json_str.strip()
- if not json_str:
- raise ValueError("提取的 JSON 内容为空")
- data = json.loads(json_str)
- result.summary = data.get("summary", "")
- result.keyword_analysis = data.get("keyword_analysis", "")
- result.sentiment = data.get("sentiment", "")
- result.cross_platform = data.get("cross_platform", "")
- result.impact = data.get("impact", "")
- result.signals = data.get("signals", "")
- result.conclusion = data.get("conclusion", "")
- result.success = True
- except json.JSONDecodeError as e:
- # JSON 解析失败,记录详细错误但仍使用原始文本
- error_context = json_str[max(0, e.pos - 30):e.pos + 30] if json_str and e.pos else ""
- result.error = f"JSON 解析错误 (位置 {e.pos}): {e.msg}"
- if error_context:
- result.error += f",上下文: ...{error_context}..."
- # 使用原始响应作为 summary
- result.summary = response[:1000] if len(response) > 1000 else response
- result.success = True # 仍标记为成功,因为有内容可展示
- except (IndexError, KeyError, TypeError, ValueError) as e:
- # 其他解析错误
- result.error = f"响应解析错误: {type(e).__name__}: {str(e)}"
- result.summary = response[:1000] if len(response) > 1000 else response
- result.success = True
- except Exception as e:
- # 未知错误
- result.error = f"解析时发生未知错误: {type(e).__name__}: {str(e)}"
- result.summary = response[:1000] if len(response) > 1000 else response
- result.success = True
- return result
|