flops.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # this is only a script
  2. import os
  3. import sys
  4. import torch
  5. import torch.nn as nn
  6. from utils import FLOPs, BuildModels, import_abspy
  7. HOME = os.environ["HOME"].rstrip("/")
  8. if __name__ == '__main__':
  9. from utils import FLOPs, BuildModels, import_abspy
  10. build = import_abspy("models", os.path.join(os.path.dirname(os.path.abspath(__file__)), "../classification/"),)
  11. Backbone_VSSM: nn.Module = build.vmamba.Backbone_VSSM
  12. def mmdet_mmseg_vssm():
  13. from mmengine.model import BaseModule
  14. from mmdet.registry import MODELS as MODELS_MMDET
  15. from mmseg.registry import MODELS as MODELS_MMSEG
  16. @MODELS_MMSEG.register_module()
  17. @MODELS_MMDET.register_module()
  18. class MM_VSSM(BaseModule, Backbone_VSSM):
  19. def __init__(self, *args, **kwargs):
  20. BaseModule.__init__(self)
  21. Backbone_VSSM.__init__(self, *args, **kwargs)
  22. # FLOPs.fvcore_flop_count(BuildModels.build_xcit(scale="tiny").cuda())
  23. # FLOPs.fvcore_flop_count(BuildModels.build_xcit(scale="small").cuda())
  24. # FLOPs.fvcore_flop_count(BuildModels.build_xcit(scale="base").cuda())
  25. segpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../segmentation/configs")
  26. detpath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../detection/configs")
  27. mmdet_mmseg_vssm()
  28. if False:
  29. # FLOPs.mmseg_flops(config=f"{segpath}/upernet/upernet_r50_4xb4-160k_ade20k-512x512.py", input_shape=(3, 512, 2048)) # GFlops: 952.616667136 Params: 66516108
  30. # FLOPs.mmseg_flops(config=f"{segpath}/upernet/upernet_r101_4xb4-160k_ade20k-512x512.py", input_shape=(3, 512, 2048)) # GFlops: 1030.4084234239997 Params: 85508236
  31. # FLOPs.mmseg_flops(config=f"{segpath}/vit/vit_deit-s16_mln_upernet_8xb2-160k_ade20k-512x512.py", input_shape=(3, 512, 2048)) # GFlops: 1216.821829632 Params: 57994796
  32. # FLOPs.mmseg_flops(config=f"{segpath}/vit/vit_deit-b16_mln_upernet_8xb2-160k_ade20k-512x512.py", input_shape=(3, 512, 2048)) # GFlops: 2006.545496064 Params: 144172844
  33. FLOPs.mmseg_flops(config=f"{segpath}/vssm/upernet_vssm_4xb4-160k_ade20k-512x512_tiny.py", input_shape=(3, 512, 2048)) # GFlops: 939.4933174400002 Params: 54546956
  34. FLOPs.mmseg_flops(config=f"{segpath}/vssm/upernet_vssm_4xb4-160k_ade20k-512x512_small.py", input_shape=(3, 512, 2048)) # GFlops: 1036.6845167359998 Params: 76070924
  35. FLOPs.mmseg_flops(config=f"{segpath}/vssm/upernet_vssm_4xb4-160k_ade20k-512x512_base.py", input_shape=(3, 512, 2048)) # GFlops: 1166.887735664 Params: 109765548
  36. # FLOPs.mmseg_flops(config=f"{segpath}/vssm/upernet_swin_4xb4-160k_ade20k-640x640_small.py", input_shape=(3, 640, 2560)) # GFlops: 1614.082896384 Params: 81259766
  37. # FLOPs.mmseg_flops(config=f"{segpath}/vssm/upernet_convnext_4xb4-160k_ade20k-640x640_small.py", input_shape=(3, 640, 2560)) # GFlops: 1606.538496 Params: 81877196
  38. # FLOPs.mmseg_flops(config=f"{segpath}/vssm/upernet_vssm_4xb4-160k_ade20k-640x640_small.py", input_shape=(3, 640, 2560)) # GFlops: 1619.8110944 Params: 76070924
  39. if True:
  40. FLOPs.mmdet_flops(config=f"{detpath}/vssm/mask_rcnn_vssm_fpn_coco_tiny.py") # 42.4M 262093532640.0 285883020640.0
  41. FLOPs.mmdet_flops(config=f"{detpath}/vssm/mask_rcnn_vssm_fpn_coco_small.py") # 63.924M 357006236640.0 400260276640.0
  42. FLOPs.mmdet_flops(config=f"{detpath}/vssm/mask_rcnn_vssm_fpn_coco_base.py") # 95.628M 482127568640.0 539797328640.0
  43. # FLOPs.mmdet_flops(config=f"{detpath}/mask_rcnn/mask-rcnn_r50_fpn_1x_coco.py") # 44.396M 260152304640.0
  44. # FLOPs.mmdet_flops(config=f"{detpath}/mask_rcnn/mask-rcnn_r101_fpn_1x_coco.py") # 63.388M 336434160640.0
  45. if False:
  46. FLOPs.mmseg_flops(config=f"{segpath}/vssm1/upernet_vssm_4xb4-160k_ade20k-512x512_tiny1.py", input_shape=(3, 512, 2048)) # GFlops: 947.779848192 Params: 62359340
  47. FLOPs.mmseg_flops(config=f"{segpath}/vssm1/upernet_vssm_4xb4-160k_ade20k-512x512_tiny.py", input_shape=(3, 512, 2048)) # GFlops: 948.7801896960001 Params: 61902572
  48. FLOPs.mmseg_flops(config=f"{segpath}/vssm1/upernet_vssm_4xb4-160k_ade20k-512x512_small.py", input_shape=(3, 512, 2048)) # GFlops: 1028.404888464 Params: 81801260
  49. FLOPs.mmseg_flops(config=f"{segpath}/vssm1/upernet_vssm_4xb4-160k_ade20k-512x512_base.py", input_shape=(3, 512, 2048)) # GFlops: 1170.3442882240001 Params: 122069292
  50. FLOPs.mmseg_flops(config=f"{segpath}/vssm1/upernet_vssm_4xb4-160k_ade20k-640x640_small.py", input_shape=(3, 640, 2560)) # GFlops: 1606.8682596 Params: 81801260
  51. if False:
  52. FLOPs.mmdet_flops(config=f"{detpath}/vssm1/mask_rcnn_vssm_fpn_coco_tiny1.py") # 50.212M 270186480640.0
  53. FLOPs.mmdet_flops(config=f"{detpath}/vssm1/mask_rcnn_vssm_fpn_coco_tiny.py") # 49.755M 271163376640.0
  54. FLOPs.mmdet_flops(config=f"{detpath}/vssm1/mask_rcnn_vssm_fpn_coco_small.py") # 69.654M 348921708640.0
  55. FLOPs.mmdet_flops(config=f"{detpath}/vssm1/mask_rcnn_vssm_fpn_coco_base.py") # 108M 485496108640.0
  56. # xcit det
  57. if False:
  58. lines = open(f"{HOME}/packs/xcit/detection/backbone/xcit.py").readlines()
  59. for i, l in enumerate(lines):
  60. if "from mmcv.runner import load_checkpoint\n" in l:
  61. lines[i] = "from mmengine.runner import load_checkpoint\n"
  62. elif "from mmdet.utils import get_root_logger\n" in l:
  63. lines[i] = "from mmengine.logging.logger import MMLogger as get_root_logger\n"
  64. elif "from mmdet.models.builder import BACKBONES\n" in l:
  65. lines[i] = "from mmdet.registry import MODELS as BACKBONES\n"
  66. file = open("/tmp/mmdet_backbone_xcit.py", "w+")
  67. file.write("".join(lines))
  68. file.close()
  69. xcit_det = import_abspy("mmdet_backbone_xcit", "/tmp")
  70. FLOPs.mmdet_flops(config=f"{HOME}/packs/xcit/detection/configs/xcit/mask_rcnn_xcit_small_12_p16_3x_coco.py", extra_config=f"{detpath}/mask_rcnn/mask-rcnn_r50-caffe_fpn_ms-poly-3x_coco.py") # 44.387M 286517232640.0
  71. FLOPs.mmdet_flops(config=f"{HOME}/packs/xcit/detection/configs/xcit/mask_rcnn_xcit_small_24_p16_3x_coco.py", extra_config=f"{detpath}/mask_rcnn/mask-rcnn_r50-caffe_fpn_ms-poly-3x_coco.py") # 65.805M 373921776640.0
  72. FLOPs.mmdet_flops(config=f"{HOME}/packs/xcit/detection/configs/xcit/mask_rcnn_xcit_medium_24_p16_3x_coco.py", extra_config=f"{detpath}/mask_rcnn/mask-rcnn_r50-caffe_fpn_ms-poly-3x_coco.py") # 98.981M 1476021744640.0
  73. # xcit seg
  74. if False:
  75. from mmengine.model import BaseModule
  76. lines = open(f"{HOME}/packs/xcit/semantic_segmentation/backbone/xcit.py").readlines()
  77. for i, l in enumerate(lines):
  78. if "from mmcv.runner import load_checkpoint\n" in l:
  79. lines[i] = "from mmengine.runner import load_checkpoint\n"
  80. elif "from mmseg.utils import get_root_logger\n" in l:
  81. lines[i] = "from mmengine.logging.logger import MMLogger as get_root_logger\n"
  82. elif "from mmseg.models.builder import BACKBONES\n" in l:
  83. lines[i] = "from mmseg.registry import MODELS as BACKBONES\n"
  84. file = open("/tmp/mmseg_backbone_xcit.py", "w+")
  85. file.write("".join(lines))
  86. file.close()
  87. xcit_seg = import_abspy("mmseg_backbone_xcit", "/tmp")
  88. FLOPs.mmseg_flops(config=f"{HOME}/packs/xcit/semantic_segmentation/configs/xcit/upernet/upernet_xcit_small_12_p16_160k_ade20k.py", input_shape=(3, 512, 2048)) # GFlops: 968.270727168 Params: 54199100
  89. FLOPs.mmseg_flops(config=f"{HOME}/packs/xcit/semantic_segmentation/configs/xcit/upernet/upernet_xcit_small_24_p16_160k_ade20k.py", input_shape=(3, 512, 2048)) # GFlops: 1057.7163571199999 Params: 75617180
  90. FLOPs.mmseg_flops(config=f"{HOME}/packs/xcit/semantic_segmentation/configs/xcit/upernet/upernet_xcit_medium_24_p16_160k_ade20k.py", input_shape=(3, 512, 2048)) # GFlops: 1220.2945269759998 Params: 112177196