base_loss.py 578 B

123456789101112131415161718192021222324
  1. import torch.nn as nn
  2. from . import LOSS
  3. __all__ = ['L1Loss', 'MSELoss']
  4. @LOSS.register_module
  5. class L1Loss(nn.L1Loss):
  6. def __init__(self, lam=1):
  7. super(L1Loss, self).__init__()
  8. self.lam = lam
  9. def forward(self, input, target):
  10. return super(L1Loss, self).forward(input, target) * self.lam
  11. @LOSS.register_module
  12. class MSELoss(nn.MSELoss):
  13. def __init__(self, lam=1):
  14. super(MSELoss, self).__init__()
  15. self.lam = lam
  16. def forward(self, input, target):
  17. return super(MSELoss, self).forward(input, target) * self.lam