cls_loss.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import torchvision.transforms as T
  5. import torchvision.transforms.functional as F_tv
  6. from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  7. from . import LOSS
  8. from model import get_model
  9. __all__ = ['CE', 'LabelSmoothingCE', 'SoftTargetCE', 'CLSKDLoss']
  10. @LOSS.register_module
  11. class CE(nn.CrossEntropyLoss):
  12. def __init__(self, lam=1):
  13. super(CE, self).__init__()
  14. self.lam = lam
  15. def forward(self, input, target):
  16. return super(CE, self).forward(input, target) * self.lam
  17. @LOSS.register_module
  18. class LabelSmoothingCE(nn.Module):
  19. """
  20. NLL loss with label smoothing.
  21. """
  22. def __init__(self, smoothing=0.1, lam=1):
  23. """
  24. Constructor for the LabelSmoothing module.
  25. :param smoothing: label smoothing factor
  26. """
  27. super(LabelSmoothingCE, self).__init__()
  28. assert smoothing < 1.0
  29. self.smoothing = smoothing
  30. self.lam = lam
  31. self.confidence = 1. - smoothing
  32. def forward(self, x, target):
  33. logprobs = F.log_softmax(x, dim=-1)
  34. nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
  35. nll_loss = nll_loss.squeeze(1)
  36. smooth_loss = -logprobs.mean(dim=-1)
  37. loss = self.confidence * nll_loss + self.smoothing * smooth_loss
  38. return loss.mean() * self.lam
  39. @LOSS.register_module
  40. class SoftTargetCE(nn.Module):
  41. def __init__(self, lam=1, fp32=False):
  42. super(SoftTargetCE, self).__init__()
  43. self.lam = lam
  44. self.fp32 = fp32
  45. def forward(self, x, target):
  46. if self.fp32:
  47. loss = torch.sum(-target * F.log_softmax(x.float(), dim=-1), dim=-1)
  48. else:
  49. loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
  50. return loss.mean() * self.lam
  51. @LOSS.register_module
  52. class CLSKDLoss(torch.nn.Module):
  53. def __init__(self, cfg, kd_type='soft', size=224, mean_t=IMAGENET_DEFAULT_MEAN, std_t=IMAGENET_DEFAULT_STD,
  54. mean_s=IMAGENET_DEFAULT_MEAN, std_s=IMAGENET_DEFAULT_STD, tau=1.0, lam=1):
  55. super().__init__()
  56. self.teacher_model = get_model(cfg)
  57. self.teacher_model.cuda()
  58. self.teacher_model.eval()
  59. assert kd_type in ['soft', 'hard']
  60. self.kd_type = kd_type
  61. self.size = size
  62. self.mean_t, self.std_t = mean_t, std_t
  63. self.mean_s, self.std_s = mean_s, std_s
  64. self.tau = tau
  65. self.lam = lam
  66. def forward(self, outputs_kd, inputs):
  67. with torch.no_grad():
  68. if self.mean_t != self.mean_s or self.std_t != self.std_s:
  69. # std = [std_t / std_s for std_t, std_s in zip(self.std_t, self.std_s)]
  70. # transform_std = T.Normalize(self.mean_t, std=std)
  71. # mean = [mean_t / mean_s for mean_t, mean_s in zip(self.mean_t, self.mean_s)]
  72. # transform_mean = T.Normalize(mean=mean, std=self.std_t)
  73. # inputs = transform_mean(transform_std(inputs))
  74. mean_t = torch.as_tensor(self.mean_t, dtype=inputs.dtype, device=inputs.device).view(-1, 1, 1)
  75. std_t = torch.as_tensor(self.std_t, dtype=inputs.dtype, device=inputs.device).view(-1, 1, 1)
  76. mean_s = torch.as_tensor(self.mean_s, dtype=inputs.dtype, device=inputs.device).view(-1, 1, 1)
  77. std_s = torch.as_tensor(self.std_s, dtype=inputs.dtype, device=inputs.device).view(-1, 1, 1)
  78. inputs = inputs.clone()
  79. inputs.mul_(std_s).add_(mean_s).sub_(mean_t).div_(std_t)
  80. B, C, H, W = inputs.shape
  81. if H != self.size:
  82. inputs = F_tv.resize(inputs, self.size, F_tv.InterpolationMode.BICUBIC)
  83. teacher_outputs = self.teacher_model(inputs)
  84. if self.kd_type == 'soft':
  85. distillation_loss = F.kl_div(F.log_softmax(outputs_kd / self.tau, dim=1),
  86. F.log_softmax(teacher_outputs / self.tau, dim=1),
  87. reduction='sum', log_target=True) * (self.tau * self.tau) / outputs_kd.shape[0]
  88. elif self.kd_type == 'hard':
  89. distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
  90. else:
  91. raise ValueError(f'invalid distillation type: {self.kd_type}')
  92. return distillation_loss * self.lam