from __future__ import annotations from collections.abc import Sequence from typing import Any import torch def record_collate_fn(batch: Sequence[dict[str, Any]]) -> dict[str, Any]: if not batch: raise ValueError("Empty batch is not allowed.") collated: dict[str, Any] = {} keys = batch[0].keys() for key in keys: values = [sample[key] for sample in batch] first = values[0] if torch.is_tensor(first): shapes = [tuple(value.shape) for value in values] if all(shape == shapes[0] for shape in shapes): collated[key] = torch.stack(values, dim=0) else: collated[key] = values continue if first is None: collated[key] = values continue if isinstance(first, (str, int, float, dict)): collated[key] = values continue collated[key] = values return collated __all__ = ["record_collate_fn"]