filter.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. # coding=utf-8
  2. """
  3. AI 智能筛选模块
  4. 通过 AI 对新闻进行标签分类:
  5. 1. 阶段 A:从用户兴趣描述中提取结构化标签
  6. 2. 阶段 B:对新闻标题按标签进行批量分类
  7. """
  8. import hashlib
  9. import json
  10. from dataclasses import dataclass, field
  11. from pathlib import Path
  12. from typing import Any, Callable, Dict, List, Optional
  13. from trendradar.ai.client import AIClient
  14. from trendradar.ai.prompt_loader import load_prompt_template
  15. @dataclass
  16. class AIFilterResult:
  17. """AI 筛选结果,传给报告和通知模块"""
  18. tags: List[Dict] = field(default_factory=list)
  19. # [{"tag": str, "description": str, "count": int, "items": [
  20. # {"title": str, "source_id": str, "source_name": str,
  21. # "url": str, "mobile_url": str, "rank": int, "ranks": [...],
  22. # "first_time": str, "last_time": str, "count": int,
  23. # "relevance_score": float, "source_type": str}
  24. # ]}]
  25. total_matched: int = 0 # 匹配新闻总数
  26. total_processed: int = 0 # 处理新闻总数
  27. success: bool = False
  28. error: str = ""
  29. class AIFilter:
  30. """AI 智能筛选器"""
  31. def __init__(
  32. self,
  33. ai_config: Dict[str, Any],
  34. filter_config: Dict[str, Any],
  35. get_time_func: Callable,
  36. debug: bool = False,
  37. ):
  38. self.client = AIClient(ai_config)
  39. self.filter_config = filter_config
  40. self.batch_size = filter_config.get("BATCH_SIZE", 200)
  41. self.get_time_func = get_time_func
  42. self.debug = debug
  43. # 加载提示词模板
  44. self.classify_system, self.classify_user = load_prompt_template(
  45. filter_config.get("PROMPT_FILE", "ai_filter_prompt.txt"),
  46. config_subdir="ai_filter", label="AI筛选",
  47. )
  48. self.extract_system, self.extract_user = load_prompt_template(
  49. filter_config.get("EXTRACT_PROMPT_FILE", "ai_filter_extract_prompt.txt"),
  50. config_subdir="ai_filter", label="AI筛选",
  51. )
  52. self.update_tags_system, self.update_tags_user = load_prompt_template(
  53. filter_config.get("UPDATE_TAGS_PROMPT_FILE", "update_tags_prompt.txt"),
  54. config_subdir="ai_filter", label="AI筛选",
  55. )
  56. def compute_interests_hash(self, interests_content: str, filename: str = "ai_interests.txt") -> str:
  57. """计算兴趣描述的 hash,格式为 filename:md5"""
  58. # 去除前后空白和注释行,确保内容变化才改变 hash
  59. lines = []
  60. for line in interests_content.strip().splitlines():
  61. line = line.strip()
  62. if line and not line.startswith("#"):
  63. lines.append(line)
  64. normalized = "\n".join(lines)
  65. content_hash = hashlib.md5(normalized.encode("utf-8")).hexdigest()
  66. return f"{filename}:{content_hash}"
  67. def load_interests_content(self, interests_file: Optional[str] = None) -> Optional[str]:
  68. """加载兴趣描述文件内容
  69. 解析逻辑:
  70. - interests_file 为 None:使用默认 config/ai_interests.txt
  71. - interests_file 有值:仅查 config/custom/ai/{filename}
  72. 注意:调用方(context.py)已完成 config/timeline 的合并决策,
  73. 此处不再二次读取 filter_config,避免语义冲突。
  74. """
  75. config_dir = Path(__file__).parent.parent.parent / "config"
  76. configured_file = interests_file
  77. if configured_file:
  78. # 自定义兴趣文件:仅查 custom/ai 目录
  79. filename = configured_file
  80. interests_path = config_dir / "custom" / "ai" / filename
  81. if not interests_path.exists():
  82. print(f"[AI筛选] 自定义兴趣描述文件不存在: {filename}")
  83. print(f"[AI筛选] 已查找: {interests_path}")
  84. return None
  85. else:
  86. # 默认兴趣文件:固定使用 config/ai_interests.txt
  87. filename = "ai_interests.txt"
  88. interests_path = config_dir / filename
  89. if not interests_path.exists():
  90. print(f"[AI筛选] 默认兴趣描述文件不存在: {filename}")
  91. print(f"[AI筛选] 已查找: {interests_path}")
  92. return None
  93. if not interests_path.exists():
  94. print(f"[AI筛选] 兴趣描述文件不存在: {interests_path}")
  95. return None
  96. content = interests_path.read_text(encoding="utf-8").strip()
  97. if not content:
  98. print("[AI筛选] 兴趣描述文件为空")
  99. return None
  100. return content
  101. def extract_tags(self, interests_content: str) -> List[Dict]:
  102. """
  103. 阶段 A:从兴趣描述中提取结构化标签
  104. Args:
  105. interests_content: 用户的兴趣描述文本
  106. Returns:
  107. [{"tag": str, "description": str}, ...]
  108. """
  109. if not self.extract_user:
  110. print("[AI筛选] 标签提取提示词模板为空")
  111. return []
  112. user_prompt = self.extract_user.replace("{interests_content}", interests_content)
  113. messages = []
  114. if self.extract_system:
  115. messages.append({"role": "system", "content": self.extract_system})
  116. messages.append({"role": "user", "content": user_prompt})
  117. if self.debug:
  118. print(f"\n[AI筛选][DEBUG] === 标签提取 Prompt ===")
  119. for m in messages:
  120. print(f"[{m['role']}]\n{m['content']}")
  121. print(f"[AI筛选][DEBUG] === Prompt 结束 ===")
  122. try:
  123. response = self.client.chat(messages)
  124. if self.debug:
  125. print(f"\n[AI筛选][DEBUG] === 标签提取 AI 原始响应 ===")
  126. # 尝试格式化 JSON 便于阅读
  127. self._print_formatted_json(response)
  128. print(f"[AI筛选][DEBUG] === 响应结束 ===")
  129. tags = self._parse_tags_response(response)
  130. print(f"[AI筛选] 提取到 {len(tags)} 个标签")
  131. for t in tags:
  132. print(f" {t['tag']}: {t.get('description', '')}")
  133. if self.debug:
  134. json_str = self._extract_json(response)
  135. if not json_str:
  136. print(f"[AI筛选][DEBUG] 无法从响应中提取 JSON")
  137. else:
  138. raw_data = json.loads(json_str)
  139. raw_tags = raw_data.get("tags", [])
  140. skipped = len(raw_tags) - len(tags)
  141. if skipped > 0:
  142. print(f"[AI筛选][DEBUG] 原始标签 {len(raw_tags)} 个,有效 {len(tags)} 个,跳过 {skipped} 个(缺少 tag 字段或格式无效)")
  143. return tags
  144. except json.JSONDecodeError as e:
  145. print(f"[AI筛选] 标签提取失败: JSON 解析错误: {e}")
  146. if self.debug:
  147. print(f"[AI筛选][DEBUG] 尝试解析的 JSON 内容: {self._extract_json(response) if response else '(空响应)'}")
  148. return []
  149. except Exception as e:
  150. print(f"[AI筛选] 标签提取失败: {type(e).__name__}: {e}")
  151. return []
  152. def update_tags(self, old_tags: List[Dict], interests_content: str) -> Optional[Dict]:
  153. """
  154. 阶段 A':AI 对比旧标签和新兴趣描述,给出更新方案
  155. Args:
  156. old_tags: [{"tag": str, "description": str, "id": int}, ...]
  157. interests_content: 新的兴趣描述文本
  158. Returns:
  159. {"keep": [{"tag": str, "description": str}],
  160. "add": [{"tag": str, "description": str}],
  161. "remove": [str],
  162. "change_ratio": float}
  163. 失败返回 None
  164. """
  165. if not self.update_tags_user:
  166. print("[AI筛选] 标签更新提示词模板为空,回退到重新提取")
  167. return None
  168. # 构造旧标签 JSON
  169. old_tags_json = json.dumps(
  170. [{"tag": t["tag"], "description": t.get("description", "")} for t in old_tags],
  171. ensure_ascii=False, indent=2
  172. )
  173. user_prompt = self.update_tags_user.replace(
  174. "{old_tags_json}", old_tags_json
  175. ).replace(
  176. "{interests_content}", interests_content
  177. )
  178. messages = []
  179. if self.update_tags_system:
  180. messages.append({"role": "system", "content": self.update_tags_system})
  181. messages.append({"role": "user", "content": user_prompt})
  182. if self.debug:
  183. print(f"\n[AI筛选][DEBUG] === 标签更新 Prompt ===")
  184. for m in messages:
  185. print(f"[{m['role']}]\n{m['content']}")
  186. print(f"[AI筛选][DEBUG] === Prompt 结束 ===")
  187. try:
  188. response = self.client.chat(messages)
  189. if self.debug:
  190. print(f"\n[AI筛选][DEBUG] === 标签更新 AI 原始响应 ===")
  191. self._print_formatted_json(response)
  192. print(f"[AI筛选][DEBUG] === 响应结束 ===")
  193. result = self._parse_update_tags_response(response)
  194. if result is None:
  195. return None
  196. keep_count = len(result.get("keep", []))
  197. add_count = len(result.get("add", []))
  198. remove_count = len(result.get("remove", []))
  199. ratio = result.get("change_ratio", 0)
  200. print(f"[AI筛选] AI 标签更新方案: 保留 {keep_count}, 新增 {add_count}, 移除 {remove_count}, change_ratio={ratio:.2f}")
  201. return result
  202. except Exception as e:
  203. print(f"[AI筛选] 标签更新失败: {type(e).__name__}: {e}")
  204. return None
  205. def _parse_update_tags_response(self, response: str) -> Optional[Dict]:
  206. """解析标签更新的 AI 响应"""
  207. json_str = self._extract_json(response)
  208. if not json_str:
  209. print("[AI筛选] 无法从标签更新响应中提取 JSON")
  210. return None
  211. data = json.loads(json_str)
  212. # 校验必需字段
  213. keep = data.get("keep", [])
  214. add = data.get("add", [])
  215. remove = data.get("remove", [])
  216. change_ratio = float(data.get("change_ratio", 0))
  217. # 校验 keep/add 格式
  218. validated_keep = []
  219. for t in keep:
  220. if isinstance(t, dict) and "tag" in t:
  221. validated_keep.append({
  222. "tag": str(t["tag"]).strip(),
  223. "description": str(t.get("description", "")).strip(),
  224. })
  225. validated_add = []
  226. for t in add:
  227. if isinstance(t, dict) and "tag" in t:
  228. validated_add.append({
  229. "tag": str(t["tag"]).strip(),
  230. "description": str(t.get("description", "")).strip(),
  231. })
  232. validated_remove = [str(r).strip() for r in remove if r]
  233. # change_ratio 限制在 0~1
  234. change_ratio = max(0.0, min(1.0, change_ratio))
  235. return {
  236. "keep": validated_keep,
  237. "add": validated_add,
  238. "remove": validated_remove,
  239. "change_ratio": change_ratio,
  240. }
  241. def _parse_tags_response(self, response: str) -> List[Dict]:
  242. """解析标签提取的 AI 响应"""
  243. json_str = self._extract_json(response)
  244. if not json_str:
  245. return []
  246. data = json.loads(json_str)
  247. tags_raw = data.get("tags", [])
  248. tags = []
  249. for t in tags_raw:
  250. if not isinstance(t, dict) or "tag" not in t:
  251. continue
  252. tags.append({
  253. "tag": str(t["tag"]).strip(),
  254. "description": str(t.get("description", "")).strip(),
  255. })
  256. return tags
  257. def classify_batch(
  258. self,
  259. titles: List[Dict],
  260. tags: List[Dict],
  261. interests_content: str = "",
  262. ) -> List[Dict]:
  263. """
  264. 阶段 B:对一批新闻标题做分类
  265. Args:
  266. titles: [{"id": news_item_id, "title": str, "source": str}]
  267. tags: [{"id": tag_id, "tag": str, "description": str}]
  268. interests_content: 用户的兴趣描述(含质量过滤要求)
  269. Returns:
  270. [{"news_item_id": int, "tag_id": int, "relevance_score": float}, ...]
  271. """
  272. if not titles or not tags:
  273. return []
  274. if not self.classify_user:
  275. print("[AI筛选] 分类提示词模板为空")
  276. return []
  277. # 构建标签列表文本
  278. tags_list = "\n".join(
  279. f"{t['id']}. {t['tag']}: {t.get('description', '')}"
  280. for t in tags
  281. )
  282. # 构建新闻列表文本
  283. news_list = "\n".join(
  284. f"{t['id']}. [{t.get('source', '')}] {t['title']}"
  285. for t in titles
  286. )
  287. # 填充模板
  288. user_prompt = self.classify_user
  289. user_prompt = user_prompt.replace("{interests_content}", interests_content)
  290. user_prompt = user_prompt.replace("{tags_list}", tags_list)
  291. user_prompt = user_prompt.replace("{news_count}", str(len(titles)))
  292. user_prompt = user_prompt.replace("{news_list}", news_list)
  293. messages = []
  294. if self.classify_system:
  295. messages.append({"role": "system", "content": self.classify_system})
  296. messages.append({"role": "user", "content": user_prompt})
  297. if self.debug:
  298. print(f"\n[AI筛选][DEBUG] === 分类 Prompt (标题数={len(titles)}, 标签={len(tags)}) ===")
  299. for m in messages:
  300. role = m['role']
  301. content = m['content']
  302. # 截断过长的新闻列表:只显示前5条和后5条
  303. lines = content.split('\n')
  304. # 找到新闻列表区域并截断
  305. if len(lines) > 30:
  306. # 显示前15行 + 省略提示 + 后10行
  307. head = lines[:15]
  308. tail = lines[-10:]
  309. omitted = len(lines) - 25
  310. truncated = '\n'.join(head) + f'\n... (省略 {omitted} 行) ...\n' + '\n'.join(tail)
  311. print(f"[{role}]\n{truncated}")
  312. else:
  313. print(f"[{role}]\n{content}")
  314. print(f"[AI筛选][DEBUG] === Prompt 结束 (长度: {sum(len(m['content']) for m in messages)} 字符) ===")
  315. try:
  316. response = self.client.chat(messages)
  317. return self._parse_classify_response(response, titles, tags)
  318. except Exception as e:
  319. print(f"[AI筛选] 分类请求失败: {type(e).__name__}: {e}")
  320. return []
  321. def _parse_classify_response(
  322. self,
  323. response: str,
  324. titles: List[Dict],
  325. tags: List[Dict],
  326. ) -> List[Dict]:
  327. """解析分类的 AI 响应
  328. 支持两种 JSON 格式:
  329. - 新格式(扁平): [{"id": 1, "tag_id": 1, "score": 0.9}, ...]
  330. - 旧格式(嵌套): [{"id": 1, "tags": [{"tag_id": 1, "score": 0.9}]}, ...]
  331. 每条新闻只保留一个最高分的 tag,杜绝同一条出现在多个标签下。
  332. """
  333. json_str = self._extract_json(response)
  334. if not json_str:
  335. if self.debug:
  336. print(f"[AI筛选][DEBUG] 无法从分类响应中提取 JSON,原始响应前 500 字符: {(response or '')[:500]}")
  337. return []
  338. try:
  339. data = json.loads(json_str)
  340. except json.JSONDecodeError as e:
  341. if self.debug:
  342. print(f"[AI筛选][DEBUG] 分类响应 JSON 解析失败: {e}")
  343. print(f"[AI筛选][DEBUG] 提取的 JSON 文本前 500 字符: {json_str[:500]}")
  344. return []
  345. if not isinstance(data, list):
  346. if self.debug:
  347. print(f"[AI筛选][DEBUG] 分类响应顶层不是数组,实际类型: {type(data).__name__}")
  348. return []
  349. # 构建 id 映射
  350. title_ids = {t["id"] for t in titles}
  351. title_map = {t["id"]: t["title"] for t in titles}
  352. tag_id_set = {t["id"] for t in tags}
  353. tag_name_map = {t["id"]: t["tag"] for t in tags}
  354. # 每条新闻只保留一个最高分的 tag
  355. best_per_news: Dict[int, Dict] = {} # news_id -> {"tag_id": ..., "score": ...}
  356. skipped_news_ids = 0
  357. skipped_tag_ids = 0
  358. skipped_empty = 0
  359. for item in data:
  360. if not isinstance(item, dict):
  361. continue
  362. news_id = item.get("id")
  363. if news_id not in title_ids:
  364. skipped_news_ids += 1
  365. continue
  366. # 收集此条新闻的所有候选 tag
  367. candidates = []
  368. if "tag_id" in item:
  369. # 新格式(扁平): {"id": 1, "tag_id": 1, "score": 0.9}
  370. candidates.append({"tag_id": item["tag_id"], "score": item.get("score", 0.5)})
  371. elif "tags" in item:
  372. # 旧格式(嵌套): {"id": 1, "tags": [{"tag_id": 1, "score": 0.9}]}
  373. matched_tags = item.get("tags", [])
  374. if isinstance(matched_tags, list):
  375. if not matched_tags:
  376. skipped_empty += 1
  377. continue
  378. candidates.extend(matched_tags)
  379. if not candidates:
  380. skipped_empty += 1
  381. continue
  382. # 取最高分的有效 tag
  383. best_tag_id = None
  384. best_score = -1.0
  385. for tag_match in candidates:
  386. if not isinstance(tag_match, dict):
  387. continue
  388. tag_id = tag_match.get("tag_id")
  389. if tag_id not in tag_id_set:
  390. skipped_tag_ids += 1
  391. continue
  392. score = tag_match.get("score", 0.5)
  393. try:
  394. score = float(score)
  395. score = max(0.0, min(1.0, score))
  396. except (ValueError, TypeError):
  397. score = 0.5
  398. if score > best_score:
  399. best_score = score
  400. best_tag_id = tag_id
  401. if best_tag_id is not None:
  402. # 如果同一条新闻被多次返回,只保留分数更高的
  403. existing = best_per_news.get(news_id)
  404. if existing is None or best_score > existing["relevance_score"]:
  405. best_per_news[news_id] = {
  406. "news_item_id": news_id,
  407. "tag_id": best_tag_id,
  408. "relevance_score": best_score,
  409. }
  410. results = list(best_per_news.values())
  411. if self.debug:
  412. ai_returned = len(data)
  413. print(f"[AI筛选][DEBUG] --- 分类解析结果 ---")
  414. print(f"[AI筛选][DEBUG] AI 返回 {ai_returned} 条, 有效 {len(results)} 条 (每条新闻仅保留最高分 tag)")
  415. if skipped_empty > 0:
  416. print(f"[AI筛选][DEBUG] 跳过空 tags: {skipped_empty} 条")
  417. if skipped_news_ids > 0:
  418. print(f"[AI筛选][DEBUG] !! 跳过无效 news_id: {skipped_news_ids} 条")
  419. if skipped_tag_ids > 0:
  420. print(f"[AI筛选][DEBUG] !! 跳过无效 tag_id: {skipped_tag_ids} 条")
  421. # 按标签汇总
  422. tag_summary: Dict[int, List[str]] = {}
  423. for r in results:
  424. tid = r["tag_id"]
  425. if tid not in tag_summary:
  426. tag_summary[tid] = []
  427. tag_summary[tid].append(
  428. f" [{r['news_item_id']}] {title_map.get(r['news_item_id'], '?')[:40]} (score={r['relevance_score']:.2f})"
  429. )
  430. for tid, items in tag_summary.items():
  431. tname = tag_name_map.get(tid, f"tag_{tid}")
  432. print(f"[AI筛选][DEBUG] 标签「{tname}」匹配 {len(items)} 条:")
  433. for line in items:
  434. print(line)
  435. return results
  436. def _extract_json(self, response: str) -> Optional[str]:
  437. """从 AI 响应中提取 JSON 字符串"""
  438. if not response or not response.strip():
  439. return None
  440. json_str = response.strip()
  441. if "```json" in json_str:
  442. parts = json_str.split("```json", 1)
  443. if len(parts) > 1:
  444. code_block = parts[1]
  445. end_idx = code_block.find("```")
  446. json_str = code_block[:end_idx] if end_idx != -1 else code_block
  447. elif "```" in json_str:
  448. parts = json_str.split("```", 2)
  449. if len(parts) >= 2:
  450. json_str = parts[1]
  451. json_str = json_str.strip()
  452. return json_str if json_str else None
  453. def _print_formatted_json(self, response: str) -> None:
  454. """格式化打印 AI 响应中的 JSON,便于 debug 阅读"""
  455. if not response:
  456. print("(空响应)")
  457. return
  458. json_str = self._extract_json(response)
  459. if json_str:
  460. try:
  461. data = json.loads(json_str)
  462. if isinstance(data, list):
  463. # 数组:每个元素压成一行
  464. lines = [json.dumps(item, ensure_ascii=False) for item in data]
  465. print("[\n " + ",\n ".join(lines) + "\n]")
  466. else:
  467. print(json.dumps(data, ensure_ascii=False, indent=2))
  468. return
  469. except json.JSONDecodeError:
  470. pass
  471. # JSON 解析失败,直接打印原始响应
  472. print(response)