|
@@ -6,53 +6,46 @@ from typing import Any
|
|
|
import torch
|
|
import torch
|
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
-from lib.modules import SegmentationNet2d
|
|
|
|
|
-from lib.tools import (
|
|
|
|
|
- BinaryBoundaryLoss,
|
|
|
|
|
- MaskBoundaryConsistencyLoss,
|
|
|
|
|
- build_optimizer,
|
|
|
|
|
- build_scheduler,
|
|
|
|
|
- mask_to_boundary_map,
|
|
|
|
|
-)
|
|
|
|
|
|
|
+from lib.modules import SegmentationModel2d
|
|
|
|
|
+from lib.tools import build_loss, build_optimizer, build_scheduler
|
|
|
from .base import BaseTrainer
|
|
from .base import BaseTrainer
|
|
|
|
|
|
|
|
|
|
|
|
|
class SupervisedSegmentationTrainer(BaseTrainer):
|
|
class SupervisedSegmentationTrainer(BaseTrainer):
|
|
|
def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None:
|
|
def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None:
|
|
|
super().__init__(cfg=cfg, args=args)
|
|
super().__init__(cfg=cfg, args=args)
|
|
|
- self.model: SegmentationNet2d | None = None
|
|
|
|
|
|
|
+ self.model: SegmentationModel2d | None = None
|
|
|
self.optimizer = None
|
|
self.optimizer = None
|
|
|
self.scheduler = None
|
|
self.scheduler = None
|
|
|
self.loader: DataLoader | None = None
|
|
self.loader: DataLoader | None = None
|
|
|
self.val_loader: DataLoader | None = None
|
|
self.val_loader: DataLoader | None = None
|
|
|
self.seg_loss = None
|
|
self.seg_loss = None
|
|
|
- self.boundary_loss = BinaryBoundaryLoss()
|
|
|
|
|
- self.consistency_loss = MaskBoundaryConsistencyLoss()
|
|
|
|
|
|
|
|
|
|
def build(self) -> None:
|
|
def build(self) -> None:
|
|
|
dataset_cfg = self.cfg["dataset"]
|
|
dataset_cfg = self.cfg["dataset"]
|
|
|
model_cfg = self.cfg["model"]
|
|
model_cfg = self.cfg["model"]
|
|
|
train_cfg = self.cfg["train"]
|
|
train_cfg = self.cfg["train"]
|
|
|
|
|
|
|
|
- self.model = SegmentationNet2d(
|
|
|
|
|
|
|
+ self.model = SegmentationModel2d(
|
|
|
num_classes=dataset_cfg["num_classes"],
|
|
num_classes=dataset_cfg["num_classes"],
|
|
|
model_name=model_cfg["model_name"],
|
|
model_name=model_cfg["model_name"],
|
|
|
load_weights=model_cfg.get("load_weights", False),
|
|
load_weights=model_cfg.get("load_weights", False),
|
|
|
decoder_channels=model_cfg.get("decoder_channels"),
|
|
decoder_channels=model_cfg.get("decoder_channels"),
|
|
|
- fwta_wavelet=model_cfg.get("fwta_wavelet", "haar"),
|
|
|
|
|
- fwta_level=model_cfg.get("fwta_level", 1),
|
|
|
|
|
- fwta_sigma_ratio=model_cfg.get("fwta_sigma_ratio", 0.35),
|
|
|
|
|
- fwta_tau_fourier=model_cfg.get("fwta_tau_fourier", 0.15),
|
|
|
|
|
- fwta_gate_temperature=model_cfg.get("fwta_gate_temperature", 1.0),
|
|
|
|
|
|
|
+ use_multiscale_features=model_cfg.get("use_multiscale_features", True),
|
|
|
|
|
+ include_patch_embed=model_cfg.get("include_patch_embed", True),
|
|
|
).to(self.device)
|
|
).to(self.device)
|
|
|
|
|
|
|
|
self.optimizer = build_optimizer(self.model, self.cfg["optimizer"])
|
|
self.optimizer = build_optimizer(self.model, self.cfg["optimizer"])
|
|
|
self.scheduler = build_scheduler(self.optimizer, self.cfg.get("scheduler"))
|
|
self.scheduler = build_scheduler(self.optimizer, self.cfg.get("scheduler"))
|
|
|
|
|
+ loss_cfg = self.cfg.get("loss")
|
|
|
|
|
+ if loss_cfg is not None:
|
|
|
|
|
+ self.seg_loss = build_loss(loss_cfg)
|
|
|
self.loader = self._build_segmentation_loader(
|
|
self.loader = self._build_segmentation_loader(
|
|
|
split=str(dataset_cfg.get("split", "train")),
|
|
split=str(dataset_cfg.get("split", "train")),
|
|
|
split_file=dataset_cfg.get("split_file"),
|
|
split_file=dataset_cfg.get("split_file"),
|
|
|
batch_size=self._resolve_batch_size("batch_size", 4),
|
|
batch_size=self._resolve_batch_size("batch_size", 4),
|
|
|
shuffle=bool(train_cfg.get("shuffle", True)),
|
|
shuffle=bool(train_cfg.get("shuffle", True)),
|
|
|
|
|
+ augmentation_config=self.cfg.get("augmentation", {}).get("train"),
|
|
|
)
|
|
)
|
|
|
self.val_loader = self._build_val_loader(
|
|
self.val_loader = self._build_val_loader(
|
|
|
batch_size=self._resolve_batch_size(
|
|
batch_size=self._resolve_batch_size(
|
|
@@ -78,19 +71,17 @@ class SupervisedSegmentationTrainer(BaseTrainer):
|
|
|
with torch.autocast(device_type=self.device.type, enabled=self._amp_enabled()):
|
|
with torch.autocast(device_type=self.device.type, enabled=self._amp_enabled()):
|
|
|
outputs = self.model(image)
|
|
outputs = self.model(image)
|
|
|
seg_logits = outputs["seg_logits"]
|
|
seg_logits = outputs["seg_logits"]
|
|
|
- boundary_logits = outputs["boundary_logits"]
|
|
|
|
|
|
|
|
|
|
- seg_loss = torch.nn.functional.binary_cross_entropy_with_logits(seg_logits, mask)
|
|
|
|
|
- boundary_target = mask_to_boundary_map(mask)
|
|
|
|
|
- boundary_loss = self.boundary_loss(boundary_logits, boundary_target)
|
|
|
|
|
- consistency_loss = self.consistency_loss(seg_logits, boundary_logits)
|
|
|
|
|
- total_loss = seg_loss + boundary_loss + 0.1 * consistency_loss
|
|
|
|
|
|
|
+ if self.seg_loss is None:
|
|
|
|
|
+ seg_loss = torch.nn.functional.binary_cross_entropy_with_logits(seg_logits, mask)
|
|
|
|
|
+ else:
|
|
|
|
|
+ seg_loss = self.seg_loss(seg_logits, mask)
|
|
|
|
|
+
|
|
|
|
|
+ total_loss = seg_loss
|
|
|
|
|
|
|
|
losses = {
|
|
losses = {
|
|
|
"total": total_loss,
|
|
"total": total_loss,
|
|
|
"seg": seg_loss,
|
|
"seg": seg_loss,
|
|
|
- "boundary": boundary_loss,
|
|
|
|
|
- "consistency": consistency_loss,
|
|
|
|
|
}
|
|
}
|
|
|
return outputs, losses
|
|
return outputs, losses
|
|
|
|
|
|
|
@@ -106,8 +97,6 @@ class SupervisedSegmentationTrainer(BaseTrainer):
|
|
|
metrics = self._build_validation_metrics()
|
|
metrics = self._build_validation_metrics()
|
|
|
total = 0.0
|
|
total = 0.0
|
|
|
seg = 0.0
|
|
seg = 0.0
|
|
|
- boundary = 0.0
|
|
|
|
|
- consistency = 0.0
|
|
|
|
|
steps = 0
|
|
steps = 0
|
|
|
with torch.no_grad():
|
|
with torch.no_grad():
|
|
|
for batch in self.val_loader:
|
|
for batch in self.val_loader:
|
|
@@ -116,8 +105,6 @@ class SupervisedSegmentationTrainer(BaseTrainer):
|
|
|
outputs, losses = self._compute_losses(image, mask)
|
|
outputs, losses = self._compute_losses(image, mask)
|
|
|
total += float(losses["total"].detach().cpu())
|
|
total += float(losses["total"].detach().cpu())
|
|
|
seg += float(losses["seg"].detach().cpu())
|
|
seg += float(losses["seg"].detach().cpu())
|
|
|
- boundary += float(losses["boundary"].detach().cpu())
|
|
|
|
|
- consistency += float(losses["consistency"].detach().cpu())
|
|
|
|
|
self._update_validation_metrics(
|
|
self._update_validation_metrics(
|
|
|
metrics,
|
|
metrics,
|
|
|
logits=outputs["seg_logits"],
|
|
logits=outputs["seg_logits"],
|
|
@@ -130,8 +117,6 @@ class SupervisedSegmentationTrainer(BaseTrainer):
|
|
|
val_metrics = {
|
|
val_metrics = {
|
|
|
"total": total / steps,
|
|
"total": total / steps,
|
|
|
"seg": seg / steps,
|
|
"seg": seg / steps,
|
|
|
- "boundary": boundary / steps,
|
|
|
|
|
- "consistency": consistency / steps,
|
|
|
|
|
}
|
|
}
|
|
|
val_metrics.update(self._compute_validation_metric_values(metrics))
|
|
val_metrics.update(self._compute_validation_metric_values(metrics))
|
|
|
return val_metrics
|
|
return val_metrics
|
|
@@ -141,6 +126,7 @@ class SupervisedSegmentationTrainer(BaseTrainer):
|
|
|
raise RuntimeError("Trainer.build() must be called before train().")
|
|
raise RuntimeError("Trainer.build() must be called before train().")
|
|
|
|
|
|
|
|
epochs = int(self.cfg["train"].get("epochs", 1))
|
|
epochs = int(self.cfg["train"].get("epochs", 1))
|
|
|
|
|
+ accum_steps = self._accum_steps()
|
|
|
try:
|
|
try:
|
|
|
self._print_training_setup(
|
|
self._print_training_setup(
|
|
|
model_map={"model": self.model},
|
|
model_map={"model": self.model},
|
|
@@ -150,11 +136,10 @@ class SupervisedSegmentationTrainer(BaseTrainer):
|
|
|
)
|
|
)
|
|
|
for epoch in range(self.start_epoch, epochs):
|
|
for epoch in range(self.start_epoch, epochs):
|
|
|
self.model.train()
|
|
self.model.train()
|
|
|
|
|
+ self.optimizer.zero_grad()
|
|
|
train_metric_sums = {
|
|
train_metric_sums = {
|
|
|
"total": 0.0,
|
|
"total": 0.0,
|
|
|
"seg": 0.0,
|
|
"seg": 0.0,
|
|
|
- "boundary": 0.0,
|
|
|
|
|
- "consistency": 0.0,
|
|
|
|
|
}
|
|
}
|
|
|
train_metrics: dict[str, float] | None = None
|
|
train_metrics: dict[str, float] | None = None
|
|
|
end_time = time.perf_counter()
|
|
end_time = time.perf_counter()
|
|
@@ -165,14 +150,17 @@ class SupervisedSegmentationTrainer(BaseTrainer):
|
|
|
image = batch["image"].to(self.device)
|
|
image = batch["image"].to(self.device)
|
|
|
mask = batch["mask"].to(self.device)
|
|
mask = batch["mask"].to(self.device)
|
|
|
_, losses = self._compute_losses(image, mask)
|
|
_, losses = self._compute_losses(image, mask)
|
|
|
- self.optimizer.zero_grad()
|
|
|
|
|
- self.grad_scaler.scale(losses["total"]).backward()
|
|
|
|
|
|
|
+ scaled_total_loss = losses["total"] / accum_steps
|
|
|
|
|
+ self.grad_scaler.scale(scaled_total_loss).backward()
|
|
|
grad_norm = None
|
|
grad_norm = None
|
|
|
- if self._grad_clip_enabled():
|
|
|
|
|
- self.grad_scaler.unscale_(self.optimizer)
|
|
|
|
|
- grad_norm = self._clip_gradients(self.model)
|
|
|
|
|
- self.grad_scaler.step(self.optimizer)
|
|
|
|
|
- self.grad_scaler.update()
|
|
|
|
|
|
|
+ should_step = (step % accum_steps == 0) or (step == num_steps)
|
|
|
|
|
+ if should_step:
|
|
|
|
|
+ if self._grad_clip_enabled():
|
|
|
|
|
+ self.grad_scaler.unscale_(self.optimizer)
|
|
|
|
|
+ grad_norm = self._clip_gradients(self.model)
|
|
|
|
|
+ self.grad_scaler.step(self.optimizer)
|
|
|
|
|
+ self.grad_scaler.update()
|
|
|
|
|
+ self.optimizer.zero_grad()
|
|
|
train_metrics = self._detach_metrics(losses)
|
|
train_metrics = self._detach_metrics(losses)
|
|
|
if grad_norm is not None:
|
|
if grad_norm is not None:
|
|
|
train_metrics["grad_norm"] = grad_norm
|
|
train_metrics["grad_norm"] = grad_norm
|