| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- from pathlib import Path
- from typing import Callable, Optional, List, Dict, Any
- from monai.data import Dataset
- class PolypDetectionDataset(Dataset):
- """Polyp detection dataset for training and validation."""
- def __init__(
- self, root_dir: str, flag: str = "train", transform: Optional[Callable] = None
- ):
- """Initialize dataset.
-
- Args:
- root_dir: Root directory containing images, masks, and split files
- flag: Dataset split ('train' or 'val')
- transform: Optional transformations to apply
- """
- # Set up directory paths
- self.root_dir = Path(root_dir)
- self.images_dir = self.root_dir / "images"
- self.labels_dir = self.root_dir / "masks"
- self.flag = flag.lower()
- self.transform = transform
- # Validate split flag
- if self.flag not in ["train", "val"]:
- raise ValueError(f"flag must be 'train' or 'val', got '{self.flag}'")
- # Check if directories exist
- if not self.images_dir.exists():
- raise FileNotFoundError(f"Image directory does not exist: {self.images_dir}")
- if not self.labels_dir.exists():
- raise FileNotFoundError(f"Label directory does not exist: {self.labels_dir}")
- # Load image filenames from split file
- txt_file = self.root_dir / f"{self.flag}.txt"
- if not txt_file.exists():
- raise FileNotFoundError(
- f"{self.flag}.txt file not found in {self.root_dir}\n"
- f"Please ensure train.txt and val.txt files exist in this directory"
- )
- with open(txt_file, "r", encoding="utf-8") as f:
- self.images: List[str] = [
- line.strip() for line in f.readlines() if line.strip()
- ]
- # Labels have same filenames as images
- self.labels: List[str] = self.images.copy()
- # Create data list with image-label pairs
- data = [
- {"image": str(self.images_dir / img), "label": str(self.labels_dir / lbl)}
- for img, lbl in zip(self.images, self.labels)
- ]
- super().__init__(data=data, transform=transform)
- def __len__(self) -> int:
- """Return number of samples in dataset."""
- return len(self.images)
- def __getitem__(self, idx: int) -> Dict[str, Any]:
- """Get a single sample by index.
-
- Args:
- idx: Sample index
-
- Returns:
- Dictionary with 'image' and 'label' tensors
- """
- # Build file paths
- image_path = str(self.images_dir / self.images[idx])
- label_path = str(self.labels_dir / self.labels[idx])
- data = {"image": image_path, "label": label_path}
- # Apply transformations if specified
- if self.transform is not None:
- data = self.transform(data)
- return data
- def get_image_filename(self, idx: int) -> str:
- """Get image filename by index."""
- return self.images[idx]
- def get_label_filename(self, idx: int) -> str:
- """Get label filename by index."""
- return self.labels[idx]
- def get_dataset_info(self, dataset_name="CVC_300") -> Dict[str, Any]:
- """Get dataset information.
-
- Args:
- dataset_name: Name of the dataset
-
- Returns:
- Dictionary containing dataset metadata
- """
- return {
- "dataset_name": dataset_name,
- "split": self.flag,
- "num_samples": len(self),
- "root_dir": str(self.root_dir),
- "images_dir": str(self.images_dir),
- "labels_dir": str(self.labels_dir),
- "has_transform": self.transform is not None,
- }
|