| 12345678910111213141516171819202122232425262728293031323334353637383940 |
- from __future__ import annotations
- from collections.abc import Sequence
- from typing import Any
- import torch
- def record_collate_fn(batch: Sequence[dict[str, Any]]) -> dict[str, Any]:
- if not batch:
- raise ValueError("Empty batch is not allowed.")
- collated: dict[str, Any] = {}
- keys = batch[0].keys()
- for key in keys:
- values = [sample[key] for sample in batch]
- first = values[0]
- if torch.is_tensor(first):
- shapes = [tuple(value.shape) for value in values]
- if all(shape == shapes[0] for shape in shapes):
- collated[key] = torch.stack(values, dim=0)
- else:
- collated[key] = values
- continue
- if first is None:
- collated[key] = values
- continue
- if isinstance(first, (str, int, float, dict)):
- collated[key] = values
- continue
- collated[key] = values
- return collated
- __all__ = ["record_collate_fn"]
|