| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- from __future__ import annotations
- import argparse
- import sys
- from pathlib import Path
- ROOT_DIR = Path(__file__).resolve().parents[3]
- if str(ROOT_DIR) not in sys.path:
- sys.path.insert(0, str(ROOT_DIR))
- from lib.data import build_dataloader
- def main() -> None:
- parser = argparse.ArgumentParser(description="Inspect dataset and dataloader.")
- parser.add_argument("--dataset", required=True, help="Dataset name")
- parser.add_argument("--root", required=True, help="Dataset root")
- parser.add_argument("--split", default=None, help="Optional split name")
- parser.add_argument("--split-file", default=None, help="Optional official split file")
- parser.add_argument("--batch-size", type=int, default=2)
- args = parser.parse_args()
- loader = build_dataloader(
- dataset_name=args.dataset,
- root=args.root,
- split=args.split,
- split_file=args.split_file,
- batch_size=args.batch_size,
- shuffle=False,
- num_workers=0,
- )
- batch = next(iter(loader))
- print("dataset_len:", len(loader.dataset))
- image_value = batch["image"]
- mask_value = batch["mask"]
- if hasattr(image_value, "shape"):
- print("image_shape:", tuple(image_value.shape))
- else:
- print("image_shapes:", [tuple(item.shape) for item in image_value[: min(3, len(image_value))]])
- if hasattr(mask_value, "shape"):
- print("mask_shape:", tuple(mask_value.shape))
- else:
- sample_masks = []
- for item in mask_value[: min(3, len(mask_value))]:
- sample_masks.append(None if item is None else tuple(item.shape))
- print("mask_shapes:", sample_masks)
- print("sample_ids:", batch["sample_id"])
- print("splits:", batch["split"])
- print("class_names:", batch["class_name"])
- if __name__ == "__main__":
- main()
|