batch.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # coding=utf-8
  2. """
  3. 批次处理模块
  4. 提供消息分批发送的辅助函数
  5. """
  6. from typing import List
  7. def get_batch_header(format_type: str, batch_num: int, total_batches: int) -> str:
  8. """根据 format_type 生成对应格式的批次头部
  9. Args:
  10. format_type: 推送类型(telegram, slack, wework_text, bark, feishu, dingtalk, ntfy, wework)
  11. batch_num: 当前批次编号
  12. total_batches: 总批次数
  13. Returns:
  14. 格式化的批次头部字符串
  15. """
  16. if format_type == "telegram":
  17. return f"<b>[第 {batch_num}/{total_batches} 批次]</b>\n\n"
  18. elif format_type == "slack":
  19. return f"*[第 {batch_num}/{total_batches} 批次]*\n\n"
  20. elif format_type in ("wework_text", "bark"):
  21. # 企业微信文本模式和 Bark 使用纯文本格式
  22. return f"[第 {batch_num}/{total_batches} 批次]\n\n"
  23. else:
  24. # 飞书、钉钉、ntfy、企业微信 markdown 模式
  25. return f"**[第 {batch_num}/{total_batches} 批次]**\n\n"
  26. def get_max_batch_header_size(format_type: str) -> int:
  27. """估算批次头部的最大字节数(假设最多 99 批次)
  28. 用于在分批时预留空间,避免事后截断破坏内容完整性。
  29. Args:
  30. format_type: 推送类型
  31. Returns:
  32. 最大头部字节数
  33. """
  34. # 生成最坏情况的头部(99/99 批次)
  35. max_header = get_batch_header(format_type, 99, 99)
  36. return len(max_header.encode("utf-8"))
  37. def truncate_to_bytes(text: str, max_bytes: int) -> str:
  38. """安全截断字符串到指定字节数,避免截断多字节字符
  39. Args:
  40. text: 要截断的文本
  41. max_bytes: 最大字节数
  42. Returns:
  43. 截断后的文本
  44. """
  45. text_bytes = text.encode("utf-8")
  46. if len(text_bytes) <= max_bytes:
  47. return text
  48. # 截断到指定字节数
  49. truncated = text_bytes[:max_bytes]
  50. # 处理可能的不完整 UTF-8 字符
  51. for i in range(min(4, len(truncated))):
  52. try:
  53. return truncated[: len(truncated) - i].decode("utf-8")
  54. except UnicodeDecodeError:
  55. continue
  56. # 极端情况:返回空字符串
  57. return ""
  58. def add_batch_headers(
  59. batches: List[str], format_type: str, max_bytes: int
  60. ) -> List[str]:
  61. """为批次添加头部,动态计算确保总大小不超过限制
  62. Args:
  63. batches: 原始批次列表
  64. format_type: 推送类型(bark, telegram, feishu 等)
  65. max_bytes: 该推送类型的最大字节限制
  66. Returns:
  67. 添加头部后的批次列表
  68. """
  69. if len(batches) <= 1:
  70. return batches
  71. total = len(batches)
  72. result = []
  73. for i, content in enumerate(batches, 1):
  74. # 生成批次头部
  75. header = get_batch_header(format_type, i, total)
  76. header_size = len(header.encode("utf-8"))
  77. # 动态计算允许的最大内容大小
  78. max_content_size = max_bytes - header_size
  79. content_size = len(content.encode("utf-8"))
  80. # 如果超出,截断到安全大小
  81. if content_size > max_content_size:
  82. print(
  83. f"警告:{format_type} 第 {i}/{total} 批次内容({content_size}字节) + 头部({header_size}字节) 超出限制({max_bytes}字节),截断到 {max_content_size} 字节"
  84. )
  85. content = truncate_to_bytes(content, max_content_size)
  86. result.append(header + content)
  87. return result