builder.py 777 B

123456789101112131415161718192021222324252627
  1. from __future__ import annotations
  2. from typing import Any
  3. from .base import BaseTrainer
  4. from .supervised import SupervisedSegmentationTrainer
  5. TRAINER_REGISTRY = {
  6. "supervised_segmentation": SupervisedSegmentationTrainer,
  7. }
  8. def build_trainer(cfg: dict[str, Any], args: Any | None = None) -> BaseTrainer:
  9. trainer_cfg = cfg.get("trainer", {})
  10. trainer_name = trainer_cfg.get("name", "supervised_segmentation")
  11. trainer_cls = TRAINER_REGISTRY.get(trainer_name)
  12. if trainer_cls is None:
  13. raise ValueError(
  14. f"Unsupported trainer '{trainer_name}'. Expected one of: {', '.join(TRAINER_REGISTRY)}."
  15. )
  16. trainer = trainer_cls(cfg=cfg, args=args)
  17. trainer.build()
  18. return trainer
  19. __all__ = ["TRAINER_REGISTRY", "build_trainer"]