data_query.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. """
  2. 数据查询工具
  3. 实现P0核心的数据查询工具。
  4. """
  5. from typing import Dict, List, Optional, Union
  6. from ..services.data_service import DataService
  7. from ..utils.validators import (
  8. validate_platforms,
  9. validate_limit,
  10. validate_keyword,
  11. validate_date_range,
  12. validate_top_n,
  13. validate_mode,
  14. validate_date_query,
  15. normalize_date_range
  16. )
  17. from ..utils.errors import MCPError
  18. class DataQueryTools:
  19. """数据查询工具类"""
  20. def __init__(self, project_root: str = None):
  21. """
  22. 初始化数据查询工具
  23. Args:
  24. project_root: 项目根目录
  25. """
  26. self.data_service = DataService(project_root)
  27. def get_latest_news(
  28. self,
  29. platforms: Optional[List[str]] = None,
  30. limit: Optional[int] = None,
  31. include_url: bool = False
  32. ) -> Dict:
  33. """
  34. 获取最新一批爬取的新闻数据
  35. Args:
  36. platforms: 平台ID列表,如 ['zhihu', 'weibo']
  37. limit: 返回条数限制,默认20
  38. include_url: 是否包含URL链接,默认False(节省token)
  39. Returns:
  40. 新闻列表字典
  41. Example:
  42. >>> tools = DataQueryTools()
  43. >>> result = tools.get_latest_news(platforms=['zhihu'], limit=10)
  44. >>> print(result['total'])
  45. 10
  46. """
  47. try:
  48. # 参数验证
  49. platforms = validate_platforms(platforms)
  50. limit = validate_limit(limit, default=50)
  51. # 获取数据
  52. news_list = self.data_service.get_latest_news(
  53. platforms=platforms,
  54. limit=limit,
  55. include_url=include_url
  56. )
  57. return {
  58. "news": news_list,
  59. "total": len(news_list),
  60. "platforms": platforms,
  61. "success": True
  62. }
  63. except MCPError as e:
  64. return {
  65. "success": False,
  66. "error": e.to_dict()
  67. }
  68. except Exception as e:
  69. return {
  70. "success": False,
  71. "error": {
  72. "code": "INTERNAL_ERROR",
  73. "message": str(e)
  74. }
  75. }
  76. def search_news_by_keyword(
  77. self,
  78. keyword: str,
  79. date_range: Optional[Union[Dict, str]] = None,
  80. platforms: Optional[List[str]] = None,
  81. limit: Optional[int] = None
  82. ) -> Dict:
  83. """
  84. 按关键词搜索历史新闻
  85. Args:
  86. keyword: 搜索关键词(必需)
  87. date_range: 日期范围,格式: {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
  88. platforms: 平台过滤列表
  89. limit: 返回条数限制(可选,默认返回所有)
  90. Returns:
  91. 搜索结果字典
  92. Example (假设今天是 2025-11-17):
  93. >>> tools = DataQueryTools()
  94. >>> result = tools.search_news_by_keyword(
  95. ... keyword="人工智能",
  96. ... date_range={"start": "2025-11-08", "end": "2025-11-17"},
  97. ... limit=50
  98. ... )
  99. >>> print(result['total'])
  100. """
  101. try:
  102. # 参数验证
  103. keyword = validate_keyword(keyword)
  104. date_range_tuple = validate_date_range(date_range)
  105. platforms = validate_platforms(platforms)
  106. if limit is not None:
  107. limit = validate_limit(limit, default=100)
  108. # 搜索数据
  109. search_result = self.data_service.search_news_by_keyword(
  110. keyword=keyword,
  111. date_range=date_range_tuple,
  112. platforms=platforms,
  113. limit=limit
  114. )
  115. return {
  116. **search_result,
  117. "success": True
  118. }
  119. except MCPError as e:
  120. return {
  121. "success": False,
  122. "error": e.to_dict()
  123. }
  124. except Exception as e:
  125. return {
  126. "success": False,
  127. "error": {
  128. "code": "INTERNAL_ERROR",
  129. "message": str(e)
  130. }
  131. }
  132. def get_trending_topics(
  133. self,
  134. top_n: Optional[int] = None,
  135. mode: Optional[str] = None,
  136. extract_mode: Optional[str] = None
  137. ) -> Dict:
  138. """
  139. 获取热点话题统计
  140. Args:
  141. top_n: 返回TOP N话题,默认10
  142. mode: 时间模式
  143. - "daily": 当日累计数据统计
  144. - "current": 最新一批数据统计(默认)
  145. extract_mode: 提取模式
  146. - "keywords": 统计预设关注词(基于 config/frequency_words.txt,默认)
  147. - "auto_extract": 自动从新闻标题提取高频词
  148. Returns:
  149. 话题频率统计字典
  150. Example:
  151. >>> tools = DataQueryTools()
  152. >>> # 使用预设关注词
  153. >>> result = tools.get_trending_topics(top_n=5, mode="current")
  154. >>> # 自动提取高频词
  155. >>> result = tools.get_trending_topics(top_n=10, extract_mode="auto_extract")
  156. """
  157. try:
  158. # 参数验证
  159. top_n = validate_top_n(top_n, default=10)
  160. valid_modes = ["daily", "current"]
  161. mode = validate_mode(mode, valid_modes, default="current")
  162. # 验证 extract_mode
  163. if extract_mode is None:
  164. extract_mode = "keywords"
  165. elif extract_mode not in ["keywords", "auto_extract"]:
  166. return {
  167. "success": False,
  168. "error": {
  169. "code": "INVALID_PARAMETER",
  170. "message": f"不支持的提取模式: {extract_mode}",
  171. "suggestion": "支持的模式: keywords, auto_extract"
  172. }
  173. }
  174. # 获取趋势话题
  175. trending_result = self.data_service.get_trending_topics(
  176. top_n=top_n,
  177. mode=mode,
  178. extract_mode=extract_mode
  179. )
  180. return {
  181. **trending_result,
  182. "success": True
  183. }
  184. except MCPError as e:
  185. return {
  186. "success": False,
  187. "error": e.to_dict()
  188. }
  189. except Exception as e:
  190. return {
  191. "success": False,
  192. "error": {
  193. "code": "INTERNAL_ERROR",
  194. "message": str(e)
  195. }
  196. }
  197. def get_news_by_date(
  198. self,
  199. date_range: Optional[Union[Dict[str, str], str]] = None,
  200. platforms: Optional[List[str]] = None,
  201. limit: Optional[int] = None,
  202. include_url: bool = False
  203. ) -> Dict:
  204. """
  205. 按日期查询新闻,支持自然语言日期
  206. Args:
  207. date_range: 日期范围(可选,默认"今天"),支持:
  208. - 范围对象:{"start": "2025-01-01", "end": "2025-01-07"}
  209. - 相对日期:今天、昨天、前天、3天前
  210. - 单日字符串:2025-10-10
  211. platforms: 平台ID列表,如 ['zhihu', 'weibo']
  212. limit: 返回条数限制,默认50
  213. include_url: 是否包含URL链接,默认False(节省token)
  214. Returns:
  215. 新闻列表字典
  216. Example:
  217. >>> tools = DataQueryTools()
  218. >>> # 不指定日期,默认查询今天
  219. >>> result = tools.get_news_by_date(platforms=['zhihu'], limit=20)
  220. >>> # 指定日期
  221. >>> result = tools.get_news_by_date(
  222. ... date_range="昨天",
  223. ... platforms=['zhihu'],
  224. ... limit=20
  225. ... )
  226. >>> print(result['total'])
  227. 20
  228. """
  229. try:
  230. # 参数验证 - 默认今天
  231. if date_range is None:
  232. date_range = "今天"
  233. # 规范化 date_range(处理 JSON 字符串序列化问题)
  234. date_range = normalize_date_range(date_range)
  235. # 处理 date_range:支持字符串或对象
  236. if isinstance(date_range, dict):
  237. # 范围对象,取 start 日期
  238. date_str = date_range.get('start', '今天')
  239. else:
  240. date_str = date_range
  241. target_date = validate_date_query(date_str)
  242. platforms = validate_platforms(platforms)
  243. limit = validate_limit(limit, default=50)
  244. # 获取数据
  245. news_list = self.data_service.get_news_by_date(
  246. target_date=target_date,
  247. platforms=platforms,
  248. limit=limit,
  249. include_url=include_url
  250. )
  251. return {
  252. "news": news_list,
  253. "total": len(news_list),
  254. "date": target_date.strftime("%Y-%m-%d"),
  255. "date_range": date_range,
  256. "platforms": platforms,
  257. "success": True
  258. }
  259. except MCPError as e:
  260. return {
  261. "success": False,
  262. "error": e.to_dict()
  263. }
  264. except Exception as e:
  265. return {
  266. "success": False,
  267. "error": {
  268. "code": "INTERNAL_ERROR",
  269. "message": str(e)
  270. }
  271. }
  272. # ========================================
  273. # RSS 数据查询方法
  274. # ========================================
  275. def get_latest_rss(
  276. self,
  277. feeds: Optional[List[str]] = None,
  278. limit: Optional[int] = None,
  279. include_summary: bool = False
  280. ) -> Dict:
  281. """
  282. 获取最新的 RSS 数据
  283. Args:
  284. feeds: RSS 源 ID 列表,如 ['hacker-news', '36kr']
  285. limit: 返回条数限制,默认50
  286. include_summary: 是否包含摘要,默认False(节省token)
  287. Returns:
  288. RSS 条目列表字典
  289. """
  290. try:
  291. limit = validate_limit(limit, default=50)
  292. rss_list = self.data_service.get_latest_rss(
  293. feeds=feeds,
  294. limit=limit,
  295. include_summary=include_summary
  296. )
  297. return {
  298. "rss": rss_list,
  299. "total": len(rss_list),
  300. "feeds": feeds,
  301. "success": True
  302. }
  303. except MCPError as e:
  304. return {
  305. "success": False,
  306. "error": e.to_dict()
  307. }
  308. except Exception as e:
  309. return {
  310. "success": False,
  311. "error": {
  312. "code": "INTERNAL_ERROR",
  313. "message": str(e)
  314. }
  315. }
  316. def search_rss(
  317. self,
  318. keyword: str,
  319. feeds: Optional[List[str]] = None,
  320. days: int = 7,
  321. limit: Optional[int] = None,
  322. include_summary: bool = False
  323. ) -> Dict:
  324. """
  325. 搜索 RSS 数据
  326. Args:
  327. keyword: 搜索关键词
  328. feeds: RSS 源 ID 列表
  329. days: 搜索最近 N 天的数据,默认 7 天
  330. limit: 返回条数限制,默认50
  331. include_summary: 是否包含摘要
  332. Returns:
  333. 匹配的 RSS 条目列表
  334. """
  335. try:
  336. keyword = validate_keyword(keyword)
  337. limit = validate_limit(limit, default=50)
  338. if days < 1 or days > 30:
  339. days = 7
  340. rss_list = self.data_service.search_rss(
  341. keyword=keyword,
  342. feeds=feeds,
  343. days=days,
  344. limit=limit,
  345. include_summary=include_summary
  346. )
  347. return {
  348. "rss": rss_list,
  349. "total": len(rss_list),
  350. "keyword": keyword,
  351. "feeds": feeds,
  352. "days": days,
  353. "success": True
  354. }
  355. except MCPError as e:
  356. return {
  357. "success": False,
  358. "error": e.to_dict()
  359. }
  360. except Exception as e:
  361. return {
  362. "success": False,
  363. "error": {
  364. "code": "INTERNAL_ERROR",
  365. "message": str(e)
  366. }
  367. }
  368. def get_rss_feeds_status(self) -> Dict:
  369. """
  370. 获取 RSS 源状态
  371. Returns:
  372. RSS 源状态信息
  373. """
  374. try:
  375. status = self.data_service.get_rss_feeds_status()
  376. return {
  377. **status,
  378. "success": True
  379. }
  380. except MCPError as e:
  381. return {
  382. "success": False,
  383. "error": e.to_dict()
  384. }
  385. except Exception as e:
  386. return {
  387. "success": False,
  388. "error": {
  389. "code": "INTERNAL_ERROR",
  390. "message": str(e)
  391. }
  392. }