| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586 |
- # coding=utf-8
- """
- AI 智能筛选模块
- 通过 AI 对新闻进行标签分类:
- 1. 阶段 A:从用户兴趣描述中提取结构化标签
- 2. 阶段 B:对新闻标题按标签进行批量分类
- """
- import hashlib
- import json
- from dataclasses import dataclass, field
- from pathlib import Path
- from typing import Any, Callable, Dict, List, Optional
- from trendradar.ai.client import AIClient
- @dataclass
- class AIFilterResult:
- """AI 筛选结果,传给报告和通知模块"""
- tags: List[Dict] = field(default_factory=list)
- # [{"tag": str, "description": str, "count": int, "items": [
- # {"title": str, "source_id": str, "source_name": str,
- # "url": str, "mobile_url": str, "rank": int, "ranks": [...],
- # "first_time": str, "last_time": str, "count": int,
- # "relevance_score": float, "source_type": str}
- # ]}]
- total_matched: int = 0 # 匹配新闻总数
- total_processed: int = 0 # 处理新闻总数
- success: bool = False
- error: str = ""
- class AIFilter:
- """AI 智能筛选器"""
- def __init__(
- self,
- ai_config: Dict[str, Any],
- filter_config: Dict[str, Any],
- get_time_func: Callable,
- debug: bool = False,
- ):
- self.client = AIClient(ai_config)
- self.filter_config = filter_config
- self.batch_size = filter_config.get("BATCH_SIZE", 200)
- self.get_time_func = get_time_func
- self.debug = debug
- # 加载提示词模板
- self.classify_system, self.classify_user = self._load_prompt(
- filter_config.get("PROMPT_FILE", "ai_filter_prompt.txt")
- )
- self.extract_system, self.extract_user = self._load_prompt(
- filter_config.get("EXTRACT_PROMPT_FILE", "ai_filter_extract_prompt.txt")
- )
- self.update_tags_system, self.update_tags_user = self._load_prompt(
- filter_config.get("UPDATE_TAGS_PROMPT_FILE", "update_tags_prompt.txt")
- )
- def _load_prompt(self, filename: str) -> tuple:
- """加载提示词文件,返回 (system_prompt, user_prompt_template)"""
- config_dir = Path(__file__).parent.parent.parent / "config" / "ai_filter"
- prompt_path = config_dir / filename
- if not prompt_path.exists():
- print(f"[AI筛选] 提示词文件不存在: {prompt_path}")
- return "", ""
- content = prompt_path.read_text(encoding="utf-8")
- system_prompt = ""
- user_prompt = ""
- if "[system]" in content and "[user]" in content:
- parts = content.split("[user]")
- system_part = parts[0]
- user_part = parts[1] if len(parts) > 1 else ""
- if "[system]" in system_part:
- system_prompt = system_part.split("[system]")[1].strip()
- user_prompt = user_part.strip()
- else:
- user_prompt = content
- return system_prompt, user_prompt
- def compute_interests_hash(self, interests_content: str, filename: str = "ai_interests.txt") -> str:
- """计算兴趣描述的 hash,格式为 filename:md5"""
- # 去除前后空白和注释行,确保内容变化才改变 hash
- lines = []
- for line in interests_content.strip().splitlines():
- line = line.strip()
- if line and not line.startswith("#"):
- lines.append(line)
- normalized = "\n".join(lines)
- content_hash = hashlib.md5(normalized.encode("utf-8")).hexdigest()
- return f"{filename}:{content_hash}"
- def load_interests_content(self, interests_file: Optional[str] = None) -> Optional[str]:
- """加载兴趣描述文件内容
- 解析逻辑:
- - interests_file 为 None:使用默认 config/ai_interests.txt
- - interests_file 有值:仅查 config/custom/ai/{filename}
- 注意:调用方(context.py)已完成 config/timeline 的合并决策,
- 此处不再二次读取 filter_config,避免语义冲突。
- """
- config_dir = Path(__file__).parent.parent.parent / "config"
- configured_file = interests_file
- if configured_file:
- # 自定义兴趣文件:仅查 custom/ai 目录
- filename = configured_file
- interests_path = config_dir / "custom" / "ai" / filename
- if not interests_path.exists():
- print(f"[AI筛选] 自定义兴趣描述文件不存在: {filename}")
- print(f"[AI筛选] 已查找: {interests_path}")
- return None
- else:
- # 默认兴趣文件:固定使用 config/ai_interests.txt
- filename = "ai_interests.txt"
- interests_path = config_dir / filename
- if not interests_path.exists():
- print(f"[AI筛选] 默认兴趣描述文件不存在: {filename}")
- print(f"[AI筛选] 已查找: {interests_path}")
- return None
- if not interests_path.exists():
- print(f"[AI筛选] 兴趣描述文件不存在: {interests_path}")
- return None
- content = interests_path.read_text(encoding="utf-8").strip()
- if not content:
- print("[AI筛选] 兴趣描述文件为空")
- return None
- return content
- def extract_tags(self, interests_content: str) -> List[Dict]:
- """
- 阶段 A:从兴趣描述中提取结构化标签
- Args:
- interests_content: 用户的兴趣描述文本
- Returns:
- [{"tag": str, "description": str}, ...]
- """
- if not self.extract_user:
- print("[AI筛选] 标签提取提示词模板为空")
- return []
- user_prompt = self.extract_user.replace("{interests_content}", interests_content)
- messages = []
- if self.extract_system:
- messages.append({"role": "system", "content": self.extract_system})
- messages.append({"role": "user", "content": user_prompt})
- if self.debug:
- print(f"\n[AI筛选][DEBUG] === 标签提取 Prompt ===")
- for m in messages:
- print(f"[{m['role']}]\n{m['content']}")
- print(f"[AI筛选][DEBUG] === Prompt 结束 ===")
- try:
- response = self.client.chat(messages)
- if self.debug:
- print(f"\n[AI筛选][DEBUG] === 标签提取 AI 原始响应 ===")
- # 尝试格式化 JSON 便于阅读
- self._print_formatted_json(response)
- print(f"[AI筛选][DEBUG] === 响应结束 ===")
- tags = self._parse_tags_response(response)
- print(f"[AI筛选] 提取到 {len(tags)} 个标签")
- for t in tags:
- print(f" {t['tag']}: {t.get('description', '')}")
- if self.debug:
- json_str = self._extract_json(response)
- if not json_str:
- print(f"[AI筛选][DEBUG] 无法从响应中提取 JSON")
- else:
- raw_data = json.loads(json_str)
- raw_tags = raw_data.get("tags", [])
- skipped = len(raw_tags) - len(tags)
- if skipped > 0:
- print(f"[AI筛选][DEBUG] 原始标签 {len(raw_tags)} 个,有效 {len(tags)} 个,跳过 {skipped} 个(缺少 tag 字段或格式无效)")
- return tags
- except json.JSONDecodeError as e:
- print(f"[AI筛选] 标签提取失败: JSON 解析错误: {e}")
- if self.debug:
- print(f"[AI筛选][DEBUG] 尝试解析的 JSON 内容: {self._extract_json(response) if response else '(空响应)'}")
- return []
- except Exception as e:
- print(f"[AI筛选] 标签提取失败: {type(e).__name__}: {e}")
- return []
- def update_tags(self, old_tags: List[Dict], interests_content: str) -> Optional[Dict]:
- """
- 阶段 A':AI 对比旧标签和新兴趣描述,给出更新方案
- Args:
- old_tags: [{"tag": str, "description": str, "id": int}, ...]
- interests_content: 新的兴趣描述文本
- Returns:
- {"keep": [{"tag": str, "description": str}],
- "add": [{"tag": str, "description": str}],
- "remove": [str],
- "change_ratio": float}
- 失败返回 None
- """
- if not self.update_tags_user:
- print("[AI筛选] 标签更新提示词模板为空,回退到重新提取")
- return None
- # 构造旧标签 JSON
- old_tags_json = json.dumps(
- [{"tag": t["tag"], "description": t.get("description", "")} for t in old_tags],
- ensure_ascii=False, indent=2
- )
- user_prompt = self.update_tags_user.replace(
- "{old_tags_json}", old_tags_json
- ).replace(
- "{interests_content}", interests_content
- )
- messages = []
- if self.update_tags_system:
- messages.append({"role": "system", "content": self.update_tags_system})
- messages.append({"role": "user", "content": user_prompt})
- if self.debug:
- print(f"\n[AI筛选][DEBUG] === 标签更新 Prompt ===")
- for m in messages:
- print(f"[{m['role']}]\n{m['content']}")
- print(f"[AI筛选][DEBUG] === Prompt 结束 ===")
- try:
- response = self.client.chat(messages)
- if self.debug:
- print(f"\n[AI筛选][DEBUG] === 标签更新 AI 原始响应 ===")
- self._print_formatted_json(response)
- print(f"[AI筛选][DEBUG] === 响应结束 ===")
- result = self._parse_update_tags_response(response)
- if result is None:
- return None
- keep_count = len(result.get("keep", []))
- add_count = len(result.get("add", []))
- remove_count = len(result.get("remove", []))
- ratio = result.get("change_ratio", 0)
- print(f"[AI筛选] AI 标签更新方案: 保留 {keep_count}, 新增 {add_count}, 移除 {remove_count}, change_ratio={ratio:.2f}")
- return result
- except Exception as e:
- print(f"[AI筛选] 标签更新失败: {type(e).__name__}: {e}")
- return None
- def _parse_update_tags_response(self, response: str) -> Optional[Dict]:
- """解析标签更新的 AI 响应"""
- json_str = self._extract_json(response)
- if not json_str:
- print("[AI筛选] 无法从标签更新响应中提取 JSON")
- return None
- data = json.loads(json_str)
- # 校验必需字段
- keep = data.get("keep", [])
- add = data.get("add", [])
- remove = data.get("remove", [])
- change_ratio = float(data.get("change_ratio", 0))
- # 校验 keep/add 格式
- validated_keep = []
- for t in keep:
- if isinstance(t, dict) and "tag" in t:
- validated_keep.append({
- "tag": str(t["tag"]).strip(),
- "description": str(t.get("description", "")).strip(),
- })
- validated_add = []
- for t in add:
- if isinstance(t, dict) and "tag" in t:
- validated_add.append({
- "tag": str(t["tag"]).strip(),
- "description": str(t.get("description", "")).strip(),
- })
- validated_remove = [str(r).strip() for r in remove if r]
- # change_ratio 限制在 0~1
- change_ratio = max(0.0, min(1.0, change_ratio))
- return {
- "keep": validated_keep,
- "add": validated_add,
- "remove": validated_remove,
- "change_ratio": change_ratio,
- }
- def _parse_tags_response(self, response: str) -> List[Dict]:
- """解析标签提取的 AI 响应"""
- json_str = self._extract_json(response)
- if not json_str:
- return []
- data = json.loads(json_str)
- tags_raw = data.get("tags", [])
- tags = []
- for t in tags_raw:
- if not isinstance(t, dict) or "tag" not in t:
- continue
- tags.append({
- "tag": str(t["tag"]).strip(),
- "description": str(t.get("description", "")).strip(),
- })
- return tags
- def classify_batch(
- self,
- titles: List[Dict],
- tags: List[Dict],
- interests_content: str = "",
- ) -> List[Dict]:
- """
- 阶段 B:对一批新闻标题做分类
- Args:
- titles: [{"id": news_item_id, "title": str, "source": str}]
- tags: [{"id": tag_id, "tag": str, "description": str}]
- interests_content: 用户的兴趣描述(含质量过滤要求)
- Returns:
- [{"news_item_id": int, "tag_id": int, "relevance_score": float}, ...]
- """
- if not titles or not tags:
- return []
- if not self.classify_user:
- print("[AI筛选] 分类提示词模板为空")
- return []
- # 构建标签列表文本
- tags_list = "\n".join(
- f"{t['id']}. {t['tag']}: {t.get('description', '')}"
- for t in tags
- )
- # 构建新闻列表文本
- news_list = "\n".join(
- f"{t['id']}. [{t.get('source', '')}] {t['title']}"
- for t in titles
- )
- # 填充模板
- user_prompt = self.classify_user
- user_prompt = user_prompt.replace("{interests_content}", interests_content)
- user_prompt = user_prompt.replace("{tags_list}", tags_list)
- user_prompt = user_prompt.replace("{news_count}", str(len(titles)))
- user_prompt = user_prompt.replace("{news_list}", news_list)
- messages = []
- if self.classify_system:
- messages.append({"role": "system", "content": self.classify_system})
- messages.append({"role": "user", "content": user_prompt})
- if self.debug:
- print(f"\n[AI筛选][DEBUG] === 分类 Prompt (标题数={len(titles)}, 标签={len(tags)}) ===")
- for m in messages:
- role = m['role']
- content = m['content']
- # 截断过长的新闻列表:只显示前5条和后5条
- lines = content.split('\n')
- # 找到新闻列表区域并截断
- if len(lines) > 30:
- # 显示前15行 + 省略提示 + 后10行
- head = lines[:15]
- tail = lines[-10:]
- omitted = len(lines) - 25
- truncated = '\n'.join(head) + f'\n... (省略 {omitted} 行) ...\n' + '\n'.join(tail)
- print(f"[{role}]\n{truncated}")
- else:
- print(f"[{role}]\n{content}")
- print(f"[AI筛选][DEBUG] === Prompt 结束 (长度: {sum(len(m['content']) for m in messages)} 字符) ===")
- try:
- response = self.client.chat(messages)
- return self._parse_classify_response(response, titles, tags)
- except Exception as e:
- print(f"[AI筛选] 分类请求失败: {type(e).__name__}: {e}")
- return []
- def _parse_classify_response(
- self,
- response: str,
- titles: List[Dict],
- tags: List[Dict],
- ) -> List[Dict]:
- """解析分类的 AI 响应
- 支持两种 JSON 格式:
- - 新格式(扁平): [{"id": 1, "tag_id": 1, "score": 0.9}, ...]
- - 旧格式(嵌套): [{"id": 1, "tags": [{"tag_id": 1, "score": 0.9}]}, ...]
- 每条新闻只保留一个最高分的 tag,杜绝同一条出现在多个标签下。
- """
- json_str = self._extract_json(response)
- if not json_str:
- if self.debug:
- print(f"[AI筛选][DEBUG] 无法从分类响应中提取 JSON,原始响应前 500 字符: {(response or '')[:500]}")
- return []
- try:
- data = json.loads(json_str)
- except json.JSONDecodeError as e:
- if self.debug:
- print(f"[AI筛选][DEBUG] 分类响应 JSON 解析失败: {e}")
- print(f"[AI筛选][DEBUG] 提取的 JSON 文本前 500 字符: {json_str[:500]}")
- return []
- if not isinstance(data, list):
- if self.debug:
- print(f"[AI筛选][DEBUG] 分类响应顶层不是数组,实际类型: {type(data).__name__}")
- return []
- # 构建 id 映射
- title_ids = {t["id"] for t in titles}
- title_map = {t["id"]: t["title"] for t in titles}
- tag_id_set = {t["id"] for t in tags}
- tag_name_map = {t["id"]: t["tag"] for t in tags}
- # 每条新闻只保留一个最高分的 tag
- best_per_news: Dict[int, Dict] = {} # news_id -> {"tag_id": ..., "score": ...}
- skipped_news_ids = 0
- skipped_tag_ids = 0
- skipped_empty = 0
- for item in data:
- if not isinstance(item, dict):
- continue
- news_id = item.get("id")
- if news_id not in title_ids:
- skipped_news_ids += 1
- continue
- # 收集此条新闻的所有候选 tag
- candidates = []
- if "tag_id" in item:
- # 新格式(扁平): {"id": 1, "tag_id": 1, "score": 0.9}
- candidates.append({"tag_id": item["tag_id"], "score": item.get("score", 0.5)})
- elif "tags" in item:
- # 旧格式(嵌套): {"id": 1, "tags": [{"tag_id": 1, "score": 0.9}]}
- matched_tags = item.get("tags", [])
- if isinstance(matched_tags, list):
- if not matched_tags:
- skipped_empty += 1
- continue
- candidates.extend(matched_tags)
- if not candidates:
- skipped_empty += 1
- continue
- # 取最高分的有效 tag
- best_tag_id = None
- best_score = -1.0
- for tag_match in candidates:
- if not isinstance(tag_match, dict):
- continue
- tag_id = tag_match.get("tag_id")
- if tag_id not in tag_id_set:
- skipped_tag_ids += 1
- continue
- score = tag_match.get("score", 0.5)
- try:
- score = float(score)
- score = max(0.0, min(1.0, score))
- except (ValueError, TypeError):
- score = 0.5
- if score > best_score:
- best_score = score
- best_tag_id = tag_id
- if best_tag_id is not None:
- # 如果同一条新闻被多次返回,只保留分数更高的
- existing = best_per_news.get(news_id)
- if existing is None or best_score > existing["relevance_score"]:
- best_per_news[news_id] = {
- "news_item_id": news_id,
- "tag_id": best_tag_id,
- "relevance_score": best_score,
- }
- results = list(best_per_news.values())
- if self.debug:
- ai_returned = len(data)
- print(f"[AI筛选][DEBUG] --- 分类解析结果 ---")
- print(f"[AI筛选][DEBUG] AI 返回 {ai_returned} 条, 有效 {len(results)} 条 (每条新闻仅保留最高分 tag)")
- if skipped_empty > 0:
- print(f"[AI筛选][DEBUG] 跳过空 tags: {skipped_empty} 条")
- if skipped_news_ids > 0:
- print(f"[AI筛选][DEBUG] !! 跳过无效 news_id: {skipped_news_ids} 条")
- if skipped_tag_ids > 0:
- print(f"[AI筛选][DEBUG] !! 跳过无效 tag_id: {skipped_tag_ids} 条")
- # 按标签汇总
- tag_summary: Dict[int, List[str]] = {}
- for r in results:
- tid = r["tag_id"]
- if tid not in tag_summary:
- tag_summary[tid] = []
- tag_summary[tid].append(
- f" [{r['news_item_id']}] {title_map.get(r['news_item_id'], '?')[:40]} (score={r['relevance_score']:.2f})"
- )
- for tid, items in tag_summary.items():
- tname = tag_name_map.get(tid, f"tag_{tid}")
- print(f"[AI筛选][DEBUG] 标签「{tname}」匹配 {len(items)} 条:")
- for line in items:
- print(line)
- return results
- def _extract_json(self, response: str) -> Optional[str]:
- """从 AI 响应中提取 JSON 字符串"""
- if not response or not response.strip():
- return None
- json_str = response.strip()
- if "```json" in json_str:
- parts = json_str.split("```json", 1)
- if len(parts) > 1:
- code_block = parts[1]
- end_idx = code_block.find("```")
- json_str = code_block[:end_idx] if end_idx != -1 else code_block
- elif "```" in json_str:
- parts = json_str.split("```", 2)
- if len(parts) >= 2:
- json_str = parts[1]
- json_str = json_str.strip()
- return json_str if json_str else None
- def _print_formatted_json(self, response: str) -> None:
- """格式化打印 AI 响应中的 JSON,便于 debug 阅读"""
- if not response:
- print("(空响应)")
- return
- json_str = self._extract_json(response)
- if json_str:
- try:
- data = json.loads(json_str)
- if isinstance(data, list):
- # 数组:每个元素压成一行
- lines = [json.dumps(item, ensure_ascii=False) for item in data]
- print("[\n " + ",\n ".join(lines) + "\n]")
- else:
- print(json.dumps(data, ensure_ascii=False, indent=2))
- return
- except json.JSONDecodeError:
- pass
- # JSON 解析失败,直接打印原始响应
- print(response)
|