inspect_dataloader.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from __future__ import annotations
  2. import argparse
  3. import sys
  4. from pathlib import Path
  5. ROOT_DIR = Path(__file__).resolve().parents[3]
  6. if str(ROOT_DIR) not in sys.path:
  7. sys.path.insert(0, str(ROOT_DIR))
  8. from lib.data import build_dataloader
  9. def main() -> None:
  10. parser = argparse.ArgumentParser(description="Inspect dataset and dataloader.")
  11. parser.add_argument("--dataset", required=True, help="Dataset name")
  12. parser.add_argument("--root", required=True, help="Dataset root")
  13. parser.add_argument("--split", default=None, help="Optional split name")
  14. parser.add_argument("--split-file", default=None, help="Optional official split file")
  15. parser.add_argument("--batch-size", type=int, default=2)
  16. args = parser.parse_args()
  17. loader = build_dataloader(
  18. dataset_name=args.dataset,
  19. root=args.root,
  20. split=args.split,
  21. split_file=args.split_file,
  22. batch_size=args.batch_size,
  23. shuffle=False,
  24. num_workers=0,
  25. )
  26. batch = next(iter(loader))
  27. print("dataset_len:", len(loader.dataset))
  28. image_value = batch["image"]
  29. mask_value = batch["mask"]
  30. if hasattr(image_value, "shape"):
  31. print("image_shape:", tuple(image_value.shape))
  32. else:
  33. print("image_shapes:", [tuple(item.shape) for item in image_value[: min(3, len(image_value))]])
  34. if hasattr(mask_value, "shape"):
  35. print("mask_shape:", tuple(mask_value.shape))
  36. else:
  37. sample_masks = []
  38. for item in mask_value[: min(3, len(mask_value))]:
  39. sample_masks.append(None if item is None else tuple(item.shape))
  40. print("mask_shapes:", sample_masks)
  41. print("sample_ids:", batch["sample_id"])
  42. print("splits:", batch["split"])
  43. print("class_names:", batch["class_name"])
  44. if __name__ == "__main__":
  45. main()