gan_loss.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch import autograd
  5. import math
  6. from . import LOSS
  7. __all__ = ['GANLoss', 'GPLoss', 'R1Loss', 'PathLoss']
  8. @LOSS.register_module
  9. class GANLoss(nn.Module):
  10. def __init__(self, mode='hinge',
  11. change_label_p=0.0,
  12. one_side_label_smooth=0.0,
  13. lam=1):
  14. super(GANLoss, self).__init__()
  15. self.mode = mode
  16. self.change_label_p = change_label_p
  17. self.one_side_label_smooth = one_side_label_smooth
  18. self.lam = lam
  19. if self.mode not in ['bce', 'mse', 'hinge', 'wgan', 'logistic_saturating', 'logistic_nonsaturating', 'relativistic_gan']:
  20. raise NotImplementedError('gan loss {} is not implemented'.format(self.mode))
  21. def get_target_tensor(self, pred, tgt):
  22. shape = pred.shape
  23. # tgt = torch.full((B,), tgt, dtype=pred.dtype)
  24. tgt = torch.full((shape), tgt, dtype=pred.dtype)
  25. # random change label
  26. if self.change_label_p >= 0.0:
  27. is_not_change = (torch.rand(shape) > self.change_label_p)
  28. is_not_change = is_not_change.float()
  29. tgt = tgt * is_not_change + (1 - tgt) * (1 - is_not_change) # xnor
  30. # one side label smooth
  31. if self.one_side_label_smooth >= 0.0:
  32. # tgt_tensor = (tgt * 1 - torch.rand(shape) * self.one_side_label_smooth).abs() # [0~0.1, 0.9~1]
  33. tgt_tensor = (tgt * 1 - torch.rand(shape) * self.one_side_label_smooth) * tgt # [0, 0.9~1]
  34. # tgt_tensor = torch.max(tgt * 1 - torch.rand(B) * self.label_smooth) # to be modify: only applying to real image.
  35. else:
  36. tgt_tensor = tgt * 1
  37. # return tgt_tensor.cuda(pred.device)
  38. return tgt_tensor.expand_as(pred).cuda(pred.device)
  39. def call_one(self, pred, should_be_classified_as_real):
  40. if self.mode == 'logistic_nonsaturating':
  41. loss = F.softplus(-pred).mean() if should_be_classified_as_real else F.softplus(pred).mean()
  42. else:
  43. raise 'invalid loss mode: {}'.format(self.loss_mode)
  44. return loss
  45. def __call__(self, pred_fake=None, pred_real=None, isD=True):
  46. if pred_fake is None:
  47. raise ValueError('meaningless input for GAN loss')
  48. loss = 0
  49. if self.mode == 'bce':
  50. if isD:
  51. loss_real = nn.BCEWithLogitsLoss()(pred_real, self.get_target_tensor(pred_real, 1.0))
  52. loss_fake = nn.BCEWithLogitsLoss()(pred_fake, self.get_target_tensor(pred_fake, 0.0))
  53. loss = loss_real + loss_fake
  54. else:
  55. loss_fake = nn.BCEWithLogitsLoss()(pred_fake, self.get_target_tensor(pred_fake, 1.0))
  56. loss = loss_fake
  57. elif self.mode == 'mse':
  58. if isD:
  59. loss_real = nn.MSELoss()(pred_real, self.get_target_tensor(pred_real, 1.0))
  60. loss_fake = nn.MSELoss()(pred_fake, self.get_target_tensor(pred_fake, 0.0))
  61. loss = loss_real + loss_fake
  62. else:
  63. loss_fake = nn.MSELoss()(pred_fake, self.get_target_tensor(pred_fake, 1.0))
  64. loss = loss_fake
  65. elif self.mode == 'hinge':
  66. if isD:
  67. loss_real = nn.ReLU()(1.0 - pred_real).mean()
  68. loss_fake = nn.ReLU()(1.0 + pred_fake).mean()
  69. loss = loss_real + loss_fake
  70. else:
  71. loss_fake = -pred_fake.mean()
  72. loss = loss_fake
  73. elif self.mode == 'wgan':
  74. if isD:
  75. loss_real = -pred_real.mean()
  76. loss_fake = pred_fake.mean()
  77. loss = loss_real + loss_fake
  78. else:
  79. loss_fake = -pred_fake.mean()
  80. loss = loss_fake
  81. elif self.mode == 'logistic_saturating':
  82. if isD:
  83. loss_real = F.softplus(-pred_real).mean() # log(1+exp(x))
  84. loss_fake = F.softplus(pred_fake).mean()
  85. loss = loss_real + loss_fake
  86. else:
  87. loss_fake = -F.softplus(pred_fake).mean()
  88. loss = loss_fake
  89. elif self.mode == 'logistic_nonsaturating':
  90. if isD:
  91. loss_real = F.softplus(-pred_real).mean() # log(1+exp(x))
  92. loss_fake = F.softplus(pred_fake).mean()
  93. loss = loss_real + loss_fake
  94. else:
  95. loss_fake = F.softplus(-pred_fake).mean()
  96. loss = loss_fake
  97. elif self.mode == 'relativistic_gan':
  98. if isD:
  99. loss_real = nn.BCEWithLogitsLoss()(pred_real - pred_fake.mean(0, keepdim=True), torch.ones_like(pred_real))
  100. loss_fake = nn.BCEWithLogitsLoss()(pred_fake - pred_real.mean(0, keepdim=True), torch.zeros_like(pred_real))
  101. loss = loss_real + loss_fake
  102. else:
  103. loss_fake = nn.BCEWithLogitsLoss()(pred_fake - pred_real.mean(0, keepdim=True), torch.ones_like(pred_real))
  104. loss = loss_fake
  105. return loss * self.lam
  106. @LOSS.register_module
  107. class GPLoss(nn.Module):
  108. def __init__(self,
  109. lam=1):
  110. super(GPLoss, self).__init__()
  111. self.lam = lam
  112. def forward(self, netD, real_data, fake_data):
  113. batch_size = real_data.size()[0]
  114. LAMBDA = 1
  115. alpha = torch.rand(batch_size, 1, 1, 1)
  116. alpha = alpha.expand_as(real_data).cuda()
  117. interpolates = alpha * real_data + (1 - alpha) * fake_data
  118. interpolates = interpolates.cuda()
  119. interpolates = autograd.Variable(interpolates, requires_grad=True)
  120. disc_interpolates = netD(interpolates)
  121. gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
  122. grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
  123. create_graph=True, retain_graph=True, only_inputs=True)[0]
  124. gradients = gradients.view(gradients.size(0), -1)
  125. gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
  126. return gradient_penalty * self.lam
  127. @LOSS.register_module
  128. class R1Loss(nn.Module):
  129. def __init__(self, lam=1):
  130. super(R1Loss, self).__init__()
  131. self.lam = lam
  132. def forward(self, images, output):
  133. gradients = autograd.grad(outputs=output, inputs=images, grad_outputs=torch.ones(output.size(), device=images.device),
  134. create_graph=True, retain_graph=True, only_inputs=True)[0].view(images.size(0), -1)
  135. r1_penalty = torch.sum(gradients.pow(2)).mean()
  136. # with no_weight_gradients():
  137. # grad_real, = autograd.grad(
  138. # outputs=output.sum(), inputs=images, create_graph=True
  139. # )
  140. # r1_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
  141. return r1_penalty * self.lam
  142. @LOSS.register_module
  143. class PathLoss(nn.Module):
  144. def __init__(self,
  145. lam=1):
  146. super(PathLoss, self).__init__()
  147. self.lam = lam
  148. def forward(self, img_fake, latent, mean_path_length, decay=0.01):
  149. noise = torch.randn_like(img_fake) / math.sqrt(img_fake.shape[2] * img_fake.shape[3])
  150. grad, = autograd.grad(outputs=(img_fake * noise).sum(), inputs=latent, create_graph=True)
  151. path_length = torch.sqrt(grad.pow(2).sum(2).mean(1))
  152. mean_path_length_out = mean_path_length + decay * (path_length.mean() - mean_path_length)
  153. path_penalty = (path_length - mean_path_length_out).pow(2).mean()
  154. return path_penalty * self.lam, path_length, mean_path_length_out.detach()