collate.py 997 B

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from __future__ import annotations
  2. from collections.abc import Sequence
  3. from typing import Any
  4. import torch
  5. def record_collate_fn(batch: Sequence[dict[str, Any]]) -> dict[str, Any]:
  6. if not batch:
  7. raise ValueError("Empty batch is not allowed.")
  8. collated: dict[str, Any] = {}
  9. keys = batch[0].keys()
  10. for key in keys:
  11. values = [sample[key] for sample in batch]
  12. first = values[0]
  13. if torch.is_tensor(first):
  14. shapes = [tuple(value.shape) for value in values]
  15. if all(shape == shapes[0] for shape in shapes):
  16. collated[key] = torch.stack(values, dim=0)
  17. else:
  18. collated[key] = values
  19. continue
  20. if first is None:
  21. collated[key] = values
  22. continue
  23. if isinstance(first, (str, int, float, dict)):
  24. collated[key] = values
  25. continue
  26. collated[key] = values
  27. return collated
  28. __all__ = ["record_collate_fn"]