| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch import autograd
- import math
- from . import LOSS
- __all__ = ['GANLoss', 'GPLoss', 'R1Loss', 'PathLoss']
- @LOSS.register_module
- class GANLoss(nn.Module):
- def __init__(self, mode='hinge',
- change_label_p=0.0,
- one_side_label_smooth=0.0,
- lam=1):
- super(GANLoss, self).__init__()
- self.mode = mode
- self.change_label_p = change_label_p
- self.one_side_label_smooth = one_side_label_smooth
- self.lam = lam
- if self.mode not in ['bce', 'mse', 'hinge', 'wgan', 'logistic_saturating', 'logistic_nonsaturating', 'relativistic_gan']:
- raise NotImplementedError('gan loss {} is not implemented'.format(self.mode))
- def get_target_tensor(self, pred, tgt):
- shape = pred.shape
- # tgt = torch.full((B,), tgt, dtype=pred.dtype)
- tgt = torch.full((shape), tgt, dtype=pred.dtype)
- # random change label
- if self.change_label_p >= 0.0:
- is_not_change = (torch.rand(shape) > self.change_label_p)
- is_not_change = is_not_change.float()
- tgt = tgt * is_not_change + (1 - tgt) * (1 - is_not_change) # xnor
- # one side label smooth
- if self.one_side_label_smooth >= 0.0:
- # tgt_tensor = (tgt * 1 - torch.rand(shape) * self.one_side_label_smooth).abs() # [0~0.1, 0.9~1]
- tgt_tensor = (tgt * 1 - torch.rand(shape) * self.one_side_label_smooth) * tgt # [0, 0.9~1]
- # tgt_tensor = torch.max(tgt * 1 - torch.rand(B) * self.label_smooth) # to be modify: only applying to real image.
- else:
- tgt_tensor = tgt * 1
- # return tgt_tensor.cuda(pred.device)
- return tgt_tensor.expand_as(pred).cuda(pred.device)
- def call_one(self, pred, should_be_classified_as_real):
- if self.mode == 'logistic_nonsaturating':
- loss = F.softplus(-pred).mean() if should_be_classified_as_real else F.softplus(pred).mean()
- else:
- raise 'invalid loss mode: {}'.format(self.loss_mode)
- return loss
- def __call__(self, pred_fake=None, pred_real=None, isD=True):
- if pred_fake is None:
- raise ValueError('meaningless input for GAN loss')
- loss = 0
- if self.mode == 'bce':
- if isD:
- loss_real = nn.BCEWithLogitsLoss()(pred_real, self.get_target_tensor(pred_real, 1.0))
- loss_fake = nn.BCEWithLogitsLoss()(pred_fake, self.get_target_tensor(pred_fake, 0.0))
- loss = loss_real + loss_fake
- else:
- loss_fake = nn.BCEWithLogitsLoss()(pred_fake, self.get_target_tensor(pred_fake, 1.0))
- loss = loss_fake
- elif self.mode == 'mse':
- if isD:
- loss_real = nn.MSELoss()(pred_real, self.get_target_tensor(pred_real, 1.0))
- loss_fake = nn.MSELoss()(pred_fake, self.get_target_tensor(pred_fake, 0.0))
- loss = loss_real + loss_fake
- else:
- loss_fake = nn.MSELoss()(pred_fake, self.get_target_tensor(pred_fake, 1.0))
- loss = loss_fake
- elif self.mode == 'hinge':
- if isD:
- loss_real = nn.ReLU()(1.0 - pred_real).mean()
- loss_fake = nn.ReLU()(1.0 + pred_fake).mean()
- loss = loss_real + loss_fake
- else:
- loss_fake = -pred_fake.mean()
- loss = loss_fake
- elif self.mode == 'wgan':
- if isD:
- loss_real = -pred_real.mean()
- loss_fake = pred_fake.mean()
- loss = loss_real + loss_fake
- else:
- loss_fake = -pred_fake.mean()
- loss = loss_fake
- elif self.mode == 'logistic_saturating':
- if isD:
- loss_real = F.softplus(-pred_real).mean() # log(1+exp(x))
- loss_fake = F.softplus(pred_fake).mean()
- loss = loss_real + loss_fake
- else:
- loss_fake = -F.softplus(pred_fake).mean()
- loss = loss_fake
- elif self.mode == 'logistic_nonsaturating':
- if isD:
- loss_real = F.softplus(-pred_real).mean() # log(1+exp(x))
- loss_fake = F.softplus(pred_fake).mean()
- loss = loss_real + loss_fake
- else:
- loss_fake = F.softplus(-pred_fake).mean()
- loss = loss_fake
- elif self.mode == 'relativistic_gan':
- if isD:
- loss_real = nn.BCEWithLogitsLoss()(pred_real - pred_fake.mean(0, keepdim=True), torch.ones_like(pred_real))
- loss_fake = nn.BCEWithLogitsLoss()(pred_fake - pred_real.mean(0, keepdim=True), torch.zeros_like(pred_real))
- loss = loss_real + loss_fake
- else:
- loss_fake = nn.BCEWithLogitsLoss()(pred_fake - pred_real.mean(0, keepdim=True), torch.ones_like(pred_real))
- loss = loss_fake
- return loss * self.lam
- @LOSS.register_module
- class GPLoss(nn.Module):
- def __init__(self,
- lam=1):
- super(GPLoss, self).__init__()
- self.lam = lam
- def forward(self, netD, real_data, fake_data):
- batch_size = real_data.size()[0]
- LAMBDA = 1
- alpha = torch.rand(batch_size, 1, 1, 1)
- alpha = alpha.expand_as(real_data).cuda()
- interpolates = alpha * real_data + (1 - alpha) * fake_data
- interpolates = interpolates.cuda()
- interpolates = autograd.Variable(interpolates, requires_grad=True)
- disc_interpolates = netD(interpolates)
- gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
- grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
- create_graph=True, retain_graph=True, only_inputs=True)[0]
- gradients = gradients.view(gradients.size(0), -1)
- gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
- return gradient_penalty * self.lam
- @LOSS.register_module
- class R1Loss(nn.Module):
- def __init__(self, lam=1):
- super(R1Loss, self).__init__()
- self.lam = lam
-
- def forward(self, images, output):
- gradients = autograd.grad(outputs=output, inputs=images, grad_outputs=torch.ones(output.size(), device=images.device),
- create_graph=True, retain_graph=True, only_inputs=True)[0].view(images.size(0), -1)
- r1_penalty = torch.sum(gradients.pow(2)).mean()
-
- # with no_weight_gradients():
- # grad_real, = autograd.grad(
- # outputs=output.sum(), inputs=images, create_graph=True
- # )
- # r1_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
- return r1_penalty * self.lam
-
-
- @LOSS.register_module
- class PathLoss(nn.Module):
- def __init__(self,
- lam=1):
- super(PathLoss, self).__init__()
- self.lam = lam
- def forward(self, img_fake, latent, mean_path_length, decay=0.01):
- noise = torch.randn_like(img_fake) / math.sqrt(img_fake.shape[2] * img_fake.shape[3])
- grad, = autograd.grad(outputs=(img_fake * noise).sum(), inputs=latent, create_graph=True)
- path_length = torch.sqrt(grad.pow(2).sum(2).mean(1))
- mean_path_length_out = mean_path_length + decay * (path_length.mean() - mean_path_length)
- path_penalty = (path_length - mean_path_length_out).pow(2).mean()
- return path_penalty * self.lam, path_length, mean_path_length_out.detach()
|