PolypDetectionDataset.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from pathlib import Path
  2. from typing import Callable, Optional, List, Dict, Any
  3. from monai.data import Dataset
  4. class PolypDetectionDataset(Dataset):
  5. """Polyp detection dataset for training and validation."""
  6. def __init__(
  7. self, root_dir: str, flag: str = "train", transform: Optional[Callable] = None
  8. ):
  9. """Initialize dataset.
  10. Args:
  11. root_dir: Root directory containing images, masks, and split files
  12. flag: Dataset split ('train' or 'val')
  13. transform: Optional transformations to apply
  14. """
  15. # Set up directory paths
  16. self.root_dir = Path(root_dir)
  17. self.images_dir = self.root_dir / "images"
  18. self.labels_dir = self.root_dir / "masks"
  19. self.flag = flag.lower()
  20. self.transform = transform
  21. # Validate split flag
  22. if self.flag not in ["train", "val"]:
  23. raise ValueError(f"flag must be 'train' or 'val', got '{self.flag}'")
  24. # Check if directories exist
  25. if not self.images_dir.exists():
  26. raise FileNotFoundError(f"Image directory does not exist: {self.images_dir}")
  27. if not self.labels_dir.exists():
  28. raise FileNotFoundError(f"Label directory does not exist: {self.labels_dir}")
  29. # Load image filenames from split file
  30. txt_file = self.root_dir / f"{self.flag}.txt"
  31. if not txt_file.exists():
  32. raise FileNotFoundError(
  33. f"{self.flag}.txt file not found in {self.root_dir}\n"
  34. f"Please ensure train.txt and val.txt files exist in this directory"
  35. )
  36. with open(txt_file, "r", encoding="utf-8") as f:
  37. self.images: List[str] = [
  38. line.strip() for line in f.readlines() if line.strip()
  39. ]
  40. # Labels have same filenames as images
  41. self.labels: List[str] = self.images.copy()
  42. # Create data list with image-label pairs
  43. data = [
  44. {"image": str(self.images_dir / img), "label": str(self.labels_dir / lbl)}
  45. for img, lbl in zip(self.images, self.labels)
  46. ]
  47. super().__init__(data=data, transform=transform)
  48. def __len__(self) -> int:
  49. """Return number of samples in dataset."""
  50. return len(self.images)
  51. def __getitem__(self, idx: int) -> Dict[str, Any]:
  52. """Get a single sample by index.
  53. Args:
  54. idx: Sample index
  55. Returns:
  56. Dictionary with 'image' and 'label' tensors
  57. """
  58. # Build file paths
  59. image_path = str(self.images_dir / self.images[idx])
  60. label_path = str(self.labels_dir / self.labels[idx])
  61. data = {"image": image_path, "label": label_path}
  62. # Apply transformations if specified
  63. if self.transform is not None:
  64. data = self.transform(data)
  65. return data
  66. def get_image_filename(self, idx: int) -> str:
  67. """Get image filename by index."""
  68. return self.images[idx]
  69. def get_label_filename(self, idx: int) -> str:
  70. """Get label filename by index."""
  71. return self.labels[idx]
  72. def get_dataset_info(self, dataset_name="CVC_300") -> Dict[str, Any]:
  73. """Get dataset information.
  74. Args:
  75. dataset_name: Name of the dataset
  76. Returns:
  77. Dictionary containing dataset metadata
  78. """
  79. return {
  80. "dataset_name": dataset_name,
  81. "split": self.flag,
  82. "num_samples": len(self),
  83. "root_dir": str(self.root_dir),
  84. "images_dir": str(self.images_dir),
  85. "labels_dir": str(self.labels_dir),
  86. "has_transform": self.transform is not None,
  87. }