from __future__ import annotations import argparse import sys from pathlib import Path ROOT_DIR = Path(__file__).resolve().parents[3] if str(ROOT_DIR) not in sys.path: sys.path.insert(0, str(ROOT_DIR)) from lib.data import build_dataloader def main() -> None: parser = argparse.ArgumentParser(description="Inspect dataset and dataloader.") parser.add_argument("--dataset", required=True, help="Dataset name") parser.add_argument("--root", required=True, help="Dataset root") parser.add_argument("--split", default=None, help="Optional split name") parser.add_argument("--split-file", default=None, help="Optional official split file") parser.add_argument("--batch-size", type=int, default=2) args = parser.parse_args() loader = build_dataloader( dataset_name=args.dataset, root=args.root, split=args.split, split_file=args.split_file, batch_size=args.batch_size, shuffle=False, num_workers=0, ) batch = next(iter(loader)) print("dataset_len:", len(loader.dataset)) image_value = batch["image"] mask_value = batch["mask"] if hasattr(image_value, "shape"): print("image_shape:", tuple(image_value.shape)) else: print("image_shapes:", [tuple(item.shape) for item in image_value[: min(3, len(image_value))]]) if hasattr(mask_value, "shape"): print("mask_shape:", tuple(mask_value.shape)) else: sample_masks = [] for item in mask_value[: min(3, len(mask_value))]: sample_masks.append(None if item is None else tuple(item.shape)) print("mask_shapes:", sample_masks) print("sample_ids:", batch["sample_id"]) print("splits:", batch["split"]) print("class_names:", batch["class_name"]) if __name__ == "__main__": main()