validators.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670
  1. """
  2. 参数验证工具
  3. 提供统一的参数验证功能。
  4. 支持 MCP 客户端将参数序列化为字符串的情况。
  5. """
  6. from datetime import datetime
  7. from typing import List, Optional, Union
  8. import os
  9. import json
  10. import yaml
  11. import ast
  12. from .errors import InvalidParameterError
  13. from .date_parser import DateParser
  14. # ==================== 辅助函数:处理字符串序列化 ====================
  15. def _parse_string_to_list(value: str) -> List[str]:
  16. """
  17. 将字符串解析为列表
  18. 支持格式:
  19. - JSON 数组: '["zhihu", "weibo"]'
  20. - Python 列表字符串: "['zhihu', 'weibo']"
  21. - 逗号分隔: "zhihu, weibo" 或 "zhihu,weibo"
  22. Args:
  23. value: 字符串值
  24. Returns:
  25. 解析后的列表
  26. Raises:
  27. InvalidParameterError: 解析失败
  28. """
  29. value = value.strip()
  30. if not value:
  31. return []
  32. # 尝试 JSON 解析: '["zhihu", "weibo"]'
  33. try:
  34. parsed = json.loads(value)
  35. if isinstance(parsed, list):
  36. return [str(item) for item in parsed]
  37. # 如果解析结果不是列表,继续尝试其他方式
  38. except json.JSONDecodeError:
  39. pass
  40. # 尝试 Python 字面量解析: "['zhihu', 'weibo']"
  41. try:
  42. parsed = ast.literal_eval(value)
  43. if isinstance(parsed, list):
  44. return [str(item) for item in parsed]
  45. if isinstance(parsed, str):
  46. # 单个字符串,包装成列表
  47. return [parsed]
  48. except (ValueError, SyntaxError):
  49. pass
  50. # 尝试逗号分隔: "zhihu, weibo" 或 "zhihu,weibo"
  51. if ',' in value:
  52. items = [item.strip() for item in value.split(',')]
  53. return [item for item in items if item]
  54. # 单个值
  55. return [value]
  56. def _parse_string_to_int(value: str, param_name: str = "参数") -> int:
  57. """
  58. 将字符串解析为整数
  59. Args:
  60. value: 字符串值
  61. param_name: 参数名(用于错误消息)
  62. Returns:
  63. 解析后的整数
  64. Raises:
  65. InvalidParameterError: 解析失败
  66. """
  67. value = value.strip()
  68. try:
  69. # 尝试直接转换
  70. return int(value)
  71. except ValueError:
  72. pass
  73. # 尝试解析浮点数后取整
  74. try:
  75. return int(float(value))
  76. except ValueError:
  77. raise InvalidParameterError(
  78. f"{param_name} 必须是整数,无法解析: {value}",
  79. suggestion=f"请提供有效的整数值,如: 10, 50, 100"
  80. )
  81. def _parse_string_to_float(value: str, param_name: str = "参数") -> float:
  82. """
  83. 将字符串解析为浮点数
  84. Args:
  85. value: 字符串值
  86. param_name: 参数名(用于错误消息)
  87. Returns:
  88. 解析后的浮点数
  89. Raises:
  90. InvalidParameterError: 解析失败
  91. """
  92. value = value.strip()
  93. try:
  94. return float(value)
  95. except ValueError:
  96. raise InvalidParameterError(
  97. f"{param_name} 必须是数字,无法解析: {value}",
  98. suggestion=f"请提供有效的数字值,如: 0.6, 3.0"
  99. )
  100. def _parse_string_to_bool(value: str) -> bool:
  101. """
  102. 将字符串解析为布尔值
  103. Args:
  104. value: 字符串值
  105. Returns:
  106. 解析后的布尔值
  107. """
  108. value = value.strip().lower()
  109. if value in ('true', '1', 'yes', 'on'):
  110. return True
  111. elif value in ('false', '0', 'no', 'off', ''):
  112. return False
  113. else:
  114. # 默认非空字符串为 True
  115. return bool(value)
  116. # 平台列表 mtime 缓存(避免每次 MCP 调用都重新读取 config.yaml)
  117. _platforms_cache: Optional[List[str]] = None
  118. _platforms_config_mtime: float = 0.0
  119. _platforms_config_path: Optional[str] = None
  120. def get_supported_platforms() -> List[str]:
  121. """
  122. 从 config.yaml 动态获取支持的平台列表(带 mtime 缓存)
  123. 仅当 config.yaml 被修改时才重新读取,避免每次 MCP 调用的重复 IO。
  124. Returns:
  125. 平台ID列表
  126. Note:
  127. - 读取失败时返回空列表,允许所有平台通过(降级策略)
  128. - 平台列表来自 config/config.yaml 中的 platforms 配置
  129. """
  130. global _platforms_cache, _platforms_config_mtime, _platforms_config_path
  131. try:
  132. if _platforms_config_path is None:
  133. current_dir = os.path.dirname(os.path.abspath(__file__))
  134. _platforms_config_path = os.path.normpath(
  135. os.path.join(current_dir, "..", "..", "config", "config.yaml")
  136. )
  137. current_mtime = os.path.getmtime(_platforms_config_path)
  138. if _platforms_cache is not None and current_mtime == _platforms_config_mtime:
  139. return _platforms_cache
  140. with open(_platforms_config_path, 'r', encoding='utf-8') as f:
  141. config = yaml.safe_load(f)
  142. platforms_config = config.get('platforms', {})
  143. sources = platforms_config.get('sources', [])
  144. _platforms_cache = [p['id'] for p in sources if 'id' in p]
  145. _platforms_config_mtime = current_mtime
  146. return _platforms_cache
  147. except Exception as e:
  148. print(f"警告:无法加载平台配置: {e}")
  149. return []
  150. def validate_platforms(platforms: Optional[Union[List[str], str]]) -> List[str]:
  151. """
  152. 验证平台列表
  153. Args:
  154. platforms: 平台ID列表或字符串,None表示使用 config.yaml 中配置的所有平台
  155. 支持多种格式:
  156. - None: 使用默认平台
  157. - ["zhihu", "weibo"]: JSON 数组
  158. - '["zhihu", "weibo"]': JSON 数组字符串
  159. - "['zhihu', 'weibo']": Python 列表字符串
  160. - "zhihu, weibo": 逗号分隔字符串
  161. - "zhihu": 单个平台字符串
  162. Returns:
  163. 验证后的平台列表
  164. Raises:
  165. InvalidParameterError: 平台不支持
  166. Note:
  167. - platforms=None 时,返回 config.yaml 中配置的平台列表
  168. - 会验证平台ID是否在 config.yaml 的 platforms 配置中
  169. - 配置加载失败时,允许所有平台通过(降级策略)
  170. """
  171. supported_platforms = get_supported_platforms()
  172. if platforms is None:
  173. # 返回配置文件中的平台列表(用户的默认配置)
  174. return supported_platforms if supported_platforms else []
  175. # 支持字符串形式的列表输入(某些 MCP 客户端会将 JSON 数组序列化为字符串)
  176. if isinstance(platforms, str):
  177. platforms = _parse_string_to_list(platforms)
  178. if not platforms:
  179. # 空字符串或解析后为空,使用默认平台
  180. return supported_platforms if supported_platforms else []
  181. if not isinstance(platforms, list):
  182. raise InvalidParameterError("platforms 参数必须是列表类型")
  183. if not platforms:
  184. # 空列表时,返回配置文件中的平台列表
  185. return supported_platforms if supported_platforms else []
  186. # 如果配置加载失败(supported_platforms为空),允许所有平台通过
  187. if not supported_platforms:
  188. print("警告:平台配置未加载,跳过平台验证")
  189. return platforms
  190. # 验证每个平台是否在配置中
  191. invalid_platforms = [p for p in platforms if p not in supported_platforms]
  192. if invalid_platforms:
  193. raise InvalidParameterError(
  194. f"不支持的平台: {', '.join(invalid_platforms)}",
  195. suggestion=f"支持的平台(来自config.yaml): {', '.join(supported_platforms)}"
  196. )
  197. return platforms
  198. def validate_limit(limit: Optional[Union[int, str]], default: int = 20, max_limit: int = 1000) -> int:
  199. """
  200. 验证数量限制参数
  201. Args:
  202. limit: 限制数量(整数或字符串)
  203. default: 默认值
  204. max_limit: 最大限制
  205. Returns:
  206. 验证后的限制值
  207. Raises:
  208. InvalidParameterError: 参数无效
  209. """
  210. if limit is None:
  211. return default
  212. # 支持字符串形式的整数(某些 MCP 客户端会将数字序列化为字符串)
  213. if isinstance(limit, str):
  214. limit = _parse_string_to_int(limit, "limit")
  215. if not isinstance(limit, int):
  216. raise InvalidParameterError("limit 参数必须是整数类型")
  217. if limit <= 0:
  218. raise InvalidParameterError("limit 必须大于0")
  219. if limit > max_limit:
  220. raise InvalidParameterError(
  221. f"limit 不能超过 {max_limit}",
  222. suggestion=f"请使用分页或降低limit值"
  223. )
  224. return limit
  225. def validate_date(date_str: str) -> datetime:
  226. """
  227. 验证日期格式
  228. Args:
  229. date_str: 日期字符串 (YYYY-MM-DD)
  230. Returns:
  231. datetime对象
  232. Raises:
  233. InvalidParameterError: 日期格式错误
  234. """
  235. try:
  236. return datetime.strptime(date_str, "%Y-%m-%d")
  237. except ValueError:
  238. raise InvalidParameterError(
  239. f"日期格式错误: {date_str}",
  240. suggestion="请使用 YYYY-MM-DD 格式,例如: 2025-10-11"
  241. )
  242. def normalize_date_range(date_range: Optional[Union[dict, str]]) -> Optional[Union[dict, str]]:
  243. """
  244. 规范化 date_range 参数
  245. 某些 MCP 客户端(特别是 HTTP 方式)会将 JSON 对象序列化为字符串传入。
  246. 此函数尝试将 JSON 字符串解析为 dict,如果不是 JSON 格式则保持原样。
  247. Args:
  248. date_range: 日期范围,可能是:
  249. - dict: {"start": "2025-01-01", "end": "2025-01-07"}
  250. - JSON 字符串: '{"start": "2025-01-01", "end": "2025-01-07"}'
  251. - 普通字符串: "今天", "昨天", "2025-01-01"
  252. - None
  253. Returns:
  254. 规范化后的 date_range(dict 或普通字符串)
  255. Examples:
  256. >>> normalize_date_range('{"start":"2025-01-01","end":"2025-01-07"}')
  257. {"start": "2025-01-01", "end": "2025-01-07"}
  258. >>> normalize_date_range("今天")
  259. "今天"
  260. >>> normalize_date_range({"start": "2025-01-01", "end": "2025-01-07"})
  261. {"start": "2025-01-01", "end": "2025-01-07"}
  262. """
  263. if date_range is None:
  264. return None
  265. # 如果已经是 dict,直接返回
  266. if isinstance(date_range, dict):
  267. return date_range
  268. # 如果是字符串,尝试解析为 JSON
  269. if isinstance(date_range, str):
  270. # 检查是否看起来像 JSON 对象
  271. stripped = date_range.strip()
  272. if stripped.startswith('{') and stripped.endswith('}'):
  273. try:
  274. parsed = json.loads(stripped)
  275. if isinstance(parsed, dict):
  276. return parsed
  277. except json.JSONDecodeError:
  278. pass # 解析失败,当作普通字符串处理
  279. return date_range
  280. def validate_date_range(date_range: Optional[Union[dict, str]]) -> Optional[tuple]:
  281. """
  282. 验证日期范围
  283. Args:
  284. date_range: 日期范围,支持多种格式:
  285. - dict: {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}
  286. - JSON 字符串: '{"start": "2025-01-01", "end": "2025-01-07"}'
  287. - 单日字符串: "2025-01-01"(自动转为同一天的范围)
  288. - 自然语言: "今天", "昨天", "本周", "最近7天" 等
  289. Returns:
  290. (start_date, end_date) 元组,或 None
  291. Raises:
  292. InvalidParameterError: 日期范围无效
  293. """
  294. if date_range is None:
  295. return None
  296. # 支持字符串形式的输入
  297. if isinstance(date_range, str):
  298. stripped = date_range.strip()
  299. # 1. 检查是否是 JSON 对象格式
  300. if stripped.startswith('{') and stripped.endswith('}'):
  301. try:
  302. date_range = json.loads(stripped)
  303. except json.JSONDecodeError as e:
  304. raise InvalidParameterError(
  305. f"date_range JSON 解析失败: {e}",
  306. suggestion='请使用正确的JSON格式: {"start": "YYYY-MM-DD", "end": "YYYY-MM-DD"}'
  307. )
  308. # 2. 检查是否是单日字符串格式 YYYY-MM-DD
  309. elif len(stripped) == 10 and stripped[4] == '-' and stripped[7] == '-':
  310. try:
  311. single_date = datetime.strptime(stripped, "%Y-%m-%d")
  312. return (single_date, single_date)
  313. except ValueError:
  314. raise InvalidParameterError(
  315. f"日期格式错误: {stripped}",
  316. suggestion="请使用 YYYY-MM-DD 格式,例如: 2025-10-11"
  317. )
  318. # 3. 尝试自然语言解析
  319. else:
  320. try:
  321. result = DateParser.resolve_date_range_expression(stripped)
  322. if result.get("success"):
  323. dr = result["date_range"]
  324. start_date = datetime.strptime(dr["start"], "%Y-%m-%d")
  325. end_date = datetime.strptime(dr["end"], "%Y-%m-%d")
  326. return (start_date, end_date)
  327. else:
  328. raise InvalidParameterError(
  329. f"无法识别的日期表达式: {stripped}",
  330. suggestion="支持格式: YYYY-MM-DD, {\"start\": \"...\", \"end\": \"...\"}, 或自然语言(今天、本周、最近7天等)"
  331. )
  332. except InvalidParameterError:
  333. raise
  334. except Exception:
  335. raise InvalidParameterError(
  336. f"日期解析失败: {stripped}",
  337. suggestion="支持格式: YYYY-MM-DD, {\"start\": \"...\", \"end\": \"...\"}, 或自然语言(今天、本周、最近7天等)"
  338. )
  339. if not isinstance(date_range, dict):
  340. raise InvalidParameterError(
  341. "date_range 必须是字典类型、日期字符串或有效的JSON字符串",
  342. suggestion='例如: {"start": "2025-10-01", "end": "2025-10-11"} 或 "2025-10-01"'
  343. )
  344. start_str = date_range.get("start")
  345. end_str = date_range.get("end")
  346. if not start_str or not end_str:
  347. raise InvalidParameterError(
  348. "date_range 必须包含 start 和 end 字段",
  349. suggestion='例如: {"start": "2025-10-01", "end": "2025-10-11"}'
  350. )
  351. start_date = validate_date(start_str)
  352. end_date = validate_date(end_str)
  353. if start_date > end_date:
  354. raise InvalidParameterError(
  355. "开始日期不能晚于结束日期",
  356. suggestion=f"start: {start_str}, end: {end_str}"
  357. )
  358. # 检查日期是否在未来
  359. today = datetime.now().date()
  360. if start_date.date() > today or end_date.date() > today:
  361. # 获取可用日期范围提示
  362. try:
  363. from ..services.data_service import DataService
  364. data_service = DataService()
  365. earliest, latest = data_service.get_available_date_range()
  366. if earliest and latest:
  367. available_range = f"{earliest.strftime('%Y-%m-%d')} 至 {latest.strftime('%Y-%m-%d')}"
  368. else:
  369. available_range = "无可用数据"
  370. except Exception:
  371. available_range = "未知(请检查 output 目录)"
  372. future_dates = []
  373. if start_date.date() > today:
  374. future_dates.append(start_str)
  375. if end_date.date() > today and end_str != start_str:
  376. future_dates.append(end_str)
  377. raise InvalidParameterError(
  378. f"不允许查询未来日期: {', '.join(future_dates)}(当前日期: {today.strftime('%Y-%m-%d')})",
  379. suggestion=f"当前可用数据范围: {available_range}"
  380. )
  381. return (start_date, end_date)
  382. def validate_keyword(keyword: str) -> str:
  383. """
  384. 验证关键词
  385. Args:
  386. keyword: 搜索关键词
  387. Returns:
  388. 处理后的关键词
  389. Raises:
  390. InvalidParameterError: 关键词无效
  391. """
  392. if not keyword:
  393. raise InvalidParameterError("keyword 不能为空")
  394. if not isinstance(keyword, str):
  395. raise InvalidParameterError("keyword 必须是字符串类型")
  396. keyword = keyword.strip()
  397. if not keyword:
  398. raise InvalidParameterError("keyword 不能为空白字符")
  399. if len(keyword) > 100:
  400. raise InvalidParameterError(
  401. "keyword 长度不能超过100个字符",
  402. suggestion="请使用更简洁的关键词"
  403. )
  404. return keyword
  405. def validate_top_n(top_n: Optional[Union[int, str]], default: int = 10) -> int:
  406. """
  407. 验证TOP N参数
  408. Args:
  409. top_n: TOP N数量(整数或字符串)
  410. default: 默认值
  411. Returns:
  412. 验证后的值
  413. Raises:
  414. InvalidParameterError: 参数无效
  415. """
  416. return validate_limit(top_n, default=default, max_limit=100)
  417. def validate_mode(mode: Optional[str], valid_modes: List[str], default: str) -> str:
  418. """
  419. 验证模式参数
  420. Args:
  421. mode: 模式字符串
  422. valid_modes: 有效模式列表
  423. default: 默认模式
  424. Returns:
  425. 验证后的模式
  426. Raises:
  427. InvalidParameterError: 模式无效
  428. """
  429. if mode is None:
  430. return default
  431. if not isinstance(mode, str):
  432. raise InvalidParameterError("mode 必须是字符串类型")
  433. if mode not in valid_modes:
  434. raise InvalidParameterError(
  435. f"无效的模式: {mode}",
  436. suggestion=f"支持的模式: {', '.join(valid_modes)}"
  437. )
  438. return mode
  439. def validate_config_section(section: Optional[str]) -> str:
  440. """
  441. 验证配置节参数
  442. Args:
  443. section: 配置节名称
  444. Returns:
  445. 验证后的配置节
  446. Raises:
  447. InvalidParameterError: 配置节无效
  448. """
  449. valid_sections = ["all", "crawler", "push", "keywords", "weights"]
  450. return validate_mode(section, valid_sections, "all")
  451. def validate_threshold(
  452. threshold: Optional[Union[float, int, str]],
  453. default: float = 0.6,
  454. min_value: float = 0.0,
  455. max_value: float = 1.0,
  456. param_name: str = "threshold"
  457. ) -> float:
  458. """
  459. 验证阈值参数(浮点数)
  460. Args:
  461. threshold: 阈值(浮点数、整数或字符串)
  462. default: 默认值
  463. min_value: 最小值
  464. max_value: 最大值
  465. param_name: 参数名(用于错误消息)
  466. Returns:
  467. 验证后的阈值
  468. Raises:
  469. InvalidParameterError: 参数无效
  470. """
  471. if threshold is None:
  472. return default
  473. # 支持字符串形式的数字(某些 MCP 客户端会将数字序列化为字符串)
  474. if isinstance(threshold, str):
  475. threshold = _parse_string_to_float(threshold, param_name)
  476. # 整数转浮点数
  477. if isinstance(threshold, int):
  478. threshold = float(threshold)
  479. if not isinstance(threshold, float):
  480. raise InvalidParameterError(
  481. f"{param_name} 必须是数字类型",
  482. suggestion=f"请提供 {min_value} 到 {max_value} 之间的数字"
  483. )
  484. if threshold < min_value or threshold > max_value:
  485. raise InvalidParameterError(
  486. f"{param_name} 必须在 {min_value} 到 {max_value} 之间,当前值: {threshold}",
  487. suggestion=f"推荐值: {default}"
  488. )
  489. return threshold
  490. def validate_date_query(
  491. date_query: str,
  492. allow_future: bool = False,
  493. max_days_ago: int = 365
  494. ) -> datetime:
  495. """
  496. 验证并解析日期查询字符串
  497. Args:
  498. date_query: 日期查询字符串
  499. allow_future: 是否允许未来日期
  500. max_days_ago: 允许查询的最大天数
  501. Returns:
  502. 解析后的datetime对象
  503. Raises:
  504. InvalidParameterError: 日期查询无效
  505. Examples:
  506. >>> validate_date_query("昨天")
  507. datetime(2025, 10, 10)
  508. >>> validate_date_query("2025-10-10")
  509. datetime(2025, 10, 10)
  510. """
  511. if not date_query:
  512. raise InvalidParameterError(
  513. "日期查询字符串不能为空",
  514. suggestion="请提供日期查询,如:今天、昨天、2025-10-10"
  515. )
  516. # 使用DateParser解析日期
  517. parsed_date = DateParser.parse_date_query(date_query)
  518. # 验证日期不在未来
  519. if not allow_future:
  520. DateParser.validate_date_not_future(parsed_date)
  521. # 验证日期不太久远
  522. DateParser.validate_date_not_too_old(parsed_date, max_days=max_days_ago)
  523. return parsed_date