ddti.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from __future__ import annotations
  2. import json
  3. from pathlib import Path
  4. import xml.etree.ElementTree as ET
  5. from PIL import Image, ImageDraw
  6. def parse_ddti_xml(annotation_path: str | Path) -> dict[int, list[list[tuple[int, int]]]]:
  7. """
  8. 解析 DDTI 的 xml 标注。
  9. Returns:
  10. {image_index: [polygon1, polygon2, ...]}
  11. """
  12. annotation_path = Path(annotation_path)
  13. root = ET.parse(annotation_path).getroot()
  14. image_to_polygons: dict[int, list[list[tuple[int, int]]]] = {}
  15. for mark in root.findall("mark"):
  16. image_text = mark.findtext("image")
  17. svg_text = mark.findtext("svg")
  18. if not image_text or not svg_text:
  19. continue
  20. image_index = int(image_text)
  21. try:
  22. shapes = json.loads(svg_text)
  23. except json.JSONDecodeError:
  24. continue
  25. polygons: list[list[tuple[int, int]]] = []
  26. for shape in shapes:
  27. points = shape.get("points", [])
  28. polygon = []
  29. for point in points:
  30. x = int(round(point["x"]))
  31. y = int(round(point["y"]))
  32. polygon.append((x, y))
  33. if len(polygon) >= 3:
  34. polygons.append(polygon)
  35. if polygons:
  36. image_to_polygons[image_index] = polygons
  37. return image_to_polygons
  38. def build_ddti_mask(
  39. image_path: str | Path,
  40. annotation_path: str | Path,
  41. image_index: int | None = None,
  42. fill_value: int = 255,
  43. ) -> Image.Image:
  44. """
  45. 根据 DDTI 的 xml 为指定图像生成二值掩膜。
  46. """
  47. image_path = Path(image_path)
  48. annotation_path = Path(annotation_path)
  49. image = Image.open(image_path)
  50. width, height = image.size
  51. if image_index is None:
  52. stem = image_path.stem
  53. if "_" not in stem:
  54. raise ValueError(f"Cannot infer image index from file name: {image_path.name}")
  55. _, image_idx_str = stem.split("_", 1)
  56. image_index = int(image_idx_str)
  57. polygons_map = parse_ddti_xml(annotation_path)
  58. polygons = polygons_map.get(int(image_index), [])
  59. mask = Image.new("L", (width, height), 0)
  60. draw = ImageDraw.Draw(mask)
  61. for polygon in polygons:
  62. draw.polygon(polygon, outline=fill_value, fill=fill_value)
  63. return mask
  64. __all__ = ["parse_ddti_xml", "build_ddti_mask"]