| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- from __future__ import annotations
- from pathlib import Path
- from .records import SegSampleRecord
- from .utils import list_image_files, stem_without_mask_suffix
- def build_paired_folder_records(
- dataset_name: str,
- image_dir: Path,
- mask_dir: Path,
- *,
- split: str | None = None,
- class_name: str | None = None,
- ) -> list[SegSampleRecord]:
- images = list_image_files(image_dir)
- masks = list_image_files(mask_dir)
- mask_map = {mask.name: mask for mask in masks}
- records: list[SegSampleRecord] = []
- for image in images:
- mask = mask_map.get(image.name)
- if mask is None:
- continue
- records.append(
- SegSampleRecord(
- dataset_name=dataset_name,
- image_path=image,
- mask_path=mask,
- split=split,
- sample_id=image.stem,
- class_name=class_name,
- )
- )
- return records
- def build_prefixed_paired_records(
- dataset_name: str,
- image_dir: Path,
- mask_dir: Path,
- *,
- image_prefix_to_strip: str = "",
- mask_prefix_to_strip: str = "",
- split: str | None = None,
- class_name: str | None = None,
- ) -> list[SegSampleRecord]:
- images = list_image_files(image_dir)
- masks = list_image_files(mask_dir)
- def _normalize(path: Path, prefix: str) -> str:
- name = path.name
- if prefix and name.startswith(prefix):
- name = name[len(prefix):]
- return name
- mask_map = {_normalize(mask, mask_prefix_to_strip): mask for mask in masks}
- records: list[SegSampleRecord] = []
- for image in images:
- key = _normalize(image, image_prefix_to_strip)
- mask = mask_map.get(key)
- if mask is None:
- continue
- records.append(
- SegSampleRecord(
- dataset_name=dataset_name,
- image_path=image,
- mask_path=mask,
- split=split,
- sample_id=image.stem,
- class_name=class_name,
- )
- )
- return records
- def build_stem_paired_records(
- dataset_name: str,
- image_dir: Path,
- mask_dir: Path,
- *,
- split: str | None = None,
- class_name: str | None = None,
- prefer_plain_mask: bool = True,
- ) -> list[SegSampleRecord]:
- images = list_image_files(image_dir)
- masks = list_image_files(mask_dir)
- grouped_masks: dict[str, list[Path]] = {}
- for mask in masks:
- key = stem_without_mask_suffix(mask.name)
- grouped_masks.setdefault(key, []).append(mask)
- records: list[SegSampleRecord] = []
- for image in images:
- key = image.stem
- candidates = sorted(grouped_masks.get(key, []))
- if not candidates:
- continue
- mask = candidates[0]
- if prefer_plain_mask:
- plain = [candidate for candidate in candidates if "_binary" not in candidate.stem.lower()]
- if plain:
- mask = plain[0]
- records.append(
- SegSampleRecord(
- dataset_name=dataset_name,
- image_path=image,
- mask_path=mask,
- split=split,
- sample_id=key,
- class_name=class_name,
- meta={"mask_candidates": str(len(candidates))},
- )
- )
- return records
- def build_filename_matched_records(
- dataset_name: str,
- folder: Path,
- *,
- split: str | None = None,
- class_name: str | None = None,
- ) -> list[SegSampleRecord]:
- files = list_image_files(folder)
- image_map: dict[str, Path] = {}
- mask_map: dict[str, list[Path]] = {}
- for path in files:
- key = stem_without_mask_suffix(path.name)
- if "_mask" in path.stem:
- mask_map.setdefault(key, []).append(path)
- else:
- image_map[key] = path
- records: list[SegSampleRecord] = []
- for key, image in sorted(image_map.items()):
- masks = sorted(mask_map.get(key, []))
- if not masks:
- continue
- records.append(
- SegSampleRecord(
- dataset_name=dataset_name,
- image_path=image,
- mask_path=masks[0],
- split=split,
- sample_id=key,
- class_name=class_name,
- meta={"mask_count": str(len(masks))},
- )
- )
- return records
- def build_pre_split_records(
- dataset_name: str,
- train_image_dir: Path,
- train_mask_dir: Path,
- test_image_dir: Path,
- test_mask_dir: Path,
- ) -> list[SegSampleRecord]:
- records = []
- records.extend(
- build_paired_folder_records(
- dataset_name=dataset_name,
- image_dir=train_image_dir,
- mask_dir=train_mask_dir,
- split="trainval",
- )
- )
- records.extend(
- build_paired_folder_records(
- dataset_name=dataset_name,
- image_dir=test_image_dir,
- mask_dir=test_mask_dir,
- split="test",
- )
- )
- return records
- def build_xml_annotation_records(
- dataset_name: str,
- root: Path,
- *,
- split: str | None = None,
- class_name: str | None = None,
- ) -> list[SegSampleRecord]:
- xml_map = {path.stem: path for path in sorted(root.glob("*.xml"))}
- image_files = sorted(root.glob("*.jpg"))
- records: list[SegSampleRecord] = []
- for image in image_files:
- sample_key = image.stem.split("_")[0]
- annotation = xml_map.get(sample_key)
- if annotation is None:
- continue
- records.append(
- SegSampleRecord(
- dataset_name=dataset_name,
- image_path=image,
- mask_path=None,
- annotation_path=annotation,
- split=split,
- sample_id=image.stem,
- class_name=class_name,
- meta={"annotation_type": "xml"},
- )
- )
- return records
- __all__ = [
- "build_paired_folder_records",
- "build_prefixed_paired_records",
- "build_stem_paired_records",
- "build_filename_matched_records",
- "build_pre_split_records",
- "build_xml_annotation_records",
- ]
|