| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451 |
- # coding=utf-8
- """
- AI 分析器模块
- 调用 AI 大模型对热点新闻进行深度分析
- 基于 LiteLLM 统一接口,支持 100+ AI 提供商
- """
- import json
- from dataclasses import dataclass
- from pathlib import Path
- from typing import Any, Callable, Dict, List, Optional
- from trendradar.ai.client import AIClient
- @dataclass
- class AIAnalysisResult:
- """AI 分析结果"""
- # 新版 5 核心板块
- core_trends: str = "" # 核心热点与舆情态势
- sentiment_controversy: str = "" # 舆论风向与争议
- signals: str = "" # 异动与弱信号
- rss_insights: str = "" # RSS 深度洞察
- outlook_strategy: str = "" # 研判与策略建议
- # 基础元数据
- raw_response: str = "" # 原始响应
- success: bool = False # 是否成功
- error: str = "" # 错误信息
- # 新闻数量统计
- total_news: int = 0 # 总新闻数(热榜+RSS)
- analyzed_news: int = 0 # 实际分析的新闻数
- max_news_limit: int = 0 # 分析上限配置值
- hotlist_count: int = 0 # 热榜新闻数
- rss_count: int = 0 # RSS 新闻数
- class AIAnalyzer:
- """AI 分析器"""
- def __init__(
- self,
- ai_config: Dict[str, Any],
- analysis_config: Dict[str, Any],
- get_time_func: Callable,
- debug: bool = False,
- ):
- """
- 初始化 AI 分析器
- Args:
- ai_config: AI 模型配置(LiteLLM 格式)
- analysis_config: AI 分析功能配置(language, prompt_file 等)
- get_time_func: 获取当前时间的函数
- debug: 是否开启调试模式
- """
- self.ai_config = ai_config
- self.analysis_config = analysis_config
- self.get_time_func = get_time_func
- self.debug = debug
- # 创建 AI 客户端(基于 LiteLLM)
- self.client = AIClient(ai_config)
- # 验证配置
- valid, error = self.client.validate_config()
- if not valid:
- print(f"[AI] 配置警告: {error}")
- # 从分析配置获取功能参数
- self.max_news = analysis_config.get("MAX_NEWS_FOR_ANALYSIS", 50)
- self.include_rss = analysis_config.get("INCLUDE_RSS", True)
- self.include_rank_timeline = analysis_config.get("INCLUDE_RANK_TIMELINE", False)
- self.language = analysis_config.get("LANGUAGE", "Chinese")
- # 加载提示词模板
- self.system_prompt, self.user_prompt_template = self._load_prompt_template(
- analysis_config.get("PROMPT_FILE", "ai_analysis_prompt.txt")
- )
- def _load_prompt_template(self, prompt_file: str) -> tuple:
- """加载提示词模板"""
- config_dir = Path(__file__).parent.parent.parent / "config"
- prompt_path = config_dir / prompt_file
- if not prompt_path.exists():
- print(f"[AI] 提示词文件不存在: {prompt_path}")
- return "", ""
- content = prompt_path.read_text(encoding="utf-8")
- # 解析 [system] 和 [user] 部分
- system_prompt = ""
- user_prompt = ""
- if "[system]" in content and "[user]" in content:
- parts = content.split("[user]")
- system_part = parts[0]
- user_part = parts[1] if len(parts) > 1 else ""
- # 提取 system 内容
- if "[system]" in system_part:
- system_prompt = system_part.split("[system]")[1].strip()
- user_prompt = user_part.strip()
- else:
- # 整个文件作为 user prompt
- user_prompt = content
- return system_prompt, user_prompt
- def analyze(
- self,
- stats: List[Dict],
- rss_stats: Optional[List[Dict]] = None,
- report_mode: str = "daily",
- report_type: str = "当日汇总",
- platforms: Optional[List[str]] = None,
- keywords: Optional[List[str]] = None,
- ) -> AIAnalysisResult:
- """
- 执行 AI 分析
- Args:
- stats: 热榜统计数据
- rss_stats: RSS 统计数据
- report_mode: 报告模式
- report_type: 报告类型
- platforms: 平台列表
- keywords: 关键词列表
- Returns:
- AIAnalysisResult: 分析结果
- """
- if not self.client.api_key:
- return AIAnalysisResult(
- success=False,
- error="未配置 AI API Key,请在 config.yaml 或环境变量 AI_API_KEY 中设置"
- )
- # 准备新闻内容并获取统计数据
- news_content, rss_content, hotlist_total, rss_total, analyzed_count = self._prepare_news_content(stats, rss_stats)
- total_news = hotlist_total + rss_total
- if not news_content and not rss_content:
- return AIAnalysisResult(
- success=False,
- error="没有可分析的新闻内容",
- total_news=total_news,
- hotlist_count=hotlist_total,
- rss_count=rss_total,
- analyzed_news=0,
- max_news_limit=self.max_news
- )
- # 构建提示词
- current_time = self.get_time_func().strftime("%Y-%m-%d %H:%M:%S")
- # 提取关键词
- if not keywords:
- keywords = [s.get("word", "") for s in stats if s.get("word")] if stats else []
- # 使用安全的字符串替换,避免模板中其他花括号(如 JSON 示例)被误解析
- user_prompt = self.user_prompt_template
- user_prompt = user_prompt.replace("{report_mode}", report_mode)
- user_prompt = user_prompt.replace("{report_type}", report_type)
- user_prompt = user_prompt.replace("{current_time}", current_time)
- user_prompt = user_prompt.replace("{news_count}", str(hotlist_total))
- user_prompt = user_prompt.replace("{rss_count}", str(rss_total))
- user_prompt = user_prompt.replace("{platforms}", ", ".join(platforms) if platforms else "多平台")
- user_prompt = user_prompt.replace("{keywords}", ", ".join(keywords[:20]) if keywords else "无")
- user_prompt = user_prompt.replace("{news_content}", news_content)
- user_prompt = user_prompt.replace("{rss_content}", rss_content)
- user_prompt = user_prompt.replace("{language}", self.language)
- if self.debug:
- print("\n" + "=" * 80)
- print("[AI 调试] 发送给 AI 的完整提示词")
- print("=" * 80)
- if self.system_prompt:
- print("\n--- System Prompt ---")
- print(self.system_prompt)
- print("\n--- User Prompt ---")
- print(user_prompt)
- print("=" * 80 + "\n")
- # 调用 AI API(使用 LiteLLM)
- try:
- response = self._call_ai(user_prompt)
- result = self._parse_response(response)
- # 如果配置未启用 RSS 分析,强制清空 AI 返回的 RSS 洞察
- if not self.include_rss:
- result.rss_insights = ""
- # 填充统计数据
- result.total_news = total_news
- result.hotlist_count = hotlist_total
- result.rss_count = rss_total
- result.analyzed_news = analyzed_count
- result.max_news_limit = self.max_news
- return result
- except Exception as e:
- error_type = type(e).__name__
- error_msg = str(e)
- # 截断过长的错误消息
- if len(error_msg) > 200:
- error_msg = error_msg[:200] + "..."
- friendly_msg = f"AI 分析失败 ({error_type}): {error_msg}"
- return AIAnalysisResult(
- success=False,
- error=friendly_msg
- )
- def _prepare_news_content(
- self,
- stats: List[Dict],
- rss_stats: Optional[List[Dict]] = None,
- ) -> tuple:
- """
- 准备新闻内容文本(增强版)
- 热榜新闻包含:来源、标题、排名范围、时间范围、出现次数
- RSS 包含:来源、标题、发布时间
- Returns:
- tuple: (news_content, rss_content, hotlist_total, rss_total, analyzed_count)
- """
- news_lines = []
- rss_lines = []
- news_count = 0
- rss_count = 0
- # 计算总新闻数
- hotlist_total = sum(len(s.get("titles", [])) for s in stats) if stats else 0
- rss_total = sum(len(s.get("titles", [])) for s in rss_stats) if rss_stats else 0
- # 热榜内容
- if stats:
- for stat in stats:
- word = stat.get("word", "")
- titles = stat.get("titles", [])
- if word and titles:
- news_lines.append(f"\n**{word}** ({len(titles)}条)")
- for t in titles:
- if not isinstance(t, dict):
- continue
- title = t.get("title", "")
- if not title:
- continue
- # 来源
- source = t.get("source_name", t.get("source", ""))
- # 构建行
- if source:
- line = f"- [{source}] {title}"
- else:
- line = f"- {title}"
- # 始终显示简化格式:排名范围 + 时间范围 + 出现次数
- ranks = t.get("ranks", [])
- if ranks:
- min_rank = min(ranks)
- max_rank = max(ranks)
- rank_str = f"{min_rank}" if min_rank == max_rank else f"{min_rank}-{max_rank}"
- else:
- rank_str = "-"
- first_time = t.get("first_time", "")
- last_time = t.get("last_time", "")
- time_str = self._format_time_range(first_time, last_time)
- appear_count = t.get("count", 1)
- line += f" | 排名:{rank_str} | 时间:{time_str} | 出现:{appear_count}次"
- # 开启完整时间线时,额外添加轨迹
- if self.include_rank_timeline:
- rank_timeline = t.get("rank_timeline", [])
- timeline_str = self._format_rank_timeline(rank_timeline)
- line += f" | 轨迹:{timeline_str}"
- news_lines.append(line)
- news_count += 1
- if news_count >= self.max_news:
- break
- if news_count >= self.max_news:
- break
- # RSS 内容(仅在启用时构建)
- if self.include_rss and rss_stats:
- remaining = self.max_news - news_count
- for stat in rss_stats:
- if rss_count >= remaining:
- break
- word = stat.get("word", "")
- titles = stat.get("titles", [])
- if word and titles:
- rss_lines.append(f"\n**{word}** ({len(titles)}条)")
- for t in titles:
- if not isinstance(t, dict):
- continue
- title = t.get("title", "")
- if not title:
- continue
- # 来源
- source = t.get("source_name", t.get("feed_name", ""))
- # 发布时间
- time_display = t.get("time_display", "")
- # 构建行:[来源] 标题 | 发布时间
- if source:
- line = f"- [{source}] {title}"
- else:
- line = f"- {title}"
- if time_display:
- line += f" | {time_display}"
- rss_lines.append(line)
- rss_count += 1
- if rss_count >= remaining:
- break
- news_content = "\n".join(news_lines) if news_lines else ""
- rss_content = "\n".join(rss_lines) if rss_lines else ""
- total_count = news_count + rss_count
- return news_content, rss_content, hotlist_total, rss_total, total_count
- def _call_ai(self, user_prompt: str) -> str:
- """调用 AI API(使用 LiteLLM)"""
- messages = []
- if self.system_prompt:
- messages.append({"role": "system", "content": self.system_prompt})
- messages.append({"role": "user", "content": user_prompt})
- return self.client.chat(messages)
- def _format_time_range(self, first_time: str, last_time: str) -> str:
- """格式化时间范围(简化显示,只保留时分)"""
- def extract_time(time_str: str) -> str:
- if not time_str:
- return "-"
- # 尝试提取 HH:MM 部分
- if " " in time_str:
- parts = time_str.split(" ")
- if len(parts) >= 2:
- time_part = parts[1]
- if ":" in time_part:
- return time_part[:5] # HH:MM
- elif ":" in time_str:
- return time_str[:5]
- # 处理 HH-MM 格式
- result = time_str[:5] if len(time_str) >= 5 else time_str
- if len(result) == 5 and result[2] == '-':
- result = result.replace('-', ':')
- return result
- first = extract_time(first_time)
- last = extract_time(last_time)
- if first == last or last == "-":
- return first
- return f"{first}~{last}"
- def _format_rank_timeline(self, rank_timeline: List[Dict]) -> str:
- """格式化排名时间线"""
- if not rank_timeline:
- return "-"
- parts = []
- for item in rank_timeline:
- time_str = item.get("time", "")
- if len(time_str) == 5 and time_str[2] == '-':
- time_str = time_str.replace('-', ':')
- rank = item.get("rank")
- if rank is None:
- parts.append(f"0({time_str})")
- else:
- parts.append(f"{rank}({time_str})")
- return "→".join(parts)
- def _parse_response(self, response: str) -> AIAnalysisResult:
- """解析 AI 响应"""
- result = AIAnalysisResult(raw_response=response)
- if not response or not response.strip():
- result.error = "AI 返回空响应"
- return result
- # 尝试解析 JSON
- try:
- # 提取 JSON 部分
- json_str = response
- if "```json" in response:
- parts = response.split("```json", 1)
- if len(parts) > 1:
- code_block = parts[1]
- end_idx = code_block.find("```")
- if end_idx != -1:
- json_str = code_block[:end_idx]
- else:
- json_str = code_block
- elif "```" in response:
- parts = response.split("```", 2)
- if len(parts) >= 2:
- json_str = parts[1]
- json_str = json_str.strip()
- if not json_str:
- raise ValueError("提取的 JSON 内容为空")
- data = json.loads(json_str)
- # 新版字段解析
- result.core_trends = data.get("core_trends", "")
- result.sentiment_controversy = data.get("sentiment_controversy", "")
- result.signals = data.get("signals", "")
- result.rss_insights = data.get("rss_insights", "")
- result.outlook_strategy = data.get("outlook_strategy", "")
-
- result.success = True
- except json.JSONDecodeError as e:
- error_context = json_str[max(0, e.pos - 30):e.pos + 30] if json_str and e.pos else ""
- result.error = f"JSON 解析错误 (位置 {e.pos}): {e.msg}"
- if error_context:
- result.error += f",上下文: ...{error_context}..."
- # 使用原始响应填充 core_trends,确保有输出
- result.core_trends = response[:500] + "..." if len(response) > 500 else response
- result.success = True
- except (IndexError, KeyError, TypeError, ValueError) as e:
- result.error = f"响应解析错误: {type(e).__name__}: {str(e)}"
- result.core_trends = response[:500] if len(response) > 500 else response
- result.success = True
- except Exception as e:
- result.error = f"解析时发生未知错误: {type(e).__name__}: {str(e)}"
- result.core_trends = response[:500] if len(response) > 500 else response
- result.success = True
- return result
|