translator.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. # coding=utf-8
  2. """
  3. AI 翻译器模块
  4. 对推送内容进行多语言翻译
  5. 使用共享的 AI 模型配置
  6. """
  7. import json
  8. import os
  9. from dataclasses import dataclass, field
  10. from pathlib import Path
  11. from typing import Any, Dict, List, Optional
  12. @dataclass
  13. class TranslationResult:
  14. """翻译结果"""
  15. translated_text: str = "" # 翻译后的文本
  16. original_text: str = "" # 原始文本
  17. success: bool = False # 是否成功
  18. error: str = "" # 错误信息
  19. @dataclass
  20. class BatchTranslationResult:
  21. """批量翻译结果"""
  22. results: List[TranslationResult] = field(default_factory=list)
  23. success_count: int = 0
  24. fail_count: int = 0
  25. total_count: int = 0
  26. class AITranslator:
  27. """AI 翻译器"""
  28. def __init__(self, translation_config: Dict[str, Any], ai_config: Dict[str, Any]):
  29. """
  30. 初始化 AI 翻译器
  31. Args:
  32. translation_config: AI 翻译配置 (AI_TRANSLATION)
  33. ai_config: AI 模型共享配置 (AI)
  34. """
  35. self.translation_config = translation_config
  36. self.ai_config = ai_config
  37. # 翻译配置
  38. self.enabled = translation_config.get("ENABLED", False)
  39. self.target_language = translation_config.get("LANGUAGE", "English")
  40. # 从共享配置获取模型参数
  41. self.api_key = ai_config.get("API_KEY") or os.environ.get("AI_API_KEY", "")
  42. self.provider = ai_config.get("PROVIDER", "deepseek")
  43. self.model = ai_config.get("MODEL", "deepseek-chat")
  44. self.base_url = ai_config.get("BASE_URL", "")
  45. self.timeout = ai_config.get("TIMEOUT", 90)
  46. # AI 参数配置
  47. self.temperature = ai_config.get("TEMPERATURE", 1.0)
  48. self.max_tokens = ai_config.get("MAX_TOKENS", 5000)
  49. # 额外参数
  50. self.extra_params = ai_config.get("EXTRA_PARAMS", {})
  51. if isinstance(self.extra_params, str) and self.extra_params.strip():
  52. try:
  53. self.extra_params = json.loads(self.extra_params)
  54. except json.JSONDecodeError:
  55. print(f"[翻译] 解析 extra_params 失败,将忽略: {self.extra_params}")
  56. self.extra_params = {}
  57. if not isinstance(self.extra_params, dict):
  58. self.extra_params = {}
  59. # 加载提示词模板
  60. self.system_prompt, self.user_prompt_template = self._load_prompt_template(
  61. translation_config.get("PROMPT_FILE", "ai_translation_prompt.txt")
  62. )
  63. def _load_prompt_template(self, prompt_file: str) -> tuple:
  64. """加载提示词模板"""
  65. config_dir = Path(__file__).parent.parent.parent / "config"
  66. prompt_path = config_dir / prompt_file
  67. if not prompt_path.exists():
  68. print(f"[翻译] 提示词文件不存在: {prompt_path}")
  69. return "", ""
  70. content = prompt_path.read_text(encoding="utf-8")
  71. # 解析 [system] 和 [user] 部分
  72. system_prompt = ""
  73. user_prompt = ""
  74. if "[system]" in content and "[user]" in content:
  75. parts = content.split("[user]")
  76. system_part = parts[0]
  77. user_part = parts[1] if len(parts) > 1 else ""
  78. if "[system]" in system_part:
  79. system_prompt = system_part.split("[system]")[1].strip()
  80. user_prompt = user_part.strip()
  81. else:
  82. user_prompt = content
  83. return system_prompt, user_prompt
  84. def translate(self, text: str) -> TranslationResult:
  85. """
  86. 翻译单条文本
  87. Args:
  88. text: 要翻译的文本
  89. Returns:
  90. TranslationResult: 翻译结果
  91. """
  92. result = TranslationResult(original_text=text)
  93. if not self.enabled:
  94. result.error = "翻译功能未启用"
  95. return result
  96. if not self.api_key:
  97. result.error = "未配置 AI API Key"
  98. return result
  99. if not text or not text.strip():
  100. result.translated_text = text
  101. result.success = True
  102. return result
  103. try:
  104. # 构建提示词
  105. user_prompt = self.user_prompt_template
  106. user_prompt = user_prompt.replace("{target_language}", self.target_language)
  107. user_prompt = user_prompt.replace("{content}", text)
  108. # 调用 AI API
  109. response = self._call_ai_api(user_prompt)
  110. result.translated_text = response.strip()
  111. result.success = True
  112. except Exception as e:
  113. import requests
  114. error_type = type(e).__name__
  115. error_msg = str(e)
  116. if isinstance(e, requests.exceptions.Timeout):
  117. result.error = f"翻译请求超时({self.timeout}秒)"
  118. elif isinstance(e, requests.exceptions.ConnectionError):
  119. result.error = f"无法连接到 AI API"
  120. elif isinstance(e, requests.exceptions.HTTPError):
  121. status_code = e.response.status_code if hasattr(e, 'response') and e.response else "未知"
  122. if status_code == 401:
  123. result.error = "API 认证失败"
  124. elif status_code == 429:
  125. result.error = "API 请求频率过高"
  126. else:
  127. result.error = f"API 错误 (HTTP {status_code})"
  128. else:
  129. if len(error_msg) > 100:
  130. error_msg = error_msg[:100] + "..."
  131. result.error = f"翻译失败 ({error_type}): {error_msg}"
  132. return result
  133. def translate_batch(self, texts: List[str]) -> BatchTranslationResult:
  134. """
  135. 批量翻译文本(单次 API 调用)
  136. Args:
  137. texts: 要翻译的文本列表
  138. Returns:
  139. BatchTranslationResult: 批量翻译结果
  140. """
  141. batch_result = BatchTranslationResult(total_count=len(texts))
  142. if not self.enabled:
  143. for text in texts:
  144. batch_result.results.append(TranslationResult(
  145. original_text=text,
  146. error="翻译功能未启用"
  147. ))
  148. batch_result.fail_count = len(texts)
  149. return batch_result
  150. if not self.api_key:
  151. for text in texts:
  152. batch_result.results.append(TranslationResult(
  153. original_text=text,
  154. error="未配置 AI API Key"
  155. ))
  156. batch_result.fail_count = len(texts)
  157. return batch_result
  158. if not texts:
  159. return batch_result
  160. # 过滤空文本
  161. non_empty_indices = []
  162. non_empty_texts = []
  163. for i, text in enumerate(texts):
  164. if text and text.strip():
  165. non_empty_indices.append(i)
  166. non_empty_texts.append(text)
  167. # 初始化结果列表
  168. for text in texts:
  169. batch_result.results.append(TranslationResult(original_text=text))
  170. # 空文本直接标记成功
  171. for i, text in enumerate(texts):
  172. if not text or not text.strip():
  173. batch_result.results[i].translated_text = text
  174. batch_result.results[i].success = True
  175. batch_result.success_count += 1
  176. if not non_empty_texts:
  177. return batch_result
  178. try:
  179. # 构建批量翻译内容(使用编号格式)
  180. batch_content = self._format_batch_content(non_empty_texts)
  181. # 构建提示词
  182. user_prompt = self.user_prompt_template
  183. user_prompt = user_prompt.replace("{target_language}", self.target_language)
  184. user_prompt = user_prompt.replace("{content}", batch_content)
  185. # 调用 AI API
  186. response = self._call_ai_api(user_prompt)
  187. # 解析批量翻译结果
  188. translated_texts = self._parse_batch_response(response, len(non_empty_texts))
  189. # 填充结果
  190. for idx, translated in zip(non_empty_indices, translated_texts):
  191. batch_result.results[idx].translated_text = translated
  192. batch_result.results[idx].success = True
  193. batch_result.success_count += 1
  194. except Exception as e:
  195. error_msg = f"批量翻译失败: {type(e).__name__}: {str(e)[:100]}"
  196. for idx in non_empty_indices:
  197. batch_result.results[idx].error = error_msg
  198. batch_result.fail_count = len(non_empty_indices)
  199. return batch_result
  200. def _format_batch_content(self, texts: List[str]) -> str:
  201. """格式化批量翻译内容"""
  202. lines = []
  203. for i, text in enumerate(texts, 1):
  204. lines.append(f"[{i}] {text}")
  205. return "\n".join(lines)
  206. def _parse_batch_response(self, response: str, expected_count: int) -> List[str]:
  207. """
  208. 解析批量翻译响应
  209. Args:
  210. response: AI 响应文本
  211. expected_count: 期望的翻译数量
  212. Returns:
  213. List[str]: 翻译结果列表
  214. """
  215. results = []
  216. lines = response.strip().split("\n")
  217. current_idx = None
  218. current_text = []
  219. for line in lines:
  220. # 尝试匹配 [数字] 格式
  221. stripped = line.strip()
  222. if stripped.startswith("[") and "]" in stripped:
  223. bracket_end = stripped.index("]")
  224. try:
  225. idx = int(stripped[1:bracket_end])
  226. # 保存之前的内容
  227. if current_idx is not None:
  228. results.append((current_idx, "\n".join(current_text).strip()))
  229. current_idx = idx
  230. current_text = [stripped[bracket_end + 1:].strip()]
  231. except ValueError:
  232. if current_idx is not None:
  233. current_text.append(line)
  234. else:
  235. if current_idx is not None:
  236. current_text.append(line)
  237. # 保存最后一条
  238. if current_idx is not None:
  239. results.append((current_idx, "\n".join(current_text).strip()))
  240. # 按索引排序并提取文本
  241. results.sort(key=lambda x: x[0])
  242. translated = [text for _, text in results]
  243. # 如果解析结果数量不匹配,尝试简单按行分割
  244. if len(translated) != expected_count:
  245. # 回退:按行分割(去除编号)
  246. translated = []
  247. for line in lines:
  248. stripped = line.strip()
  249. if stripped.startswith("[") and "]" in stripped:
  250. bracket_end = stripped.index("]")
  251. translated.append(stripped[bracket_end + 1:].strip())
  252. elif stripped:
  253. translated.append(stripped)
  254. # 确保返回正确数量
  255. while len(translated) < expected_count:
  256. translated.append("")
  257. return translated[:expected_count]
  258. def _call_ai_api(self, user_prompt: str) -> str:
  259. """调用 AI API"""
  260. if self.provider == "gemini":
  261. return self._call_gemini(user_prompt)
  262. return self._call_openai_compatible(user_prompt)
  263. def _get_api_url(self) -> str:
  264. """获取完整 API URL"""
  265. if self.base_url:
  266. return self.base_url
  267. urls = {
  268. "deepseek": "https://api.deepseek.com/v1/chat/completions",
  269. "openai": "https://api.openai.com/v1/chat/completions",
  270. }
  271. url = urls.get(self.provider)
  272. if not url:
  273. raise ValueError(f"{self.provider} 需要配置 base_url")
  274. return url
  275. def _call_openai_compatible(self, user_prompt: str) -> str:
  276. """调用 OpenAI 兼容接口"""
  277. import requests
  278. url = self._get_api_url()
  279. headers = {
  280. "Authorization": f"Bearer {self.api_key}",
  281. "Content-Type": "application/json",
  282. }
  283. messages = []
  284. if self.system_prompt:
  285. messages.append({"role": "system", "content": self.system_prompt})
  286. messages.append({"role": "user", "content": user_prompt})
  287. payload = {
  288. "model": self.model,
  289. "messages": messages,
  290. "temperature": self.temperature,
  291. }
  292. if self.max_tokens:
  293. payload["max_tokens"] = self.max_tokens
  294. if self.extra_params:
  295. payload.update(self.extra_params)
  296. response = requests.post(
  297. url,
  298. headers=headers,
  299. json=payload,
  300. timeout=self.timeout,
  301. )
  302. response.raise_for_status()
  303. data = response.json()
  304. return data["choices"][0]["message"]["content"]
  305. def _call_gemini(self, user_prompt: str) -> str:
  306. """调用 Google Gemini API"""
  307. import requests
  308. model = self.model or "gemini-1.5-flash"
  309. url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={self.api_key}"
  310. headers = {
  311. "Content-Type": "application/json",
  312. }
  313. payload = {
  314. "contents": [{
  315. "role": "user",
  316. "parts": [{"text": user_prompt}]
  317. }],
  318. "generationConfig": {
  319. "temperature": self.temperature,
  320. },
  321. "safetySettings": [
  322. {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
  323. {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
  324. {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
  325. {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
  326. ]
  327. }
  328. if self.system_prompt:
  329. payload["system_instruction"] = {
  330. "parts": [{"text": self.system_prompt}]
  331. }
  332. if self.max_tokens:
  333. payload["generationConfig"]["maxOutputTokens"] = self.max_tokens
  334. if self.extra_params:
  335. payload["generationConfig"].update(self.extra_params)
  336. response = requests.post(
  337. url,
  338. headers=headers,
  339. json=payload,
  340. timeout=self.timeout,
  341. )
  342. response.raise_for_status()
  343. data = response.json()
  344. return data["candidates"][0]["content"]["parts"][0]["text"]