| 123456789101112131415161718192021222324252627 |
- from __future__ import annotations
- from typing import Any
- from .base import BaseTrainer
- from .supervised import SupervisedSegmentationTrainer
- TRAINER_REGISTRY = {
- "supervised_segmentation": SupervisedSegmentationTrainer,
- }
- def build_trainer(cfg: dict[str, Any], args: Any | None = None) -> BaseTrainer:
- trainer_cfg = cfg.get("trainer", {})
- trainer_name = trainer_cfg.get("name", "supervised_segmentation")
- trainer_cls = TRAINER_REGISTRY.get(trainer_name)
- if trainer_cls is None:
- raise ValueError(
- f"Unsupported trainer '{trainer_name}'. Expected one of: {', '.join(TRAINER_REGISTRY)}."
- )
- trainer = trainer_cls(cfg=cfg, args=args)
- trainer.build()
- return trainer
- __all__ = ["TRAINER_REGISTRY", "build_trainer"]
|