filter.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. # coding=utf-8
  2. """
  3. AI 智能筛选模块
  4. 通过 AI 对新闻进行标签分类:
  5. 1. 阶段 A:从用户兴趣描述中提取结构化标签
  6. 2. 阶段 B:对新闻标题按标签进行批量分类
  7. """
  8. import hashlib
  9. import json
  10. from dataclasses import dataclass, field
  11. from pathlib import Path
  12. from typing import Any, Callable, Dict, List, Optional
  13. from trendradar.ai.client import AIClient
  14. @dataclass
  15. class AIFilterResult:
  16. """AI 筛选结果,传给报告和通知模块"""
  17. tags: List[Dict] = field(default_factory=list)
  18. # [{"tag": str, "description": str, "count": int, "items": [
  19. # {"title": str, "source_id": str, "source_name": str,
  20. # "url": str, "mobile_url": str, "rank": int, "ranks": [...],
  21. # "first_time": str, "last_time": str, "count": int,
  22. # "relevance_score": float, "source_type": str}
  23. # ]}]
  24. total_matched: int = 0 # 匹配新闻总数
  25. total_processed: int = 0 # 处理新闻总数
  26. success: bool = False
  27. error: str = ""
  28. class AIFilter:
  29. """AI 智能筛选器"""
  30. def __init__(
  31. self,
  32. ai_config: Dict[str, Any],
  33. filter_config: Dict[str, Any],
  34. get_time_func: Callable,
  35. debug: bool = False,
  36. ):
  37. self.client = AIClient(ai_config)
  38. self.filter_config = filter_config
  39. self.batch_size = filter_config.get("BATCH_SIZE", 200)
  40. self.get_time_func = get_time_func
  41. self.debug = debug
  42. # 加载提示词模板
  43. self.classify_system, self.classify_user = self._load_prompt(
  44. filter_config.get("PROMPT_FILE", "ai_filter_prompt.txt")
  45. )
  46. self.extract_system, self.extract_user = self._load_prompt(
  47. filter_config.get("EXTRACT_PROMPT_FILE", "ai_filter_extract_prompt.txt")
  48. )
  49. self.update_tags_system, self.update_tags_user = self._load_prompt(
  50. filter_config.get("UPDATE_TAGS_PROMPT_FILE", "update_tags_prompt.txt")
  51. )
  52. def _load_prompt(self, filename: str) -> tuple:
  53. """加载提示词文件,返回 (system_prompt, user_prompt_template)"""
  54. config_dir = Path(__file__).parent.parent.parent / "config" / "ai_filter"
  55. prompt_path = config_dir / filename
  56. if not prompt_path.exists():
  57. print(f"[AI筛选] 提示词文件不存在: {prompt_path}")
  58. return "", ""
  59. content = prompt_path.read_text(encoding="utf-8")
  60. system_prompt = ""
  61. user_prompt = ""
  62. if "[system]" in content and "[user]" in content:
  63. parts = content.split("[user]")
  64. system_part = parts[0]
  65. user_part = parts[1] if len(parts) > 1 else ""
  66. if "[system]" in system_part:
  67. system_prompt = system_part.split("[system]")[1].strip()
  68. user_prompt = user_part.strip()
  69. else:
  70. user_prompt = content
  71. return system_prompt, user_prompt
  72. def compute_interests_hash(self, interests_content: str, filename: str = "ai_interests.txt") -> str:
  73. """计算兴趣描述的 hash,格式为 filename:md5"""
  74. # 去除前后空白和注释行,确保内容变化才改变 hash
  75. lines = []
  76. for line in interests_content.strip().splitlines():
  77. line = line.strip()
  78. if line and not line.startswith("#"):
  79. lines.append(line)
  80. normalized = "\n".join(lines)
  81. content_hash = hashlib.md5(normalized.encode("utf-8")).hexdigest()
  82. return f"{filename}:{content_hash}"
  83. def load_interests_content(self, interests_file: Optional[str] = None) -> Optional[str]:
  84. """加载兴趣描述文件内容
  85. 解析逻辑:
  86. - interests_file 为 None:使用默认 config/ai_interests.txt
  87. - interests_file 有值:仅查 config/custom/ai/{filename}
  88. 注意:调用方(context.py)已完成 config/timeline 的合并决策,
  89. 此处不再二次读取 filter_config,避免语义冲突。
  90. """
  91. config_dir = Path(__file__).parent.parent.parent / "config"
  92. configured_file = interests_file
  93. if configured_file:
  94. # 自定义兴趣文件:仅查 custom/ai 目录
  95. filename = configured_file
  96. interests_path = config_dir / "custom" / "ai" / filename
  97. if not interests_path.exists():
  98. print(f"[AI筛选] 自定义兴趣描述文件不存在: {filename}")
  99. print(f"[AI筛选] 已查找: {interests_path}")
  100. return None
  101. else:
  102. # 默认兴趣文件:固定使用 config/ai_interests.txt
  103. filename = "ai_interests.txt"
  104. interests_path = config_dir / filename
  105. if not interests_path.exists():
  106. print(f"[AI筛选] 默认兴趣描述文件不存在: {filename}")
  107. print(f"[AI筛选] 已查找: {interests_path}")
  108. return None
  109. if not interests_path.exists():
  110. print(f"[AI筛选] 兴趣描述文件不存在: {interests_path}")
  111. return None
  112. content = interests_path.read_text(encoding="utf-8").strip()
  113. if not content:
  114. print("[AI筛选] 兴趣描述文件为空")
  115. return None
  116. return content
  117. def extract_tags(self, interests_content: str) -> List[Dict]:
  118. """
  119. 阶段 A:从兴趣描述中提取结构化标签
  120. Args:
  121. interests_content: 用户的兴趣描述文本
  122. Returns:
  123. [{"tag": str, "description": str}, ...]
  124. """
  125. if not self.extract_user:
  126. print("[AI筛选] 标签提取提示词模板为空")
  127. return []
  128. user_prompt = self.extract_user.replace("{interests_content}", interests_content)
  129. messages = []
  130. if self.extract_system:
  131. messages.append({"role": "system", "content": self.extract_system})
  132. messages.append({"role": "user", "content": user_prompt})
  133. if self.debug:
  134. print(f"\n[AI筛选][DEBUG] === 标签提取 Prompt ===")
  135. for m in messages:
  136. print(f"[{m['role']}]\n{m['content']}")
  137. print(f"[AI筛选][DEBUG] === Prompt 结束 ===")
  138. try:
  139. response = self.client.chat(messages)
  140. if self.debug:
  141. print(f"\n[AI筛选][DEBUG] === 标签提取 AI 原始响应 ===")
  142. # 尝试格式化 JSON 便于阅读
  143. self._print_formatted_json(response)
  144. print(f"[AI筛选][DEBUG] === 响应结束 ===")
  145. tags = self._parse_tags_response(response)
  146. print(f"[AI筛选] 提取到 {len(tags)} 个标签")
  147. for t in tags:
  148. print(f" {t['tag']}: {t.get('description', '')}")
  149. if self.debug:
  150. json_str = self._extract_json(response)
  151. if not json_str:
  152. print(f"[AI筛选][DEBUG] 无法从响应中提取 JSON")
  153. else:
  154. raw_data = json.loads(json_str)
  155. raw_tags = raw_data.get("tags", [])
  156. skipped = len(raw_tags) - len(tags)
  157. if skipped > 0:
  158. print(f"[AI筛选][DEBUG] 原始标签 {len(raw_tags)} 个,有效 {len(tags)} 个,跳过 {skipped} 个(缺少 tag 字段或格式无效)")
  159. return tags
  160. except json.JSONDecodeError as e:
  161. print(f"[AI筛选] 标签提取失败: JSON 解析错误: {e}")
  162. if self.debug:
  163. print(f"[AI筛选][DEBUG] 尝试解析的 JSON 内容: {self._extract_json(response) if response else '(空响应)'}")
  164. return []
  165. except Exception as e:
  166. print(f"[AI筛选] 标签提取失败: {type(e).__name__}: {e}")
  167. return []
  168. def update_tags(self, old_tags: List[Dict], interests_content: str) -> Optional[Dict]:
  169. """
  170. 阶段 A':AI 对比旧标签和新兴趣描述,给出更新方案
  171. Args:
  172. old_tags: [{"tag": str, "description": str, "id": int}, ...]
  173. interests_content: 新的兴趣描述文本
  174. Returns:
  175. {"keep": [{"tag": str, "description": str}],
  176. "add": [{"tag": str, "description": str}],
  177. "remove": [str],
  178. "change_ratio": float}
  179. 失败返回 None
  180. """
  181. if not self.update_tags_user:
  182. print("[AI筛选] 标签更新提示词模板为空,回退到重新提取")
  183. return None
  184. # 构造旧标签 JSON
  185. old_tags_json = json.dumps(
  186. [{"tag": t["tag"], "description": t.get("description", "")} for t in old_tags],
  187. ensure_ascii=False, indent=2
  188. )
  189. user_prompt = self.update_tags_user.replace(
  190. "{old_tags_json}", old_tags_json
  191. ).replace(
  192. "{interests_content}", interests_content
  193. )
  194. messages = []
  195. if self.update_tags_system:
  196. messages.append({"role": "system", "content": self.update_tags_system})
  197. messages.append({"role": "user", "content": user_prompt})
  198. if self.debug:
  199. print(f"\n[AI筛选][DEBUG] === 标签更新 Prompt ===")
  200. for m in messages:
  201. print(f"[{m['role']}]\n{m['content']}")
  202. print(f"[AI筛选][DEBUG] === Prompt 结束 ===")
  203. try:
  204. response = self.client.chat(messages)
  205. if self.debug:
  206. print(f"\n[AI筛选][DEBUG] === 标签更新 AI 原始响应 ===")
  207. self._print_formatted_json(response)
  208. print(f"[AI筛选][DEBUG] === 响应结束 ===")
  209. result = self._parse_update_tags_response(response)
  210. if result is None:
  211. return None
  212. keep_count = len(result.get("keep", []))
  213. add_count = len(result.get("add", []))
  214. remove_count = len(result.get("remove", []))
  215. ratio = result.get("change_ratio", 0)
  216. print(f"[AI筛选] AI 标签更新方案: 保留 {keep_count}, 新增 {add_count}, 移除 {remove_count}, change_ratio={ratio:.2f}")
  217. return result
  218. except Exception as e:
  219. print(f"[AI筛选] 标签更新失败: {type(e).__name__}: {e}")
  220. return None
  221. def _parse_update_tags_response(self, response: str) -> Optional[Dict]:
  222. """解析标签更新的 AI 响应"""
  223. json_str = self._extract_json(response)
  224. if not json_str:
  225. print("[AI筛选] 无法从标签更新响应中提取 JSON")
  226. return None
  227. data = json.loads(json_str)
  228. # 校验必需字段
  229. keep = data.get("keep", [])
  230. add = data.get("add", [])
  231. remove = data.get("remove", [])
  232. change_ratio = float(data.get("change_ratio", 0))
  233. # 校验 keep/add 格式
  234. validated_keep = []
  235. for t in keep:
  236. if isinstance(t, dict) and "tag" in t:
  237. validated_keep.append({
  238. "tag": str(t["tag"]).strip(),
  239. "description": str(t.get("description", "")).strip(),
  240. })
  241. validated_add = []
  242. for t in add:
  243. if isinstance(t, dict) and "tag" in t:
  244. validated_add.append({
  245. "tag": str(t["tag"]).strip(),
  246. "description": str(t.get("description", "")).strip(),
  247. })
  248. validated_remove = [str(r).strip() for r in remove if r]
  249. # change_ratio 限制在 0~1
  250. change_ratio = max(0.0, min(1.0, change_ratio))
  251. return {
  252. "keep": validated_keep,
  253. "add": validated_add,
  254. "remove": validated_remove,
  255. "change_ratio": change_ratio,
  256. }
  257. def _parse_tags_response(self, response: str) -> List[Dict]:
  258. """解析标签提取的 AI 响应"""
  259. json_str = self._extract_json(response)
  260. if not json_str:
  261. return []
  262. data = json.loads(json_str)
  263. tags_raw = data.get("tags", [])
  264. tags = []
  265. for t in tags_raw:
  266. if not isinstance(t, dict) or "tag" not in t:
  267. continue
  268. tags.append({
  269. "tag": str(t["tag"]).strip(),
  270. "description": str(t.get("description", "")).strip(),
  271. })
  272. return tags
  273. def classify_batch(
  274. self,
  275. titles: List[Dict],
  276. tags: List[Dict],
  277. interests_content: str = "",
  278. ) -> List[Dict]:
  279. """
  280. 阶段 B:对一批新闻标题做分类
  281. Args:
  282. titles: [{"id": news_item_id, "title": str, "source": str}]
  283. tags: [{"id": tag_id, "tag": str, "description": str}]
  284. interests_content: 用户的兴趣描述(含质量过滤要求)
  285. Returns:
  286. [{"news_item_id": int, "tag_id": int, "relevance_score": float}, ...]
  287. """
  288. if not titles or not tags:
  289. return []
  290. if not self.classify_user:
  291. print("[AI筛选] 分类提示词模板为空")
  292. return []
  293. # 构建标签列表文本
  294. tags_list = "\n".join(
  295. f"{t['id']}. {t['tag']}: {t.get('description', '')}"
  296. for t in tags
  297. )
  298. # 构建新闻列表文本
  299. news_list = "\n".join(
  300. f"{t['id']}. [{t.get('source', '')}] {t['title']}"
  301. for t in titles
  302. )
  303. # 填充模板
  304. user_prompt = self.classify_user
  305. user_prompt = user_prompt.replace("{interests_content}", interests_content)
  306. user_prompt = user_prompt.replace("{tags_list}", tags_list)
  307. user_prompt = user_prompt.replace("{news_count}", str(len(titles)))
  308. user_prompt = user_prompt.replace("{news_list}", news_list)
  309. messages = []
  310. if self.classify_system:
  311. messages.append({"role": "system", "content": self.classify_system})
  312. messages.append({"role": "user", "content": user_prompt})
  313. if self.debug:
  314. print(f"\n[AI筛选][DEBUG] === 分类 Prompt (标题数={len(titles)}, 标签={len(tags)}) ===")
  315. for m in messages:
  316. role = m['role']
  317. content = m['content']
  318. # 截断过长的新闻列表:只显示前5条和后5条
  319. lines = content.split('\n')
  320. # 找到新闻列表区域并截断
  321. if len(lines) > 30:
  322. # 显示前15行 + 省略提示 + 后10行
  323. head = lines[:15]
  324. tail = lines[-10:]
  325. omitted = len(lines) - 25
  326. truncated = '\n'.join(head) + f'\n... (省略 {omitted} 行) ...\n' + '\n'.join(tail)
  327. print(f"[{role}]\n{truncated}")
  328. else:
  329. print(f"[{role}]\n{content}")
  330. print(f"[AI筛选][DEBUG] === Prompt 结束 (长度: {sum(len(m['content']) for m in messages)} 字符) ===")
  331. try:
  332. response = self.client.chat(messages)
  333. return self._parse_classify_response(response, titles, tags)
  334. except Exception as e:
  335. print(f"[AI筛选] 分类请求失败: {type(e).__name__}: {e}")
  336. return []
  337. def _parse_classify_response(
  338. self,
  339. response: str,
  340. titles: List[Dict],
  341. tags: List[Dict],
  342. ) -> List[Dict]:
  343. """解析分类的 AI 响应
  344. 支持两种 JSON 格式:
  345. - 新格式(扁平): [{"id": 1, "tag_id": 1, "score": 0.9}, ...]
  346. - 旧格式(嵌套): [{"id": 1, "tags": [{"tag_id": 1, "score": 0.9}]}, ...]
  347. 每条新闻只保留一个最高分的 tag,杜绝同一条出现在多个标签下。
  348. """
  349. json_str = self._extract_json(response)
  350. if not json_str:
  351. if self.debug:
  352. print(f"[AI筛选][DEBUG] 无法从分类响应中提取 JSON,原始响应前 500 字符: {(response or '')[:500]}")
  353. return []
  354. try:
  355. data = json.loads(json_str)
  356. except json.JSONDecodeError as e:
  357. if self.debug:
  358. print(f"[AI筛选][DEBUG] 分类响应 JSON 解析失败: {e}")
  359. print(f"[AI筛选][DEBUG] 提取的 JSON 文本前 500 字符: {json_str[:500]}")
  360. return []
  361. if not isinstance(data, list):
  362. if self.debug:
  363. print(f"[AI筛选][DEBUG] 分类响应顶层不是数组,实际类型: {type(data).__name__}")
  364. return []
  365. # 构建 id 映射
  366. title_ids = {t["id"] for t in titles}
  367. title_map = {t["id"]: t["title"] for t in titles}
  368. tag_id_set = {t["id"] for t in tags}
  369. tag_name_map = {t["id"]: t["tag"] for t in tags}
  370. # 每条新闻只保留一个最高分的 tag
  371. best_per_news: Dict[int, Dict] = {} # news_id -> {"tag_id": ..., "score": ...}
  372. skipped_news_ids = 0
  373. skipped_tag_ids = 0
  374. skipped_empty = 0
  375. for item in data:
  376. if not isinstance(item, dict):
  377. continue
  378. news_id = item.get("id")
  379. if news_id not in title_ids:
  380. skipped_news_ids += 1
  381. continue
  382. # 收集此条新闻的所有候选 tag
  383. candidates = []
  384. if "tag_id" in item:
  385. # 新格式(扁平): {"id": 1, "tag_id": 1, "score": 0.9}
  386. candidates.append({"tag_id": item["tag_id"], "score": item.get("score", 0.5)})
  387. elif "tags" in item:
  388. # 旧格式(嵌套): {"id": 1, "tags": [{"tag_id": 1, "score": 0.9}]}
  389. matched_tags = item.get("tags", [])
  390. if isinstance(matched_tags, list):
  391. if not matched_tags:
  392. skipped_empty += 1
  393. continue
  394. candidates.extend(matched_tags)
  395. if not candidates:
  396. skipped_empty += 1
  397. continue
  398. # 取最高分的有效 tag
  399. best_tag_id = None
  400. best_score = -1.0
  401. for tag_match in candidates:
  402. if not isinstance(tag_match, dict):
  403. continue
  404. tag_id = tag_match.get("tag_id")
  405. if tag_id not in tag_id_set:
  406. skipped_tag_ids += 1
  407. continue
  408. score = tag_match.get("score", 0.5)
  409. try:
  410. score = float(score)
  411. score = max(0.0, min(1.0, score))
  412. except (ValueError, TypeError):
  413. score = 0.5
  414. if score > best_score:
  415. best_score = score
  416. best_tag_id = tag_id
  417. if best_tag_id is not None:
  418. # 如果同一条新闻被多次返回,只保留分数更高的
  419. existing = best_per_news.get(news_id)
  420. if existing is None or best_score > existing["relevance_score"]:
  421. best_per_news[news_id] = {
  422. "news_item_id": news_id,
  423. "tag_id": best_tag_id,
  424. "relevance_score": best_score,
  425. }
  426. results = list(best_per_news.values())
  427. if self.debug:
  428. ai_returned = len(data)
  429. print(f"[AI筛选][DEBUG] --- 分类解析结果 ---")
  430. print(f"[AI筛选][DEBUG] AI 返回 {ai_returned} 条, 有效 {len(results)} 条 (每条新闻仅保留最高分 tag)")
  431. if skipped_empty > 0:
  432. print(f"[AI筛选][DEBUG] 跳过空 tags: {skipped_empty} 条")
  433. if skipped_news_ids > 0:
  434. print(f"[AI筛选][DEBUG] !! 跳过无效 news_id: {skipped_news_ids} 条")
  435. if skipped_tag_ids > 0:
  436. print(f"[AI筛选][DEBUG] !! 跳过无效 tag_id: {skipped_tag_ids} 条")
  437. # 按标签汇总
  438. tag_summary: Dict[int, List[str]] = {}
  439. for r in results:
  440. tid = r["tag_id"]
  441. if tid not in tag_summary:
  442. tag_summary[tid] = []
  443. tag_summary[tid].append(
  444. f" [{r['news_item_id']}] {title_map.get(r['news_item_id'], '?')[:40]} (score={r['relevance_score']:.2f})"
  445. )
  446. for tid, items in tag_summary.items():
  447. tname = tag_name_map.get(tid, f"tag_{tid}")
  448. print(f"[AI筛选][DEBUG] 标签「{tname}」匹配 {len(items)} 条:")
  449. for line in items:
  450. print(line)
  451. return results
  452. def _extract_json(self, response: str) -> Optional[str]:
  453. """从 AI 响应中提取 JSON 字符串"""
  454. if not response or not response.strip():
  455. return None
  456. json_str = response.strip()
  457. if "```json" in json_str:
  458. parts = json_str.split("```json", 1)
  459. if len(parts) > 1:
  460. code_block = parts[1]
  461. end_idx = code_block.find("```")
  462. json_str = code_block[:end_idx] if end_idx != -1 else code_block
  463. elif "```" in json_str:
  464. parts = json_str.split("```", 2)
  465. if len(parts) >= 2:
  466. json_str = parts[1]
  467. json_str = json_str.strip()
  468. return json_str if json_str else None
  469. def _print_formatted_json(self, response: str) -> None:
  470. """格式化打印 AI 响应中的 JSON,便于 debug 阅读"""
  471. if not response:
  472. print("(空响应)")
  473. return
  474. json_str = self._extract_json(response)
  475. if json_str:
  476. try:
  477. data = json.loads(json_str)
  478. if isinstance(data, list):
  479. # 数组:每个元素压成一行
  480. lines = [json.dumps(item, ensure_ascii=False) for item in data]
  481. print("[\n " + ",\n ".join(lines) + "\n]")
  482. else:
  483. print(json.dumps(data, ensure_ascii=False, indent=2))
  484. return
  485. except json.JSONDecodeError:
  486. pass
  487. # JSON 解析失败,直接打印原始响应
  488. print(response)