from pathlib import Path from typing import Callable, Optional, List, Dict, Any from monai.data import Dataset class PolypDetectionDataset(Dataset): """Polyp detection dataset for training and validation.""" def __init__( self, root_dir: str, flag: str = "train", transform: Optional[Callable] = None ): """Initialize dataset. Args: root_dir: Root directory containing images, masks, and split files flag: Dataset split ('train' or 'val') transform: Optional transformations to apply """ # Set up directory paths self.root_dir = Path(root_dir) self.images_dir = self.root_dir / "images" self.labels_dir = self.root_dir / "masks" self.flag = flag.lower() self.transform = transform # Validate split flag if self.flag not in ["train", "val"]: raise ValueError(f"flag must be 'train' or 'val', got '{self.flag}'") # Check if directories exist if not self.images_dir.exists(): raise FileNotFoundError(f"Image directory does not exist: {self.images_dir}") if not self.labels_dir.exists(): raise FileNotFoundError(f"Label directory does not exist: {self.labels_dir}") # Load image filenames from split file txt_file = self.root_dir / f"{self.flag}.txt" if not txt_file.exists(): raise FileNotFoundError( f"{self.flag}.txt file not found in {self.root_dir}\n" f"Please ensure train.txt and val.txt files exist in this directory" ) with open(txt_file, "r", encoding="utf-8") as f: self.images: List[str] = [ line.strip() for line in f.readlines() if line.strip() ] # Labels have same filenames as images self.labels: List[str] = self.images.copy() # Create data list with image-label pairs data = [ {"image": str(self.images_dir / img), "label": str(self.labels_dir / lbl)} for img, lbl in zip(self.images, self.labels) ] super().__init__(data=data, transform=transform) def __len__(self) -> int: """Return number of samples in dataset.""" return len(self.images) def __getitem__(self, idx: int) -> Dict[str, Any]: """Get a single sample by index. Args: idx: Sample index Returns: Dictionary with 'image' and 'label' tensors """ # Build file paths image_path = str(self.images_dir / self.images[idx]) label_path = str(self.labels_dir / self.labels[idx]) data = {"image": image_path, "label": label_path} # Apply transformations if specified if self.transform is not None: data = self.transform(data) return data def get_image_filename(self, idx: int) -> str: """Get image filename by index.""" return self.images[idx] def get_label_filename(self, idx: int) -> str: """Get label filename by index.""" return self.labels[idx] def get_dataset_info(self, dataset_name="CVC_300") -> Dict[str, Any]: """Get dataset information. Args: dataset_name: Name of the dataset Returns: Dictionary containing dataset metadata """ return { "dataset_name": dataset_name, "split": self.flag, "num_samples": len(self), "root_dir": str(self.root_dir), "images_dir": str(self.images_dir), "labels_dir": str(self.labels_dir), "has_transform": self.transform is not None, }