indexers.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. from __future__ import annotations
  2. from pathlib import Path
  3. from .records import SegSampleRecord
  4. from .utils import list_image_files, stem_without_mask_suffix
  5. def build_paired_folder_records(
  6. dataset_name: str,
  7. image_dir: Path,
  8. mask_dir: Path,
  9. *,
  10. split: str | None = None,
  11. class_name: str | None = None,
  12. ) -> list[SegSampleRecord]:
  13. images = list_image_files(image_dir)
  14. masks = list_image_files(mask_dir)
  15. mask_map = {mask.name: mask for mask in masks}
  16. records: list[SegSampleRecord] = []
  17. for image in images:
  18. mask = mask_map.get(image.name)
  19. if mask is None:
  20. continue
  21. records.append(
  22. SegSampleRecord(
  23. dataset_name=dataset_name,
  24. image_path=image,
  25. mask_path=mask,
  26. split=split,
  27. sample_id=image.stem,
  28. class_name=class_name,
  29. )
  30. )
  31. return records
  32. def build_prefixed_paired_records(
  33. dataset_name: str,
  34. image_dir: Path,
  35. mask_dir: Path,
  36. *,
  37. image_prefix_to_strip: str = "",
  38. mask_prefix_to_strip: str = "",
  39. split: str | None = None,
  40. class_name: str | None = None,
  41. ) -> list[SegSampleRecord]:
  42. images = list_image_files(image_dir)
  43. masks = list_image_files(mask_dir)
  44. def _normalize(path: Path, prefix: str) -> str:
  45. name = path.name
  46. if prefix and name.startswith(prefix):
  47. name = name[len(prefix):]
  48. return name
  49. mask_map = {_normalize(mask, mask_prefix_to_strip): mask for mask in masks}
  50. records: list[SegSampleRecord] = []
  51. for image in images:
  52. key = _normalize(image, image_prefix_to_strip)
  53. mask = mask_map.get(key)
  54. if mask is None:
  55. continue
  56. records.append(
  57. SegSampleRecord(
  58. dataset_name=dataset_name,
  59. image_path=image,
  60. mask_path=mask,
  61. split=split,
  62. sample_id=image.stem,
  63. class_name=class_name,
  64. )
  65. )
  66. return records
  67. def build_stem_paired_records(
  68. dataset_name: str,
  69. image_dir: Path,
  70. mask_dir: Path,
  71. *,
  72. split: str | None = None,
  73. class_name: str | None = None,
  74. prefer_plain_mask: bool = True,
  75. ) -> list[SegSampleRecord]:
  76. images = list_image_files(image_dir)
  77. masks = list_image_files(mask_dir)
  78. grouped_masks: dict[str, list[Path]] = {}
  79. for mask in masks:
  80. key = stem_without_mask_suffix(mask.name)
  81. grouped_masks.setdefault(key, []).append(mask)
  82. records: list[SegSampleRecord] = []
  83. for image in images:
  84. key = image.stem
  85. candidates = sorted(grouped_masks.get(key, []))
  86. if not candidates:
  87. continue
  88. mask = candidates[0]
  89. if prefer_plain_mask:
  90. plain = [candidate for candidate in candidates if "_binary" not in candidate.stem.lower()]
  91. if plain:
  92. mask = plain[0]
  93. records.append(
  94. SegSampleRecord(
  95. dataset_name=dataset_name,
  96. image_path=image,
  97. mask_path=mask,
  98. split=split,
  99. sample_id=key,
  100. class_name=class_name,
  101. meta={"mask_candidates": str(len(candidates))},
  102. )
  103. )
  104. return records
  105. def build_filename_matched_records(
  106. dataset_name: str,
  107. folder: Path,
  108. *,
  109. split: str | None = None,
  110. class_name: str | None = None,
  111. ) -> list[SegSampleRecord]:
  112. files = list_image_files(folder)
  113. image_map: dict[str, Path] = {}
  114. mask_map: dict[str, list[Path]] = {}
  115. for path in files:
  116. key = stem_without_mask_suffix(path.name)
  117. if "_mask" in path.stem:
  118. mask_map.setdefault(key, []).append(path)
  119. else:
  120. image_map[key] = path
  121. records: list[SegSampleRecord] = []
  122. for key, image in sorted(image_map.items()):
  123. masks = sorted(mask_map.get(key, []))
  124. if not masks:
  125. continue
  126. records.append(
  127. SegSampleRecord(
  128. dataset_name=dataset_name,
  129. image_path=image,
  130. mask_path=masks[0],
  131. split=split,
  132. sample_id=key,
  133. class_name=class_name,
  134. meta={"mask_count": str(len(masks))},
  135. )
  136. )
  137. return records
  138. def build_pre_split_records(
  139. dataset_name: str,
  140. train_image_dir: Path,
  141. train_mask_dir: Path,
  142. test_image_dir: Path,
  143. test_mask_dir: Path,
  144. ) -> list[SegSampleRecord]:
  145. records = []
  146. records.extend(
  147. build_paired_folder_records(
  148. dataset_name=dataset_name,
  149. image_dir=train_image_dir,
  150. mask_dir=train_mask_dir,
  151. split="trainval",
  152. )
  153. )
  154. records.extend(
  155. build_paired_folder_records(
  156. dataset_name=dataset_name,
  157. image_dir=test_image_dir,
  158. mask_dir=test_mask_dir,
  159. split="test",
  160. )
  161. )
  162. return records
  163. def build_xml_annotation_records(
  164. dataset_name: str,
  165. root: Path,
  166. *,
  167. split: str | None = None,
  168. class_name: str | None = None,
  169. ) -> list[SegSampleRecord]:
  170. xml_map = {path.stem: path for path in sorted(root.glob("*.xml"))}
  171. image_files = sorted(root.glob("*.jpg"))
  172. records: list[SegSampleRecord] = []
  173. for image in image_files:
  174. sample_key = image.stem.split("_")[0]
  175. annotation = xml_map.get(sample_key)
  176. if annotation is None:
  177. continue
  178. records.append(
  179. SegSampleRecord(
  180. dataset_name=dataset_name,
  181. image_path=image,
  182. mask_path=None,
  183. annotation_path=annotation,
  184. split=split,
  185. sample_id=image.stem,
  186. class_name=class_name,
  187. meta={"annotation_type": "xml"},
  188. )
  189. )
  190. return records
  191. __all__ = [
  192. "build_paired_folder_records",
  193. "build_prefixed_paired_records",
  194. "build_stem_paired_records",
  195. "build_filename_matched_records",
  196. "build_pre_split_records",
  197. "build_xml_annotation_records",
  198. ]