pytorch2torchscript.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import argparse
  3. import numpy as np
  4. import torch
  5. import torch._C
  6. import torch.serialization
  7. from mmengine import Config
  8. from mmengine.runner import load_checkpoint
  9. from torch import nn
  10. from mmseg.models import build_segmentor
  11. torch.manual_seed(3)
  12. def digit_version(version_str):
  13. digit_version = []
  14. for x in version_str.split('.'):
  15. if x.isdigit():
  16. digit_version.append(int(x))
  17. elif x.find('rc') != -1:
  18. patch_version = x.split('rc')
  19. digit_version.append(int(patch_version[0]) - 1)
  20. digit_version.append(int(patch_version[1]))
  21. return digit_version
  22. def check_torch_version():
  23. torch_minimum_version = '1.8.0'
  24. torch_version = digit_version(torch.__version__)
  25. assert (torch_version >= digit_version(torch_minimum_version)), \
  26. f'Torch=={torch.__version__} is not support for converting to ' \
  27. f'torchscript. Please install pytorch>={torch_minimum_version}.'
  28. def _convert_batchnorm(module):
  29. module_output = module
  30. if isinstance(module, torch.nn.SyncBatchNorm):
  31. module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
  32. module.momentum, module.affine,
  33. module.track_running_stats)
  34. if module.affine:
  35. module_output.weight.data = module.weight.data.clone().detach()
  36. module_output.bias.data = module.bias.data.clone().detach()
  37. # keep requires_grad unchanged
  38. module_output.weight.requires_grad = module.weight.requires_grad
  39. module_output.bias.requires_grad = module.bias.requires_grad
  40. module_output.running_mean = module.running_mean
  41. module_output.running_var = module.running_var
  42. module_output.num_batches_tracked = module.num_batches_tracked
  43. for name, child in module.named_children():
  44. module_output.add_module(name, _convert_batchnorm(child))
  45. del module
  46. return module_output
  47. def _demo_mm_inputs(input_shape, num_classes):
  48. """Create a superset of inputs needed to run test or train batches.
  49. Args:
  50. input_shape (tuple):
  51. input batch dimensions
  52. num_classes (int):
  53. number of semantic classes
  54. """
  55. (N, C, H, W) = input_shape
  56. rng = np.random.RandomState(0)
  57. imgs = rng.rand(*input_shape)
  58. segs = rng.randint(
  59. low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
  60. img_metas = [{
  61. 'img_shape': (H, W, C),
  62. 'ori_shape': (H, W, C),
  63. 'pad_shape': (H, W, C),
  64. 'filename': '<demo>.png',
  65. 'scale_factor': 1.0,
  66. 'flip': False,
  67. } for _ in range(N)]
  68. mm_inputs = {
  69. 'imgs': torch.FloatTensor(imgs).requires_grad_(True),
  70. 'img_metas': img_metas,
  71. 'gt_semantic_seg': torch.LongTensor(segs)
  72. }
  73. return mm_inputs
  74. def pytorch2libtorch(model,
  75. input_shape,
  76. show=False,
  77. output_file='tmp.pt',
  78. verify=False):
  79. """Export Pytorch model to TorchScript model and verify the outputs are
  80. same between Pytorch and TorchScript.
  81. Args:
  82. model (nn.Module): Pytorch model we want to export.
  83. input_shape (tuple): Use this input shape to construct
  84. the corresponding dummy input and execute the model.
  85. show (bool): Whether print the computation graph. Default: False.
  86. output_file (string): The path to where we store the
  87. output TorchScript model. Default: `tmp.pt`.
  88. verify (bool): Whether compare the outputs between
  89. Pytorch and TorchScript. Default: False.
  90. """
  91. if isinstance(model.decode_head, nn.ModuleList):
  92. num_classes = model.decode_head[-1].num_classes
  93. else:
  94. num_classes = model.decode_head.num_classes
  95. mm_inputs = _demo_mm_inputs(input_shape, num_classes)
  96. imgs = mm_inputs.pop('imgs')
  97. # replace the original forword with forward_dummy
  98. model.forward = model.forward_dummy
  99. model.eval()
  100. traced_model = torch.jit.trace(
  101. model,
  102. example_inputs=imgs,
  103. check_trace=verify,
  104. )
  105. if show:
  106. print(traced_model.graph)
  107. traced_model.save(output_file)
  108. print(f'Successfully exported TorchScript model: {output_file}')
  109. def parse_args():
  110. parser = argparse.ArgumentParser(
  111. description='Convert MMSeg to TorchScript')
  112. parser.add_argument('config', help='test config file path')
  113. parser.add_argument('--checkpoint', help='checkpoint file', default=None)
  114. parser.add_argument(
  115. '--show', action='store_true', help='show TorchScript graph')
  116. parser.add_argument(
  117. '--verify', action='store_true', help='verify the TorchScript model')
  118. parser.add_argument('--output-file', type=str, default='tmp.pt')
  119. parser.add_argument(
  120. '--shape',
  121. type=int,
  122. nargs='+',
  123. default=[512, 512],
  124. help='input image size (height, width)')
  125. args = parser.parse_args()
  126. return args
  127. if __name__ == '__main__':
  128. args = parse_args()
  129. check_torch_version()
  130. if len(args.shape) == 1:
  131. input_shape = (1, 3, args.shape[0], args.shape[0])
  132. elif len(args.shape) == 2:
  133. input_shape = (
  134. 1,
  135. 3,
  136. ) + tuple(args.shape)
  137. else:
  138. raise ValueError('invalid input shape')
  139. cfg = Config.fromfile(args.config)
  140. cfg.model.pretrained = None
  141. # build the model and load checkpoint
  142. cfg.model.train_cfg = None
  143. segmentor = build_segmentor(
  144. cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
  145. # convert SyncBN to BN
  146. segmentor = _convert_batchnorm(segmentor)
  147. if args.checkpoint:
  148. load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
  149. # convert the PyTorch model to LibTorch model
  150. pytorch2libtorch(
  151. segmentor,
  152. input_shape,
  153. show=args.show,
  154. output_file=args.output_file,
  155. verify=args.verify)