visualization_cam.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. """Use the pytorch-grad-cam tool to visualize Class Activation Maps (CAM).
  3. requirement: pip install grad-cam
  4. """
  5. from argparse import ArgumentParser
  6. import numpy as np
  7. import torch
  8. import torch.nn.functional as F
  9. from mmengine import Config
  10. from mmengine.model import revert_sync_batchnorm
  11. from PIL import Image
  12. from pytorch_grad_cam import GradCAM
  13. from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
  14. from mmseg.apis import inference_model, init_model, show_result_pyplot
  15. from mmseg.utils import register_all_modules
  16. class SemanticSegmentationTarget:
  17. """wrap the model.
  18. requirement: pip install grad-cam
  19. Args:
  20. category (int): Visualization class.
  21. mask (ndarray): Mask of class.
  22. size (tuple): Image size.
  23. """
  24. def __init__(self, category, mask, size):
  25. self.category = category
  26. self.mask = torch.from_numpy(mask)
  27. self.size = size
  28. if torch.cuda.is_available():
  29. self.mask = self.mask.cuda()
  30. def __call__(self, model_output):
  31. model_output = torch.unsqueeze(model_output, dim=0)
  32. model_output = F.interpolate(
  33. model_output, size=self.size, mode='bilinear')
  34. model_output = torch.squeeze(model_output, dim=0)
  35. return (model_output[self.category, :, :] * self.mask).sum()
  36. def main():
  37. parser = ArgumentParser()
  38. parser.add_argument('img', help='Image file')
  39. parser.add_argument('config', help='Config file')
  40. parser.add_argument('checkpoint', help='Checkpoint file')
  41. parser.add_argument(
  42. '--out-file',
  43. default='prediction.png',
  44. help='Path to output prediction file')
  45. parser.add_argument(
  46. '--cam-file', default='vis_cam.png', help='Path to output cam file')
  47. parser.add_argument(
  48. '--target-layers',
  49. default='backbone.layer4[2]',
  50. help='Target layers to visualize CAM')
  51. parser.add_argument(
  52. '--category-index', default='7', help='Category to visualize CAM')
  53. parser.add_argument(
  54. '--device', default='cuda:0', help='Device used for inference')
  55. args = parser.parse_args()
  56. # build the model from a config file and a checkpoint file
  57. register_all_modules()
  58. model = init_model(args.config, args.checkpoint, device=args.device)
  59. if args.device == 'cpu':
  60. model = revert_sync_batchnorm(model)
  61. # test a single image
  62. result = inference_model(model, args.img)
  63. # show the results
  64. show_result_pyplot(
  65. model,
  66. args.img,
  67. result,
  68. draw_gt=False,
  69. show=False if args.out_file is not None else True,
  70. out_file=args.out_file)
  71. # result data conversion
  72. prediction_data = result.pred_sem_seg.data
  73. pre_np_data = prediction_data.cpu().numpy().squeeze(0)
  74. target_layers = args.target_layers
  75. target_layers = [eval(f'model.{target_layers}')]
  76. category = int(args.category_index)
  77. mask_float = np.float32(pre_np_data == category)
  78. # data processing
  79. image = np.array(Image.open(args.img).convert('RGB'))
  80. height, width = image.shape[0], image.shape[1]
  81. rgb_img = np.float32(image) / 255
  82. config = Config.fromfile(args.config)
  83. image_mean = config.data_preprocessor['mean']
  84. image_std = config.data_preprocessor['std']
  85. input_tensor = preprocess_image(
  86. rgb_img,
  87. mean=[x / 255 for x in image_mean],
  88. std=[x / 255 for x in image_std])
  89. # Grad CAM(Class Activation Maps)
  90. # Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
  91. targets = [
  92. SemanticSegmentationTarget(category, mask_float, (height, width))
  93. ]
  94. with GradCAM(
  95. model=model,
  96. target_layers=target_layers,
  97. use_cuda=torch.cuda.is_available()) as cam:
  98. grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
  99. cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
  100. # save cam file
  101. Image.fromarray(cam_image).save(args.cam_file)
  102. if __name__ == '__main__':
  103. main()