| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- # Copyright (c) OpenMMLab. All rights reserved.
- """Use the pytorch-grad-cam tool to visualize Class Activation Maps (CAM).
- requirement: pip install grad-cam
- """
- from argparse import ArgumentParser
- import numpy as np
- import torch
- import torch.nn.functional as F
- from mmengine import Config
- from mmengine.model import revert_sync_batchnorm
- from PIL import Image
- from pytorch_grad_cam import GradCAM
- from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
- from mmseg.apis import inference_model, init_model, show_result_pyplot
- from mmseg.utils import register_all_modules
- class SemanticSegmentationTarget:
- """wrap the model.
- requirement: pip install grad-cam
- Args:
- category (int): Visualization class.
- mask (ndarray): Mask of class.
- size (tuple): Image size.
- """
- def __init__(self, category, mask, size):
- self.category = category
- self.mask = torch.from_numpy(mask)
- self.size = size
- if torch.cuda.is_available():
- self.mask = self.mask.cuda()
- def __call__(self, model_output):
- model_output = torch.unsqueeze(model_output, dim=0)
- model_output = F.interpolate(
- model_output, size=self.size, mode='bilinear')
- model_output = torch.squeeze(model_output, dim=0)
- return (model_output[self.category, :, :] * self.mask).sum()
- def main():
- parser = ArgumentParser()
- parser.add_argument('img', help='Image file')
- parser.add_argument('config', help='Config file')
- parser.add_argument('checkpoint', help='Checkpoint file')
- parser.add_argument(
- '--out-file',
- default='prediction.png',
- help='Path to output prediction file')
- parser.add_argument(
- '--cam-file', default='vis_cam.png', help='Path to output cam file')
- parser.add_argument(
- '--target-layers',
- default='backbone.layer4[2]',
- help='Target layers to visualize CAM')
- parser.add_argument(
- '--category-index', default='7', help='Category to visualize CAM')
- parser.add_argument(
- '--device', default='cuda:0', help='Device used for inference')
- args = parser.parse_args()
- # build the model from a config file and a checkpoint file
- register_all_modules()
- model = init_model(args.config, args.checkpoint, device=args.device)
- if args.device == 'cpu':
- model = revert_sync_batchnorm(model)
- # test a single image
- result = inference_model(model, args.img)
- # show the results
- show_result_pyplot(
- model,
- args.img,
- result,
- draw_gt=False,
- show=False if args.out_file is not None else True,
- out_file=args.out_file)
- # result data conversion
- prediction_data = result.pred_sem_seg.data
- pre_np_data = prediction_data.cpu().numpy().squeeze(0)
- target_layers = args.target_layers
- target_layers = [eval(f'model.{target_layers}')]
- category = int(args.category_index)
- mask_float = np.float32(pre_np_data == category)
- # data processing
- image = np.array(Image.open(args.img).convert('RGB'))
- height, width = image.shape[0], image.shape[1]
- rgb_img = np.float32(image) / 255
- config = Config.fromfile(args.config)
- image_mean = config.data_preprocessor['mean']
- image_std = config.data_preprocessor['std']
- input_tensor = preprocess_image(
- rgb_img,
- mean=[x / 255 for x in image_mean],
- std=[x / 255 for x in image_std])
- # Grad CAM(Class Activation Maps)
- # Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
- targets = [
- SemanticSegmentationTarget(category, mask_float, (height, width))
- ]
- with GradCAM(
- model=model,
- target_layers=target_layers,
- use_cuda=torch.cuda.is_available()) as cam:
- grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
- cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
- # save cam file
- Image.fromarray(cam_image).save(args.cam_file)
- if __name__ == '__main__':
- main()
|