storage_sync.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. # coding=utf-8
  2. """
  3. 存储同步工具
  4. 实现从远程存储拉取数据到本地、获取存储状态、列出可用日期等功能。
  5. """
  6. import os
  7. import re
  8. from pathlib import Path
  9. from datetime import datetime, timedelta
  10. from typing import Dict, List, Optional
  11. import yaml
  12. from ..utils.errors import MCPError
  13. class StorageSyncTools:
  14. """存储同步工具类"""
  15. def __init__(self, project_root: str = None):
  16. """
  17. 初始化存储同步工具
  18. Args:
  19. project_root: 项目根目录
  20. """
  21. if project_root:
  22. self.project_root = Path(project_root)
  23. else:
  24. current_file = Path(__file__)
  25. self.project_root = current_file.parent.parent.parent
  26. self._config = None
  27. self._remote_backend = None
  28. def _load_config(self) -> dict:
  29. """加载配置文件"""
  30. if self._config is None:
  31. config_path = self.project_root / "config" / "config.yaml"
  32. if config_path.exists():
  33. with open(config_path, "r", encoding="utf-8") as f:
  34. self._config = yaml.safe_load(f)
  35. else:
  36. self._config = {}
  37. return self._config
  38. def _get_storage_config(self) -> dict:
  39. """获取存储配置"""
  40. config = self._load_config()
  41. return config.get("storage", {})
  42. def _get_remote_config(self) -> dict:
  43. """
  44. 获取远程存储配置(合并配置文件和环境变量)
  45. """
  46. storage_config = self._get_storage_config()
  47. remote_config = storage_config.get("remote", {})
  48. return {
  49. "endpoint_url": remote_config.get("endpoint_url") or os.environ.get("S3_ENDPOINT_URL", ""),
  50. "bucket_name": remote_config.get("bucket_name") or os.environ.get("S3_BUCKET_NAME", ""),
  51. "access_key_id": remote_config.get("access_key_id") or os.environ.get("S3_ACCESS_KEY_ID", ""),
  52. "secret_access_key": remote_config.get("secret_access_key") or os.environ.get("S3_SECRET_ACCESS_KEY", ""),
  53. "region": remote_config.get("region") or os.environ.get("S3_REGION", ""),
  54. }
  55. def _has_remote_config(self) -> bool:
  56. """检查是否有有效的远程存储配置"""
  57. config = self._get_remote_config()
  58. return bool(
  59. config.get("bucket_name") and
  60. config.get("access_key_id") and
  61. config.get("secret_access_key") and
  62. config.get("endpoint_url")
  63. )
  64. def _get_remote_backend(self):
  65. """获取远程存储后端实例"""
  66. if self._remote_backend is not None:
  67. return self._remote_backend
  68. if not self._has_remote_config():
  69. return None
  70. try:
  71. from trendradar.storage.remote import RemoteStorageBackend
  72. remote_config = self._get_remote_config()
  73. config = self._load_config()
  74. timezone = config.get("app", {}).get("timezone", "Asia/Shanghai")
  75. self._remote_backend = RemoteStorageBackend(
  76. bucket_name=remote_config["bucket_name"],
  77. access_key_id=remote_config["access_key_id"],
  78. secret_access_key=remote_config["secret_access_key"],
  79. endpoint_url=remote_config["endpoint_url"],
  80. region=remote_config.get("region", ""),
  81. timezone=timezone,
  82. )
  83. return self._remote_backend
  84. except ImportError:
  85. print("[存储同步] 远程存储后端需要安装 boto3: pip install boto3")
  86. return None
  87. except Exception as e:
  88. print(f"[存储同步] 创建远程后端失败: {e}")
  89. return None
  90. def _get_local_data_dir(self) -> Path:
  91. """获取本地数据目录"""
  92. storage_config = self._get_storage_config()
  93. local_config = storage_config.get("local", {})
  94. data_dir = local_config.get("data_dir", "output")
  95. return self.project_root / data_dir
  96. def _parse_date_folder_name(self, folder_name: str) -> Optional[datetime]:
  97. """
  98. 解析日期文件夹名称(兼容中文和 ISO 格式)
  99. 支持两种格式:
  100. - 中文格式:YYYY年MM月DD日
  101. - ISO 格式:YYYY-MM-DD
  102. """
  103. # 尝试 ISO 格式
  104. iso_match = re.match(r'(\d{4})-(\d{2})-(\d{2})', folder_name)
  105. if iso_match:
  106. try:
  107. return datetime(
  108. int(iso_match.group(1)),
  109. int(iso_match.group(2)),
  110. int(iso_match.group(3))
  111. )
  112. except ValueError:
  113. pass
  114. # 尝试中文格式
  115. chinese_match = re.match(r'(\d{4})年(\d{2})月(\d{2})日', folder_name)
  116. if chinese_match:
  117. try:
  118. return datetime(
  119. int(chinese_match.group(1)),
  120. int(chinese_match.group(2)),
  121. int(chinese_match.group(3))
  122. )
  123. except ValueError:
  124. pass
  125. return None
  126. def _get_local_dates(self, db_type: str = "news") -> List[str]:
  127. """
  128. 获取本地可用的日期列表
  129. 存储结构: output/{db_type}/{date}.db
  130. 例如: output/news/2025-12-30.db, output/rss/2025-12-30.db
  131. Args:
  132. db_type: 数据库类型 ("news" 或 "rss"),默认 "news"
  133. Returns:
  134. 日期列表(按时间倒序)
  135. """
  136. local_dir = self._get_local_data_dir()
  137. dates = set()
  138. if not local_dir.exists():
  139. return []
  140. # 扫描 output/{db_type}/{date}.db 文件
  141. type_dir = local_dir / db_type
  142. if type_dir.exists():
  143. for item in type_dir.iterdir():
  144. if item.is_file() and item.suffix == ".db":
  145. # 从文件名解析日期 (2025-12-30.db -> 2025-12-30)
  146. date_str = item.stem # 去除 .db 后缀
  147. folder_date = self._parse_date_folder_name(date_str)
  148. if folder_date:
  149. dates.add(folder_date.strftime("%Y-%m-%d"))
  150. return sorted(list(dates), reverse=True)
  151. def _get_all_local_dates(self) -> Dict[str, List[str]]:
  152. """
  153. 获取所有本地可用的日期列表(包括 news 和 rss)
  154. Returns:
  155. {
  156. "news": ["2025-12-30", ...],
  157. "rss": ["2025-12-30", ...],
  158. "all": ["2025-12-30", ...] # 合并去重
  159. }
  160. """
  161. news_dates = set(self._get_local_dates("news"))
  162. rss_dates = set(self._get_local_dates("rss"))
  163. all_dates = news_dates | rss_dates
  164. return {
  165. "news": sorted(list(news_dates), reverse=True),
  166. "rss": sorted(list(rss_dates), reverse=True),
  167. "all": sorted(list(all_dates), reverse=True)
  168. }
  169. def _calculate_dir_size(self, path: Path) -> int:
  170. """计算目录大小(字节)"""
  171. total_size = 0
  172. if path.exists():
  173. for item in path.rglob("*"):
  174. if item.is_file():
  175. total_size += item.stat().st_size
  176. return total_size
  177. def sync_from_remote(self, days: int = 7) -> Dict:
  178. """
  179. 从远程存储拉取数据到本地
  180. Args:
  181. days: 拉取最近 N 天的数据,默认 7 天
  182. Returns:
  183. 同步结果字典
  184. """
  185. try:
  186. # 检查远程配置
  187. if not self._has_remote_config():
  188. return {
  189. "success": False,
  190. "error": {
  191. "code": "REMOTE_NOT_CONFIGURED",
  192. "message": "未配置远程存储",
  193. "suggestion": "请在 config/config.yaml 中配置 storage.remote 或设置环境变量"
  194. }
  195. }
  196. # 获取远程后端
  197. remote_backend = self._get_remote_backend()
  198. if remote_backend is None:
  199. return {
  200. "success": False,
  201. "error": {
  202. "code": "REMOTE_BACKEND_FAILED",
  203. "message": "无法创建远程存储后端",
  204. "suggestion": "请检查远程存储配置和 boto3 是否已安装"
  205. }
  206. }
  207. # 获取本地数据目录
  208. local_dir = self._get_local_data_dir()
  209. local_dir.mkdir(parents=True, exist_ok=True)
  210. # 获取远程可用日期
  211. remote_dates = remote_backend.list_remote_dates()
  212. # 获取本地已有日期
  213. local_dates = set(self._get_local_dates())
  214. # 计算需要拉取的日期(最近 N 天)
  215. from trendradar.utils.time import get_configured_time
  216. config = self._load_config()
  217. timezone = config.get("app", {}).get("timezone", "Asia/Shanghai")
  218. now = get_configured_time(timezone)
  219. target_dates = []
  220. for i in range(days):
  221. date = now - timedelta(days=i)
  222. date_str = date.strftime("%Y-%m-%d")
  223. if date_str in remote_dates:
  224. target_dates.append(date_str)
  225. # 执行拉取
  226. synced_dates = []
  227. skipped_dates = []
  228. failed_dates = []
  229. for date_str in target_dates:
  230. # 检查本地是否已存在
  231. if date_str in local_dates:
  232. skipped_dates.append(date_str)
  233. continue
  234. # 拉取单个日期
  235. try:
  236. local_date_dir = local_dir / date_str
  237. local_db_path = local_date_dir / "news.db"
  238. remote_key = f"news/{date_str}.db"
  239. local_date_dir.mkdir(parents=True, exist_ok=True)
  240. remote_backend.s3_client.download_file(
  241. remote_backend.bucket_name,
  242. remote_key,
  243. str(local_db_path)
  244. )
  245. synced_dates.append(date_str)
  246. print(f"[存储同步] 已拉取: {date_str}")
  247. except Exception as e:
  248. failed_dates.append({"date": date_str, "error": str(e)})
  249. print(f"[存储同步] 拉取失败 ({date_str}): {e}")
  250. return {
  251. "success": True,
  252. "summary": {
  253. "description": "远程存储同步结果",
  254. "synced_files": len(synced_dates),
  255. "skipped_count": len(skipped_dates),
  256. "failed_count": len(failed_dates)
  257. },
  258. "data": {
  259. "synced_dates": synced_dates,
  260. "skipped_dates": skipped_dates,
  261. "failed_dates": failed_dates
  262. },
  263. "message": f"成功同步 {len(synced_dates)} 天数据" + (
  264. f",跳过 {len(skipped_dates)} 天(本地已存在)" if skipped_dates else ""
  265. ) + (
  266. f",失败 {len(failed_dates)} 天" if failed_dates else ""
  267. )
  268. }
  269. except MCPError as e:
  270. return {
  271. "success": False,
  272. "error": e.to_dict()
  273. }
  274. except Exception as e:
  275. return {
  276. "success": False,
  277. "error": {
  278. "code": "INTERNAL_ERROR",
  279. "message": str(e)
  280. }
  281. }
  282. def get_storage_status(self) -> Dict:
  283. """
  284. 获取存储配置和状态
  285. Returns:
  286. 存储状态字典
  287. """
  288. try:
  289. storage_config = self._get_storage_config()
  290. config = self._load_config()
  291. # 本地存储状态
  292. local_config = storage_config.get("local", {})
  293. local_dir = self._get_local_data_dir()
  294. local_size = self._calculate_dir_size(local_dir)
  295. # 获取分类的日期列表
  296. all_dates = self._get_all_local_dates()
  297. news_dates = all_dates["news"]
  298. rss_dates = all_dates["rss"]
  299. combined_dates = all_dates["all"]
  300. local_status = {
  301. "data_dir": local_config.get("data_dir", "output"),
  302. "retention_days": local_config.get("retention_days", 0),
  303. "total_size": f"{local_size / 1024 / 1024:.2f} MB",
  304. "total_size_bytes": local_size,
  305. "date_count": len(combined_dates),
  306. "earliest_date": combined_dates[-1] if combined_dates else None,
  307. "latest_date": combined_dates[0] if combined_dates else None,
  308. "news": {
  309. "date_count": len(news_dates),
  310. "dates": news_dates[:10], # 最近 10 天
  311. },
  312. "rss": {
  313. "date_count": len(rss_dates),
  314. "dates": rss_dates[:10], # 最近 10 天
  315. },
  316. }
  317. # 远程存储状态
  318. remote_config = storage_config.get("remote", {})
  319. has_remote = self._has_remote_config()
  320. remote_status = {
  321. "configured": has_remote,
  322. "retention_days": remote_config.get("retention_days", 0),
  323. }
  324. if has_remote:
  325. merged_config = self._get_remote_config()
  326. # 脱敏显示
  327. endpoint = merged_config.get("endpoint_url", "")
  328. bucket = merged_config.get("bucket_name", "")
  329. remote_status["endpoint_url"] = endpoint
  330. remote_status["bucket_name"] = bucket
  331. # 尝试获取远程日期列表
  332. remote_backend = self._get_remote_backend()
  333. if remote_backend:
  334. try:
  335. remote_dates = remote_backend.list_remote_dates()
  336. remote_status["date_count"] = len(remote_dates)
  337. remote_status["earliest_date"] = remote_dates[-1] if remote_dates else None
  338. remote_status["latest_date"] = remote_dates[0] if remote_dates else None
  339. except Exception as e:
  340. remote_status["error"] = str(e)
  341. # 拉取配置状态
  342. pull_config = storage_config.get("pull", {})
  343. pull_status = {
  344. "enabled": pull_config.get("enabled", False),
  345. "days": pull_config.get("days", 7),
  346. }
  347. return {
  348. "success": True,
  349. "summary": {
  350. "description": "存储配置和状态信息",
  351. "backend": storage_config.get("backend", "auto")
  352. },
  353. "data": {
  354. "local": local_status,
  355. "remote": remote_status,
  356. "pull": pull_status
  357. }
  358. }
  359. except MCPError as e:
  360. return {
  361. "success": False,
  362. "error": e.to_dict()
  363. }
  364. except Exception as e:
  365. return {
  366. "success": False,
  367. "error": {
  368. "code": "INTERNAL_ERROR",
  369. "message": str(e)
  370. }
  371. }
  372. def list_available_dates(self, source: str = "both") -> Dict:
  373. """
  374. 列出可用的日期范围
  375. Args:
  376. source: 数据来源
  377. - "local": 仅本地
  378. - "remote": 仅远程
  379. - "both": 两者都列出(默认)
  380. Returns:
  381. 日期列表字典
  382. """
  383. try:
  384. data_result = {}
  385. summary_info = {
  386. "description": "可用日期列表",
  387. "source": source
  388. }
  389. # 本地日期
  390. if source in ("local", "both"):
  391. all_dates = self._get_all_local_dates()
  392. news_dates = all_dates["news"]
  393. rss_dates = all_dates["rss"]
  394. combined_dates = all_dates["all"]
  395. data_result["local"] = {
  396. "dates": combined_dates,
  397. "count": len(combined_dates),
  398. "earliest": combined_dates[-1] if combined_dates else None,
  399. "latest": combined_dates[0] if combined_dates else None,
  400. "news": {
  401. "dates": news_dates,
  402. "count": len(news_dates),
  403. },
  404. "rss": {
  405. "dates": rss_dates,
  406. "count": len(rss_dates),
  407. },
  408. }
  409. # 远程日期
  410. if source in ("remote", "both"):
  411. if not self._has_remote_config():
  412. data_result["remote"] = {
  413. "configured": False,
  414. "dates": [],
  415. "count": 0,
  416. "earliest": None,
  417. "latest": None,
  418. "error": "未配置远程存储"
  419. }
  420. else:
  421. remote_backend = self._get_remote_backend()
  422. if remote_backend:
  423. try:
  424. remote_dates = remote_backend.list_remote_dates()
  425. data_result["remote"] = {
  426. "configured": True,
  427. "dates": remote_dates,
  428. "count": len(remote_dates),
  429. "earliest": remote_dates[-1] if remote_dates else None,
  430. "latest": remote_dates[0] if remote_dates else None,
  431. }
  432. except Exception as e:
  433. data_result["remote"] = {
  434. "configured": True,
  435. "dates": [],
  436. "count": 0,
  437. "earliest": None,
  438. "latest": None,
  439. "error": str(e)
  440. }
  441. else:
  442. data_result["remote"] = {
  443. "configured": True,
  444. "dates": [],
  445. "count": 0,
  446. "earliest": None,
  447. "latest": None,
  448. "error": "无法创建远程存储后端"
  449. }
  450. # 如果同时查询两者,计算差异
  451. if source == "both" and "local" in data_result and "remote" in data_result:
  452. local_set = set(data_result["local"]["dates"])
  453. remote_set = set(data_result["remote"].get("dates", []))
  454. data_result["comparison"] = {
  455. "only_local": sorted(list(local_set - remote_set), reverse=True),
  456. "only_remote": sorted(list(remote_set - local_set), reverse=True),
  457. "both": sorted(list(local_set & remote_set), reverse=True),
  458. }
  459. return {
  460. "success": True,
  461. "summary": summary_info,
  462. "data": data_result
  463. }
  464. except MCPError as e:
  465. return {
  466. "success": False,
  467. "error": e.to_dict()
  468. }
  469. except Exception as e:
  470. return {
  471. "success": False,
  472. "error": {
  473. "code": "INTERNAL_ERROR",
  474. "message": str(e)
  475. }
  476. }