2 Commits 235cfeb50c ... 7a65a1d472

Autor SHA1 Mensaje Fecha
  kekeZack 7a65a1d472 chore(project): 初始化项目配置文件和代码结构 hace 1 mes
  kekezack 235cfeb50c Initial commit hace 1 mes
Se han modificado 100 ficheros con 8462 adiciones y 51 borrados
  1. 36 51
      .gitignore
  2. 92 0
      configs/segmentation/train_seg_b.yaml
  3. 92 0
      configs/segmentation/train_seg_lun.yaml
  4. 95 0
      configs/segmentation/train_seg_multiclass_template.yaml
  5. 92 0
      configs/segmentation/train_seg_pe.yaml
  6. 92 0
      configs/segmentation/train_seg_template.yaml
  7. 19 0
      configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml
  8. 19 0
      configs/swinv2/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml
  9. 21 0
      configs/swinv2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml
  10. 11 0
      configs/swinv2/swinv2_base_patch4_window16_256.yaml
  11. 11 0
      configs/swinv2/swinv2_base_patch4_window8_256.yaml
  12. 19 0
      configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml
  13. 19 0
      configs/swinv2/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml
  14. 21 0
      configs/swinv2/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml
  15. 11 0
      configs/swinv2/swinv2_small_patch4_window16_256.yaml
  16. 11 0
      configs/swinv2/swinv2_small_patch4_window8_256.yaml
  17. 11 0
      configs/swinv2/swinv2_tiny_patch4_window16_256.yaml
  18. 11 0
      configs/swinv2/swinv2_tiny_patch4_window8_256.yaml
  19. 135 0
      lib/SwinTransformer/.gitignore
  20. 9 0
      lib/SwinTransformer/CODE_OF_CONDUCT.md
  21. 21 0
      lib/SwinTransformer/LICENSE
  22. 159 0
      lib/SwinTransformer/MODELHUB.md
  23. 310 0
      lib/SwinTransformer/README.md
  24. 41 0
      lib/SwinTransformer/SECURITY.md
  25. 25 0
      lib/SwinTransformer/SUPPORT.md
  26. 3 0
      lib/SwinTransformer/__init__.py
  27. 359 0
      lib/SwinTransformer/config.py
  28. 22 0
      lib/SwinTransformer/configs/simmim/simmim_finetune__swin_base__img224_window7__800ep.yaml
  29. 23 0
      lib/SwinTransformer/configs/simmim/simmim_finetune__swinv2_base__img224_window14__800ep.yaml
  30. 26 0
      lib/SwinTransformer/configs/simmim/simmim_pretrain__swin_base__img192_window6__800ep.yaml
  31. 30 0
      lib/SwinTransformer/configs/simmim/simmim_pretrain__swinv2_base__img192_window12__800ep.yaml
  32. 20 0
      lib/SwinTransformer/configs/swin/swin_base_patch4_window12_384_22kto1k_finetune.yaml
  33. 20 0
      lib/SwinTransformer/configs/swin/swin_base_patch4_window12_384_finetune.yaml
  34. 9 0
      lib/SwinTransformer/configs/swin/swin_base_patch4_window7_224.yaml
  35. 18 0
      lib/SwinTransformer/configs/swin/swin_base_patch4_window7_224_22k.yaml
  36. 16 0
      lib/SwinTransformer/configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml
  37. 20 0
      lib/SwinTransformer/configs/swin/swin_large_patch4_window12_384_22kto1k_finetune.yaml
  38. 18 0
      lib/SwinTransformer/configs/swin/swin_large_patch4_window7_224_22k.yaml
  39. 16 0
      lib/SwinTransformer/configs/swin/swin_large_patch4_window7_224_22kto1k_finetune.yaml
  40. 9 0
      lib/SwinTransformer/configs/swin/swin_small_patch4_window7_224.yaml
  41. 18 0
      lib/SwinTransformer/configs/swin/swin_small_patch4_window7_224_22k.yaml
  42. 16 0
      lib/SwinTransformer/configs/swin/swin_small_patch4_window7_224_22kto1k_finetune.yaml
  43. 11 0
      lib/SwinTransformer/configs/swin/swin_tiny_c24_patch4_window8_256.yaml
  44. 9 0
      lib/SwinTransformer/configs/swin/swin_tiny_patch4_window7_224.yaml
  45. 18 0
      lib/SwinTransformer/configs/swin/swin_tiny_patch4_window7_224_22k.yaml
  46. 16 0
      lib/SwinTransformer/configs/swin/swin_tiny_patch4_window7_224_22kto1k_finetune.yaml
  47. 9 0
      lib/SwinTransformer/configs/swinmlp/swin_mlp_base_patch4_window7_224.yaml
  48. 11 0
      lib/SwinTransformer/configs/swinmlp/swin_mlp_tiny_c12_patch4_window8_256.yaml
  49. 11 0
      lib/SwinTransformer/configs/swinmlp/swin_mlp_tiny_c24_patch4_window8_256.yaml
  50. 11 0
      lib/SwinTransformer/configs/swinmlp/swin_mlp_tiny_c6_patch4_window8_256.yaml
  51. 31 0
      lib/SwinTransformer/configs/swinmoe/swin_moe_base_patch4_window12_192_16expert_32gpu_22k.yaml
  52. 31 0
      lib/SwinTransformer/configs/swinmoe/swin_moe_base_patch4_window12_192_32expert_32gpu_22k.yaml
  53. 31 0
      lib/SwinTransformer/configs/swinmoe/swin_moe_base_patch4_window12_192_8expert_32gpu_22k.yaml
  54. 32 0
      lib/SwinTransformer/configs/swinmoe/swin_moe_base_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml
  55. 26 0
      lib/SwinTransformer/configs/swinmoe/swin_moe_base_patch4_window12_192_densebaseline_22k.yaml
  56. 31 0
      lib/SwinTransformer/configs/swinmoe/swin_moe_small_patch4_window12_192_16expert_32gpu_22k.yaml
  57. 31 0
      lib/SwinTransformer/configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml
  58. 31 0
      lib/SwinTransformer/configs/swinmoe/swin_moe_small_patch4_window12_192_64expert_64gpu_22k.yaml
  59. 31 0
      lib/SwinTransformer/configs/swinmoe/swin_moe_small_patch4_window12_192_8expert_32gpu_22k.yaml
  60. 32 0
      lib/SwinTransformer/configs/swinmoe/swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml
  61. 26 0
      lib/SwinTransformer/configs/swinmoe/swin_moe_small_patch4_window12_192_densebaseline_22k.yaml
  62. 19 0
      lib/SwinTransformer/configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml
  63. 19 0
      lib/SwinTransformer/configs/swinv2/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml
  64. 21 0
      lib/SwinTransformer/configs/swinv2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml
  65. 11 0
      lib/SwinTransformer/configs/swinv2/swinv2_base_patch4_window16_256.yaml
  66. 11 0
      lib/SwinTransformer/configs/swinv2/swinv2_base_patch4_window8_256.yaml
  67. 19 0
      lib/SwinTransformer/configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml
  68. 19 0
      lib/SwinTransformer/configs/swinv2/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml
  69. 21 0
      lib/SwinTransformer/configs/swinv2/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml
  70. 11 0
      lib/SwinTransformer/configs/swinv2/swinv2_small_patch4_window16_256.yaml
  71. 11 0
      lib/SwinTransformer/configs/swinv2/swinv2_small_patch4_window8_256.yaml
  72. 11 0
      lib/SwinTransformer/configs/swinv2/swinv2_tiny_patch4_window16_256.yaml
  73. 11 0
      lib/SwinTransformer/configs/swinv2/swinv2_tiny_patch4_window8_256.yaml
  74. 12 0
      lib/SwinTransformer/data/__init__.py
  75. 162 0
      lib/SwinTransformer/data/build.py
  76. 252 0
      lib/SwinTransformer/data/cached_image_folder.py
  77. 112 0
      lib/SwinTransformer/data/data_simmim_ft.py
  78. 99 0
      lib/SwinTransformer/data/data_simmim_pt.py
  79. 55 0
      lib/SwinTransformer/data/imagenet22k_dataset.py
  80. 1000 0
      lib/SwinTransformer/data/map22kto1k.txt
  81. 29 0
      lib/SwinTransformer/data/samplers.py
  82. 103 0
      lib/SwinTransformer/data/zipreader.py
  83. BIN
      lib/SwinTransformer/figures/teaser.png
  84. 310 0
      lib/SwinTransformer/get_started.md
  85. 12 0
      lib/SwinTransformer/kernels/window_process/setup.py
  86. 132 0
      lib/SwinTransformer/kernels/window_process/swin_window_process.cpp
  87. 323 0
      lib/SwinTransformer/kernels/window_process/swin_window_process_kernel.cu
  88. 250 0
      lib/SwinTransformer/kernels/window_process/unit_test.py
  89. 63 0
      lib/SwinTransformer/kernels/window_process/window_process.py
  90. 41 0
      lib/SwinTransformer/logger.py
  91. 152 0
      lib/SwinTransformer/lr_scheduler.py
  92. 354 0
      lib/SwinTransformer/main.py
  93. 373 0
      lib/SwinTransformer/main_moe.py
  94. 342 0
      lib/SwinTransformer/main_simmim_ft.py
  95. 234 0
      lib/SwinTransformer/main_simmim_pt.py
  96. 1 0
      lib/SwinTransformer/models/__init__.py
  97. 121 0
      lib/SwinTransformer/models/build.py
  98. 209 0
      lib/SwinTransformer/models/simmim.py
  99. 468 0
      lib/SwinTransformer/models/swin_mlp.py
  100. 614 0
      lib/SwinTransformer/models/swin_transformer.py

+ 36 - 51
.gitignore

@@ -1,60 +1,45 @@
-# ---> Python
-# Byte-compiled / optimized / DLL files
+# Python
 __pycache__/
 *.py[cod]
 *$py.class
-
-# C extensions
 *.so
-
-# Distribution / packaging
-.Python
-env/
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-*.egg-info/
-.installed.cfg
 *.egg
+*.egg-info/
+dist/
+build/
+*.whl
 
-# PyInstaller
-#  Usually these files are written by a python script from a template
-#  before PyInstaller builds the exe, so as to inject date/other infos into it.
-*.manifest
-*.spec
-
-# Installer logs
-pip-log.txt
-pip-delete-this-directory.txt
-
-# Unit test / coverage reports
-htmlcov/
-.tox/
-.coverage
-.coverage.*
-.cache
-nosetests.xml
-coverage.xml
-*,cover
-
-# Translations
-*.mo
-*.pot
-
-# Django stuff:
-*.log
+# IDE
+.idea/
+.vscode/
+*.swp
+*.swo
+
+# OS
+.DS_Store
+Thumbs.db
+
+# Reference code & papers (do not upload)
+ref/
+tmp/
 
-# Sphinx documentation
-docs/_build/
+# Weights & checkpoints
+*.pth
+*.pt
+*.ckpt
+*.onnx
+
+
+# Logs & outputs
+*.log
+outputs/
+runs/
+lightning_logs/
 
-# PyBuilder
-target/
+# Jupyter
+.ipynb_checkpoints/
 
+# Environment
+.env
+.venv/
+venv/

+ 92 - 0
configs/segmentation/train_seg_b.yaml

@@ -0,0 +1,92 @@
+train:
+  seed: 42
+  epochs: 100
+  batch_size: 8
+  accum_steps: 1
+  amp: true
+  num_workers: 8
+  device: cuda
+
+dataset:
+  name: lung_ultrasound_seg
+  root: data/lung_ultrasound
+  task_name: b
+  image_size: [256, 256]
+  in_channels: 3
+  num_classes: 1
+  train_split: train
+  val_split: val
+  test_split: test
+  mask_suffix: .png
+  image_suffix: .png
+
+model:
+  name: swin_unet
+  encoder_name: swinv2_base_patch4_window12_192_22k
+  in_channels: 3
+  out_channels: 1
+  img_size: 256
+  drop_rate: 0.0
+  drop_path_rate: 0.2
+
+pretrain:
+  enabled: true
+  source: imagenet22k
+  checkpoint: weights/swinv2_base_patch4_window12_192_22k.pth
+  strict: false
+
+loss:
+  task_name: b
+  task_mode: binary
+
+metrics:
+  task_mode: binary
+  metrics:
+    - name: dice
+    - name: iou
+
+optimizer:
+  name: adamw
+  lr: 5.0e-5
+  weight_decay: 0.05
+  betas: [0.9, 0.999]
+
+scheduler:
+  name: cosine
+  warmup:
+    name: linear
+    params:
+      start_factor: 0.1
+      total_iters: 10
+  params:
+    T_max: 100
+    eta_min: 1.0e-6
+
+augmentation:
+  train:
+    random_flip: true
+    random_rotate_90: true
+    random_resized_crop: false
+    random_brightness_contrast: true
+    random_gaussian_noise: true
+  val:
+    center_crop: false
+
+validation:
+  enabled: true
+  interval: 1
+  metrics: [dice, iou]
+  save_best: true
+  monitor: dice
+  mode: max
+
+checkpoint:
+  dir: outputs/segmentation/train_seg_b
+  save_last: true
+  save_best_only: false
+  keep_top_k: 3
+
+logging:
+  log_interval: 20
+  use_tensorboard: true
+  tensorboard_dir: outputs/tensorboard/train_seg_b

+ 92 - 0
configs/segmentation/train_seg_lun.yaml

@@ -0,0 +1,92 @@
+train:
+  seed: 42
+  epochs: 100
+  batch_size: 8
+  accum_steps: 1
+  amp: true
+  num_workers: 8
+  device: cuda
+
+dataset:
+  name: lung_ultrasound_seg
+  root: data/lung_ultrasound
+  task_name: lun
+  image_size: [256, 256]
+  in_channels: 3
+  num_classes: 1
+  train_split: train
+  val_split: val
+  test_split: test
+  mask_suffix: .png
+  image_suffix: .png
+
+model:
+  name: swin_unet
+  encoder_name: swinv2_base_patch4_window12_192_22k
+  in_channels: 3
+  out_channels: 1
+  img_size: 256
+  drop_rate: 0.0
+  drop_path_rate: 0.2
+
+pretrain:
+  enabled: true
+  source: imagenet22k
+  checkpoint: weights/swinv2_base_patch4_window12_192_22k.pth
+  strict: false
+
+loss:
+  task_name: lun
+  task_mode: binary
+
+metrics:
+  task_mode: binary
+  metrics:
+    - name: dice
+    - name: iou
+
+optimizer:
+  name: adamw
+  lr: 5.0e-5
+  weight_decay: 0.05
+  betas: [0.9, 0.999]
+
+scheduler:
+  name: cosine
+  warmup:
+    name: linear
+    params:
+      start_factor: 0.1
+      total_iters: 10
+  params:
+    T_max: 100
+    eta_min: 1.0e-6
+
+augmentation:
+  train:
+    random_flip: true
+    random_rotate_90: true
+    random_resized_crop: false
+    random_brightness_contrast: true
+    random_gaussian_noise: true
+  val:
+    center_crop: false
+
+validation:
+  enabled: true
+  interval: 1
+  metrics: [dice, iou]
+  save_best: true
+  monitor: dice
+  mode: max
+
+checkpoint:
+  dir: outputs/segmentation/train_seg_lun
+  save_last: true
+  save_best_only: false
+  keep_top_k: 3
+
+logging:
+  log_interval: 20
+  use_tensorboard: true
+  tensorboard_dir: outputs/tensorboard/train_seg_lun

+ 95 - 0
configs/segmentation/train_seg_multiclass_template.yaml

@@ -0,0 +1,95 @@
+train:
+  seed: 42
+  epochs: 100
+  batch_size: 8
+  accum_steps: 1
+  amp: true
+  num_workers: 8
+  device: cuda
+
+dataset:
+  name: lung_ultrasound_seg_multiclass
+  root: data/lung_ultrasound
+  task_name: multiclass
+  image_size: [256, 256]
+  in_channels: 3
+  num_classes: 4
+  class_names: [background, lun, pe, b]
+  train_split: train
+  val_split: val
+  test_split: test
+  mask_suffix: .png
+  image_suffix: .png
+
+model:
+  name: swin_unet
+  encoder_name: swinv2_base_patch4_window12_192_22k
+  in_channels: 3
+  out_channels: 4
+  img_size: 256
+  drop_rate: 0.0
+  drop_path_rate: 0.2
+
+pretrain:
+  enabled: true
+  source: imagenet22k
+  checkpoint: weights/swinv2_base_patch4_window12_192_22k.pth
+  strict: false
+
+loss:
+  name: generalized_dice_focal
+  task_mode: multiclass
+  params:
+    include_background: false
+
+metrics:
+  task_mode: multiclass
+  metrics:
+    - name: dice
+    - name: miou
+
+optimizer:
+  name: adamw
+  lr: 5.0e-5
+  weight_decay: 0.05
+  betas: [0.9, 0.999]
+
+scheduler:
+  name: cosine
+  warmup:
+    name: linear
+    params:
+      start_factor: 0.1
+      total_iters: 10
+  params:
+    T_max: 100
+    eta_min: 1.0e-6
+
+augmentation:
+  train:
+    random_flip: true
+    random_rotate_90: true
+    random_resized_crop: false
+    random_brightness_contrast: true
+    random_gaussian_noise: true
+  val:
+    center_crop: false
+
+validation:
+  enabled: true
+  interval: 1
+  metrics: [dice, miou]
+  save_best: true
+  monitor: dice
+  mode: max
+
+checkpoint:
+  dir: outputs/segmentation/train_seg_multiclass
+  save_last: true
+  save_best_only: false
+  keep_top_k: 3
+
+logging:
+  log_interval: 20
+  use_tensorboard: true
+  tensorboard_dir: outputs/tensorboard/train_seg_multiclass

+ 92 - 0
configs/segmentation/train_seg_pe.yaml

@@ -0,0 +1,92 @@
+train:
+  seed: 42
+  epochs: 100
+  batch_size: 8
+  accum_steps: 1
+  amp: true
+  num_workers: 8
+  device: cuda
+
+dataset:
+  name: lung_ultrasound_seg
+  root: data/lung_ultrasound
+  task_name: pe
+  image_size: [256, 256]
+  in_channels: 3
+  num_classes: 1
+  train_split: train
+  val_split: val
+  test_split: test
+  mask_suffix: .png
+  image_suffix: .png
+
+model:
+  name: swin_unet
+  encoder_name: swinv2_base_patch4_window12_192_22k
+  in_channels: 3
+  out_channels: 1
+  img_size: 256
+  drop_rate: 0.0
+  drop_path_rate: 0.2
+
+pretrain:
+  enabled: true
+  source: imagenet22k
+  checkpoint: weights/swinv2_base_patch4_window12_192_22k.pth
+  strict: false
+
+loss:
+  task_name: pe
+  task_mode: binary
+
+metrics:
+  task_mode: binary
+  metrics:
+    - name: dice
+    - name: iou
+
+optimizer:
+  name: adamw
+  lr: 5.0e-5
+  weight_decay: 0.05
+  betas: [0.9, 0.999]
+
+scheduler:
+  name: cosine
+  warmup:
+    name: linear
+    params:
+      start_factor: 0.1
+      total_iters: 10
+  params:
+    T_max: 100
+    eta_min: 1.0e-6
+
+augmentation:
+  train:
+    random_flip: true
+    random_rotate_90: true
+    random_resized_crop: false
+    random_brightness_contrast: true
+    random_gaussian_noise: true
+  val:
+    center_crop: false
+
+validation:
+  enabled: true
+  interval: 1
+  metrics: [dice, iou]
+  save_best: true
+  monitor: dice
+  mode: max
+
+checkpoint:
+  dir: outputs/segmentation/train_seg_pe
+  save_last: true
+  save_best_only: false
+  keep_top_k: 3
+
+logging:
+  log_interval: 20
+  use_tensorboard: true
+  tensorboard_dir: outputs/tensorboard/train_seg_pe

+ 92 - 0
configs/segmentation/train_seg_template.yaml

@@ -0,0 +1,92 @@
+train:
+  seed: 42
+  epochs: 100
+  batch_size: 8
+  accum_steps: 1
+  amp: true
+  num_workers: 8
+  device: cuda
+
+dataset:
+  name: lung_ultrasound_seg
+  root: data/lung_ultrasound
+  task_name: lun
+  image_size: [ 256, 256 ]
+  in_channels: 3
+  num_classes: 1
+  train_split: train
+  val_split: val
+  test_split: test
+  mask_suffix: .png
+  image_suffix: .png
+
+model:
+  name: swin_unet
+  encoder_name: swinv2_base_patch4_window12_192_22k
+  in_channels: 3
+  out_channels: 1
+  img_size: 256
+  drop_rate: 0.0
+  drop_path_rate: 0.2
+
+pretrain:
+  enabled: true
+  source: imagenet22k
+  checkpoint: weights/swinv2_base_patch4_window12_192_22k.pth
+  strict: false
+
+loss:
+  task_name: lun
+  task_mode: binary
+
+metrics:
+  task_mode: binary
+  metrics:
+    - name: dice
+    - name: iou
+
+optimizer:
+  name: adamw
+  lr: 5.0e-5
+  weight_decay: 0.05
+  betas: [ 0.9, 0.999 ]
+
+scheduler:
+  name: cosine
+  warmup:
+    name: linear
+    params:
+      start_factor: 0.1
+      total_iters: 10
+  params:
+    T_max: 100
+    eta_min: 1.0e-6
+
+augmentation:
+  train:
+    random_flip: true
+    random_rotate_90: true
+    random_resized_crop: false
+    random_brightness_contrast: true
+    random_gaussian_noise: true
+  val:
+    center_crop: false
+
+validation:
+  enabled: true
+  interval: 1
+  metrics: [ dice, iou ]
+  save_best: true
+  monitor: dice
+  mode: max
+
+checkpoint:
+  dir: outputs/segmentation/train_seg_template
+  save_last: true
+  save_best_only: false
+  keep_top_k: 3
+
+logging:
+  log_interval: 20
+  use_tensorboard: true
+  tensorboard_dir: outputs/tensorboard/train_seg_template

+ 19 - 0
configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml

@@ -0,0 +1,19 @@
+DATA:
+  DATASET: imagenet22K
+  IMG_SIZE: 192
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window12_192_22k
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 12
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 0.1
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6

+ 19 - 0
configs/swinv2/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml

@@ -0,0 +1,19 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window12to16_192to256_22kto1k_ft
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 16
+    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
+  MIN_LR: 2e-07

+ 21 - 0
configs/swinv2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml

@@ -0,0 +1,21 @@
+DATA:
+  IMG_SIZE: 384
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window12to24_192to384_22kto1k_ft
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 24
+    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
+  MIN_LR: 2e-07
+TEST:
+  CROP: False

+ 11 - 0
configs/swinv2/swinv2_base_patch4_window16_256.yaml

@@ -0,0 +1,11 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window16_256
+  DROP_PATH_RATE: 0.5
+  SWINV2:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 16

+ 11 - 0
configs/swinv2/swinv2_base_patch4_window8_256.yaml

@@ -0,0 +1,11 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window8_256
+  DROP_PATH_RATE: 0.5
+  SWINV2:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 8

+ 19 - 0
configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml

@@ -0,0 +1,19 @@
+DATA:
+  DATASET: imagenet22K
+  IMG_SIZE: 192
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_large_patch4_window12_192_22k
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 192
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 6, 12, 24, 48 ]
+    WINDOW_SIZE: 12
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 0.1
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6

+ 19 - 0
configs/swinv2/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml

@@ -0,0 +1,19 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window12to16_192to256_22kto1k_ft
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 192
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 6, 12, 24, 48 ]
+    WINDOW_SIZE: 16
+    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
+  MIN_LR: 2e-07

+ 21 - 0
configs/swinv2/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml

@@ -0,0 +1,21 @@
+DATA:
+  IMG_SIZE: 384
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_large_patch4_window12to24_192to384_22kto1k_ft
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 192
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 6, 12, 24, 48 ]
+    WINDOW_SIZE: 24
+    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
+  MIN_LR: 2e-07
+TEST:
+  CROP: False

+ 11 - 0
configs/swinv2/swinv2_small_patch4_window16_256.yaml

@@ -0,0 +1,11 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_small_patch4_window16_256
+  DROP_PATH_RATE: 0.3
+  SWINV2:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 16

+ 11 - 0
configs/swinv2/swinv2_small_patch4_window8_256.yaml

@@ -0,0 +1,11 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_small_patch4_window8_256
+  DROP_PATH_RATE: 0.3
+  SWINV2:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 8

+ 11 - 0
configs/swinv2/swinv2_tiny_patch4_window16_256.yaml

@@ -0,0 +1,11 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_tiny_patch4_window16_256
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 6, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 16

+ 11 - 0
configs/swinv2/swinv2_tiny_patch4_window8_256.yaml

@@ -0,0 +1,11 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_tiny_patch4_window8_256
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 6, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 8

+ 135 - 0
lib/SwinTransformer/.gitignore

@@ -0,0 +1,135 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# launch bash
+*.sh
+# nsight system report files
+*.nsys-rep
+*.sqlite
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+#   However, in case of collaboration, if having platform-specific dependencies or dependencies
+#   having no cross-platform support, pipenv may install dependencies that don't work, or not
+#   install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/

+ 9 - 0
lib/SwinTransformer/CODE_OF_CONDUCT.md

@@ -0,0 +1,9 @@
+# Microsoft Open Source Code of Conduct
+
+This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
+
+Resources:
+
+- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
+- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
+- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns

+ 21 - 0
lib/SwinTransformer/LICENSE

@@ -0,0 +1,21 @@
+    MIT License
+
+    Copyright (c) Microsoft Corporation.
+
+    Permission is hereby granted, free of charge, to any person obtaining a copy
+    of this software and associated documentation files (the "Software"), to deal
+    in the Software without restriction, including without limitation the rights
+    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+    copies of the Software, and to permit persons to whom the Software is
+    furnished to do so, subject to the following conditions:
+
+    The above copyright notice and this permission notice shall be included in all
+    copies or substantial portions of the Software.
+
+    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+    SOFTWARE

+ 159 - 0
lib/SwinTransformer/MODELHUB.md

@@ -0,0 +1,159 @@
+Access code for `baidu` is `swin`.
+
+## ImageNet-1K and ImageNet-22K Pretrained Swin-V1 Models
+
+| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS| 22K model | 1K model |
+| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: |:---: |
+| Swin-T | ImageNet-1K | 224x224 | 81.2 | 95.5 | 28M | 4.5G | 755 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w)/[config](configs/swin/swin_tiny_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745562/log_swin_tiny_patch4_window7_224.txt) |
+| Swin-S | ImageNet-1K | 224x224 | 83.2 | 96.2 | 50M | 8.7G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg)/[config](configs/swin/swin_small_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745563/log_swin_small_patch4_window7_224.txt) |
+| Swin-B | ImageNet-1K | 224x224 | 83.5 | 96.5 | 88M | 15.4G | 278  | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ)/[config](configs/swin/swin_base_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745564/log_swin_base_patch4_window7_224.txt) |
+| Swin-B | ImageNet-1K | 384x384 | 84.5 | 97.0 | 88M | 47.1G | 85 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw)/[config](configs/swin/swin_base_patch4_window12_384_finetune.yaml) |
+| Swin-T | ImageNet-22K | 224x224 | 80.9 | 96.0 | 28M | 4.5G | 755 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1vct0VYwwQQ8PYkBjwSSBZQ?pwd=swin)/[config](configs/swin/swin_tiny_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22kto1k_finetune.pth)/[baidu](https://pan.baidu.com/s/1K0OO-nGZDPkR8fm_r83e8Q?pwd=swin)/[config](configs/swin/swin_tiny_patch4_window7_224_22kto1k_finetune.yaml) |
+| Swin-S | ImageNet-22K | 224x224 | 83.2 | 97.0 | 50M | 8.7G | 437 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/11NC1xdT5BAGBgazdTme5Sg?pwd=swin)/[config](configs/swin/swin_small_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22kto1k_finetune.pth)/[baidu](https://pan.baidu.com/s/10RFVfjQJhwPfeHrmxQUaLw?pwd=swin)/[config](configs/swin/swin_small_patch4_window7_224_22kto1k_finetune.yaml) |
+| Swin-B | ImageNet-22K | 224x224 | 85.2 | 97.5 | 88M | 15.4G | 278 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA)/[config](configs/swin/swin_base_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg)/[config](configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml) |
+| Swin-B | ImageNet-22K | 384x384 | 86.4 | 98.0 | 88M | 47.1G | 85 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg)/[config](configs/swin/swin_base_patch4_window12_384_22kto1k_finetune.yaml) |
+| Swin-L | ImageNet-22K | 224x224 | 86.3 | 97.9 | 197M | 34.5G | 141 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w)/[config](configs/swin/swin_large_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ)/[config](configs/swin/swin_large_patch4_window7_224_22kto1k_finetune.yaml) |
+| Swin-L | ImageNet-22K | 384x384 | 87.3 | 98.2 | 197M | 103.9G | 42 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA)/[config](configs/swin/swin_large_patch4_window12_384_22kto1k_finetune.yaml) |
+
+## ImageNet-1K and ImageNet-22K Pretrained Swin-V2 Models
+
+| name | pretrain | resolution | window |acc@1 | acc@5 | #params | FLOPs | FPS |22K model | 1K model |
+|:---------------------:| :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---:|:---: |:---: |
+| SwinV2-T | ImageNet-1K | 256x256 | 8x8 | 81.8 | 95.9 | 28M | 5.9G | 572 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1RzLkAH_5OtfRCJe6Vlg6rg?pwd=swin)/[config](configs/swinv2/swinv2_tiny_patch4_window8_256.yaml) |
+| SwinV2-S | ImageNet-1K | 256x256 | 8x8 | 83.7 | 96.6 | 50M | 11.5G | 327 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/195PdA41szEduW3jEtRSa4Q?pwd=swin)/[config](configs/swinv2/swinv2_small_patch4_window8_256.yaml) |
+| SwinV2-B | ImageNet-1K | 256x256 | 8x8 | 84.2 | 96.9 | 88M | 20.3G | 217 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/18AfMSz3dPyzIvP1dKuERvQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window8_256.yaml) |
+| SwinV2-T | ImageNet-1K | 256x256 | 16x16 | 82.8 | 96.2 | 28M | 6.6G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1dyK3cK9Xipmv6RnTtrPocw?pwd=swin)/[config](configs/swinv2/swinv2_tiny_patch4_window16_256.yaml) |
+| SwinV2-S | ImageNet-1K | 256x256 | 16x16 | 84.1 | 96.8 | 50M | 12.6G  | 257 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1ZIPiSfWNKTPp821Ka-Mifw?pwd=swin)/[config](configs/swinv2/swinv2_small_patch4_window16_256.yaml) |
+| SwinV2-B | ImageNet-1K | 256x256 | 16x16 | 84.6 | 97.0 | 88M | 21.8G | 174 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1dlDQGn8BXCmnh7wQSM5Nhw?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window16_256.yaml) |
+| SwinV2-B<sup>\*</sup> | ImageNet-22K | 256x256 | 16x16 | 86.2 | 97.9 |  88M | 21.8G | 174 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/1Xc2rsSsRQz_sy5mjgfxrMQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/1sgstld4MgGsZxhUAW7MlmQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml) |
+| SwinV2-B<sup>\*</sup> | ImageNet-22K | 384x384 | 24x24 | 87.1 | 98.2 | 88M | 54.7G | 57  | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/1Xc2rsSsRQz_sy5mjgfxrMQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/17u3sEQaUYlvfL195rrORzQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml) |
+| SwinV2-L<sup>\*</sup> | ImageNet-22K | 256x256 | 16x16 | 86.9 | 98.0 | 197M | 47.5G | 95  | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/11PhCV7qAGXtZ8dXNgyiGOw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/1pqp31N80qIWjFPbudzB6Bw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml) |
+| SwinV2-L<sup>\*</sup> | ImageNet-22K | 384x384 | 24x24 | 87.6 | 98.3 | 197M | 115.4G | 33  | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/11PhCV7qAGXtZ8dXNgyiGOw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/13URdNkygr3Xn0N3e6IwjgA?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml) |
+
+Note:
+
+- SwinV2-B<sup>\*</sup>  (SwinV2-L<sup>\*</sup>) with input resolution of 256x256 and 384x384 both fine-tuned from the
+  same pre-training model using a smaller input resolution of 192x192.
+- SwinV2-B<sup>\*</sup> (384x384) achieves 78.08 acc@1 on ImageNet-1K-V2 while SwinV2-L<sup>\*</sup> (384x384) achieves
+  78.31.
+
+## ImageNet-1K Pretrained Swin MLP Models
+
+| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS |  1K model |
+| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
+| [Mixer-B/16](https://arxiv.org/pdf/2105.01601.pdf) | ImageNet-1K | 224x224 | 76.4 | - | 59M | 12.7G | - | [official repo](https://github.com/google-research/vision_transformer) |
+| [ResMLP-S24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 79.4 | - | 30M | 6.0G | 715 | [timm](https://github.com/rwightman/pytorch-image-models) |
+| [ResMLP-B24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 81.0 | - | 116M | 23.0G |  231 | [timm](https://github.com/rwightman/pytorch-image-models) |
+| Swin-T/C24 | ImageNet-1K | 256x256 | 81.6 | 95.7 | 28M | 5.9G | 563 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/17k-7l6Sxt7uZ7IV0f26GNQ)/[config](configs/swin/swin_tiny_c24_patch4_window8_256.yaml) |
+| SwinMLP-T/C24 | ImageNet-1K | 256x256 | 79.4 | 94.6 | 20M | 4.0G | 807 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1Sa4vP5R0M2RjfIe9HIga-Q)/[config](configs/swin/swin_mlp_tiny_c24_patch4_window8_256.yaml) |
+| SwinMLP-T/C12 | ImageNet-1K | 256x256 | 79.6 | 94.7 | 21M | 4.0G | 792 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c12_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1mM9J2_DEVZHUB5ASIpFl0w)/[config](configs/swin/swin_mlp_tiny_c12_patch4_window8_256.yaml) |
+| SwinMLP-T/C6 | ImageNet-1K | 256x256 | 79.7 | 94.9 | 23M | 4.0G | 766 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c6_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1hUTYVT2W1CsjICw-3W-Vjg)/[config](configs/swin/swin_mlp_tiny_c6_patch4_window8_256.yaml) |
+| SwinMLP-B | ImageNet-1K | 224x224 | 81.3 | 95.3 | 61M | 10.4G | 409 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1zww3dnbX3GxNiGfb-GwyUg)/[config](configs/swin/swin_mlp_base_patch4_window7_224.yaml) |
+
+Note: C24 means each head has 24 channels.
+
+## ImageNet-22K Pretrained Swin-MoE Models
+
+| name | #experts | k | router | resolution | window | IN-22K acc@1 | IN-1K/ft acc@1 | IN-1K/5-shot acc@1 | 22K model |
+| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
+| Swin-MoE-S | 1 (dense) | - | - | 192x192 | 8x8 | 35.5| 83.5 | 70.3 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_densebaseline_22k.zip)/[baidu](https://pan.baidu.com/s/1O1m9jT2pGoago_RiRX914w?pwd=swin)/[config](configs/swinmoe/swin_moe_small_patch4_window12_192_densebaseline_22k.yaml) |
+| Swin-MoE-S | 8 | 1 | Linear | 192x192 | 8x8 | 36.8 | 84.5 | 75.2 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_8expert_32gpu_22k.zip)/[baidu](https://pan.baidu.com/s/198IlYUrWOxEUp7wNdoJT5Q?pwd=swin)/[config](configs/swinmoe/swin_moe_small_patch4_window12_192_8expert_32gpu_22k.yaml) |
+| Swin-MoE-S | 16 | 1 | Linear |192x192 | 8x8 | 37.6 | 84.9 | 76.5 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_16expert_32gpu_22k.zip)/[baidu](https://pan.baidu.com/s/1vRQweedtT42VwMTqe9-r2A?pwd=swin)/[config](configs/swinmoe/swin_moe_small_patch4_window12_192_16expert_32gpu_22k.yaml) |
+| Swin-MoE-S | 32 | 1 | Linear | 192x192 | 8x8 | 37.4 | 84.7 | 75.9 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.zip)/[baidu](https://pan.baidu.com/s/1i7rImt5pwO8gJC-PRRuZwQ?pwd=swin)/[config](configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml) |
+| Swin-MoE-S | 32 | 1 | Cosine | 192x192 | 8x8 | 37.2 | 84.3 | 75.2 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k.zip)/[baidu](https://pan.baidu.com/s/1Yghr_12ntSrv01I9yatPDQ?pwd=swin)/[config](configs/swinmoe/swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml) |
+| Swin-MoE-S | 64 | 1 | Linear | 192x192 | 8x8 | 37.8 | 84.7 | 75.7 | - |
+| Swin-MoE-S | 128 | 1 | Linear | 192x192 | 8x8 | 37.4 | 84.5 | 75.4 | - |
+| Swin-MoE-B | 1 (dense) | - | - | 192x192 | 8x8 | 37.3 | 85.1 | 75.9 | [config](configs/swinmoe/swin_moe_base_patch4_window12_192_densebaseline_22k.yaml) |
+| Swin-MoE-B | 8 | 1 | Linear | 192x192 | 8x8 | 38.1 | 85.3 | 77.2 | [config](configs/swinmoe/swin_moe_base_patch4_window12_192_8expert_32gpu_22k.yaml) |
+| Swin-MoE-B | 16 | 1 | Linear | 192x192 | 8x8 | 38.7 | 85.5 | 78.2 | [config](configs/swinmoe/swin_moe_base_patch4_window12_192_16expert_32gpu_22k.yaml) |
+| Swin-MoE-B | 32 | 1 | Linear | 192x192 | 8x8 | 38.6 | 85.5 | 77.9 | [config](configs/swinmoe/swin_moe_base_patch4_window12_192_32expert_32gpu_22k.yaml) |
+| Swin-MoE-B | 32 | 1 | Cosine | 192x192 | 8x8 | 38.5 | 85.3 | 77.3 | [config](configs/swinmoe/swin_moe_base_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml) |
+| Swin-MoE-B | 32 | 2 | Linear | 192x192 | 8x8 | 38.6 | 85.5 | 78.7 | - |
+
+## SimMIM Pretrained Swin-V2 Models
+
+> Please note that all SimMIM pretrained Swin-V2 models will be stored in the Huggingface repository starting July 2024. For more details, refer to the [huggingface repository](https://huggingface.co/zdaxie/SimMIM).
+
+- **Model size** only includes the backbone weights and excludes weights in the decoders/classification heads.
+- **Batch size** for all models is set to 2048.
+- **Validation loss** is calculated on the ImageNet-1K validation set.
+- **Fine-tuned acc@1** refers to the top-1 accuracy on the ImageNet-1K validation set after fine-tuning.
+
+| name | model size | pre-train dataset | pre-train iterations | validation loss | fine-tuned acc@1 | pre-trained model | fine-tuned model |
+| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
+| SwinV2-Small | 49M | ImageNet-1K 10% | 125k | 0.4820 | 82.69 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper10_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper10_125k.pth?download=true) |
+| SwinV2-Small | 49M | ImageNet-1K 10% | 250k | 0.4961 | 83.11 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper10_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper10_250k.pth?download=true) |
+| SwinV2-Small | 49M | ImageNet-1K 10% | 500k | 0.5115 | 83.17 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper10_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper10_500k.pth?download=true) |
+| SwinV2-Small | 49M | ImageNet-1K 20% | 125k | 0.4751 | 83.05 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper20_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper20_125k.pth?download=true) |
+| SwinV2-Small | 49M | ImageNet-1K 20% | 250k | 0.4722 | 83.56 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper20_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper20_250k.pth?download=true) |
+| SwinV2-Small | 49M | ImageNet-1K 20% | 500k | 0.4734 | 83.75 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper20_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper20_500k.pth?download=true) |
+| SwinV2-Small | 49M | ImageNet-1K 50% | 125k | 0.4732 | 83.04 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper50_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper50_125k.pth?download=true) |
+| SwinV2-Small | 49M | ImageNet-1K 50% | 250k | 0.4681 | 83.67 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper50_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper50_250k.pth?download=true) |
+| SwinV2-Small | 49M | ImageNet-1K 50% | 500k | 0.4646 | 83.96 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1kper50_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1kper50_500k.pth?download=true) |
+| SwinV2-Small | 49M | ImageNet-1K | 125k | 0.4728 | 82.92 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1k_125k.pth?download=true) |
+| SwinV2-Small | 49M | ImageNet-1K | 250k | 0.4674 | 83.66 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1k_250k.pth?download=true) |
+| SwinV2-Small | 49M | ImageNet-1K | 500k | 0.4641 | 84.08 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_small_1k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_small_1k_500k.pth?download=true) |
+| SwinV2-Base | 87M | ImageNet-1K 10% | 125k | 0.4822 | 83.33 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper10_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper10_125k.pth?download=true) |
+| SwinV2-Base | 87M | ImageNet-1K 10% | 250k | 0.4997 | 83.60 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper10_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper10_250k.pth?download=true) |
+| SwinV2-Base | 87M | ImageNet-1K 10% | 500k | 0.5112 | 83.41 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper10_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper10_500k.pth?download=true) |
+| SwinV2-Base | 87M | ImageNet-1K 20% | 125k | 0.4703 | 83.86 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper20_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper20_125k.pth?download=true) |
+| SwinV2-Base | 87M | ImageNet-1K 20% | 250k | 0.4679 | 84.37 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper20_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper20_250k.pth?download=true) |
+| SwinV2-Base | 87M | ImageNet-1K 20% | 500k | 0.4711 | 84.61 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper20_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper20_500k.pth?download=true) |
+| SwinV2-Base | 87M | ImageNet-1K 50% | 125k | 0.4683 | 84.04 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper50_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper50_125k.pth?download=true) |
+| SwinV2-Base | 87M | ImageNet-1K 50% | 250k | 0.4633 | 84.57 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper50_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper50_250k.pth?download=true) |
+| SwinV2-Base | 87M | ImageNet-1K 50% | 500k | 0.4598 | 84.95 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1kper50_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1kper50_500k.pth?download=true) |
+| SwinV2-Base | 87M | ImageNet-1K | 125k | 0.4680 | 84.13 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1k_125k.pth?download=true) |
+| SwinV2-Base | 87M | ImageNet-1K | 250k | 0.4626 | 84.65 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1k_250k.pth?download=true) |
+| SwinV2-Base | 87M | ImageNet-1K | 500k | 0.4588 | 85.04 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_1k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_1k_500k.pth?download=true) |
+| SwinV2-Base | 87M | ImageNet-22K | 125k | 0.4695 | 84.11 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_22k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_22k_125k.pth?download=true) |
+| SwinV2-Base | 87M | ImageNet-22K | 250k | 0.4649 | 84.57 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_22k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_22k_250k.pth?download=true) |
+| SwinV2-Base | 87M | ImageNet-22K | 500k | 0.4614 | 85.11 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_base_22k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_base_22k_500k.pth?download=true) |
+| SwinV2-Large | 195M | ImageNet-1K 10% | 125k | 0.4995 | 83.69 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper10_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper10_125k.pth?download=true) |
+| SwinV2-Large | 195M | ImageNet-1K 10% | 250k | 0.5140 | 83.66 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper10_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper10_250k.pth?download=true) |
+| SwinV2-Large | 195M | ImageNet-1K 10% | 500k | 0.5150 | 83.50 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper10_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper10_500k.pth?download=true) |
+| SwinV2-Large | 195M | ImageNet-1K 20% | 125k | 0.4675 | 84.38 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper20_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper20_125k.pth?download=true) |
+| SwinV2-Large | 195M | ImageNet-1K 20% | 250k | 0.4746 | 84.71 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper20_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper20_250k.pth?download=true) |
+| SwinV2-Large | 195M | ImageNet-1K 20% | 500k | 0.4960 | 84.59 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper20_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper20_500k.pth?download=true) |
+| SwinV2-Large | 195M | ImageNet-1K 50% | 125k | 0.4622 | 84.78 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper50_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper50_125k.pth?download=true) |
+| SwinV2-Large | 195M | ImageNet-1K 50% | 250k | 0.4566 | 85.38 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper50_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper50_250k.pth?download=true) |
+| SwinV2-Large | 195M | ImageNet-1K 50% | 500k | 0.4530 | 85.80 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1kper50_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1kper50_500k.pth?download=true) |
+| SwinV2-Large | 195M | ImageNet-1K | 125k | 0.4611 | 84.98 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1k_125k.pth?download=true) |
+| SwinV2-Large | 195M | ImageNet-1K | 250k | 0.4552 | 85.45 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1k_250k.pth?download=true) |
+| SwinV2-Large | 195M | ImageNet-1K | 500k | 0.4507 | 85.91 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_1k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_1k_500k.pth?download=true) |
+| SwinV2-Large | 195M | ImageNet-22K | 125k | 0.4649 | 84.61 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_22k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_22k_125k.pth?download=true) |
+| SwinV2-Large | 195M | ImageNet-22K | 250k | 0.4586 | 85.39 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_22k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_22k_250k.pth?download=true) |
+| SwinV2-Large | 195M | ImageNet-22K | 500k | 0.4536 | 85.81 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_large_22k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_large_22k_500k.pth?download=true) |
+| SwinV2-Huge | 655M | ImageNet-1K 20% | 125k | 0.4789 | 84.35 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper20_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper20_125k.pth?download=true) |
+| SwinV2-Huge | 655M | ImageNet-1K 20% | 250k | 0.5038 | 84.16 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper20_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper20_250k.pth?download=true) |
+| SwinV2-Huge | 655M | ImageNet-1K 20% | 500k | 0.5071 | 83.44 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper20_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper20_500k.pth?download=true) |
+| SwinV2-Huge | 655M | ImageNet-1K 50% | 125k | 0.4549 | 85.09 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper50_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper50_125k.pth?download=true) |
+| SwinV2-Huge | 655M | ImageNet-1K 50% | 250k | 0.4511 | 85.64 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper50_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper50_250k.pth?download=true) |
+| SwinV2-Huge | 655M | ImageNet-1K 50% | 500k | 0.4559 | 85.69 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1kper50_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1kper50_500k.pth?download=true) |
+| SwinV2-Huge | 655M | ImageNet-1K | 125k | 0.4531 | 85.23 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1k_125k.pth?download=true) |
+| SwinV2-Huge | 655M | ImageNet-1K | 250k | 0.4464 | 85.90 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1k_250k.pth?download=true) |
+| SwinV2-Huge | 655M | ImageNet-1K | 500k | 0.4416 | 86.34 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_1k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_1k_500k.pth?download=true) |
+| SwinV2-Huge | 655M | ImageNet-22K | 125k | 0.4564 | 85.14 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_22k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_22k_125k.pth?download=true) |
+| SwinV2-Huge | 655M | ImageNet-22K | 250k | 0.4499 | 85.86 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_22k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_22k_250k.pth?download=true) |
+| SwinV2-Huge | 655M | ImageNet-22K | 500k | 0.4444 | 86.27 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_huge_22k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_huge_22k_500k.pth?download=true) |
+| SwinV2-giant | 1.06B | ImageNet-1K 50% | 125k | 0.4534 | 85.44 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1kper50_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1kper50_125k.pth?download=true) |
+| SwinV2-giant | 1.06B | ImageNet-1K 50% | 250k | 0.4515 | 85.76 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1kper50_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1kper50_250k.pth?download=true) |
+| SwinV2-giant | 1.06B | ImageNet-1K 50% | 500k | 0.4719 | 85.51 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1kper50_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1kper50_500k.pth?download=true) |
+| SwinV2-giant | 1.06B | ImageNet-1K | 125k | 0.4513 | 85.57 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1k_125k.pth?download=true) |
+| SwinV2-giant | 1.06B | ImageNet-1K | 250k | 0.4442 | 86.12 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1k_250k.pth?download=true) |
+| SwinV2-giant | 1.06B | ImageNet-1K | 500k | 0.4395 | 86.46 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_1k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_1k_500k.pth?download=true) |
+| SwinV2-giant | 1.06B | ImageNet-22K | 125k | 0.4544 | 85.39 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_22k_125k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_22k_125k.pth?download=true) |
+| SwinV2-giant | 1.06B | ImageNet-22K | 250k | 0.4475 | 85.96 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_22k_250k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_22k_250k.pth?download=true) |
+| SwinV2-giant | 1.06B | ImageNet-22K | 500k | 0.4416 | 86.53 | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_pretrain_models/swinv2_giant_22k_500k.pth?download=true) | [huggingface](https://huggingface.co/zdaxie/SimMIM/resolve/main/simmim_swinv2_finetune_models/finetune_swinv2_giant_22k_500k.pth?download=true) |
+
+## SimMIM Pretrained Swin-V1 Models
+
+**ImageNet-1K Pre-trained and Fine-tuned Models**
+
+| name | pre-train epochs | pre-train resolution | fine-tune resolution | acc@1 | pre-trained model | fine-tuned model |
+| :---: | :---: | :---: | :---: | :---: | :---: | :---: |
+| Swin-Base | 100 | 192x192 | 192x192 | 82.8 | [google](https://drive.google.com/file/d/1Wcbr66JL26FF30Kip9fZa_0lXrDAKP-d/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_pretrain__swin_base__img192_window6__100ep.yaml) | [google](https://drive.google.com/file/d/1RsgHfjB4B1ZYblXEQVT-FPX3WSvBrxcs/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_finetune__swin_base__img192_window6__100ep.yaml) |
+| Swin-Base | 100 | 192x192 | 224x224 | 83.5 | [google](https://drive.google.com/file/d/1Wcbr66JL26FF30Kip9fZa_0lXrDAKP-d/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_pretrain__swin_base__img192_window6__100ep.yaml) | [google](https://drive.google.com/file/d/1mb43BkW56F5smwiX-g7QUUD7f1Rftq8u/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_finetune__swin_base__img224_window7__100ep.yaml) |
+| Swin-Base | 800 | 192x192 | 224x224 | 84.0 | [google](https://drive.google.com/file/d/15zENvGjHlM71uKQ3d2FbljWPubtrPtjl/view?usp=sharing)/[config](configs/swin_base__800ep/simmim_pretrain__swin_base__img192_window6__800ep.yaml) | [google](https://drive.google.com/file/d/1xEKyfMTsdh6TfnYhk5vbw0Yz7a-viZ0w/view?usp=sharing)/[config](configs/swin_base__800ep/simmim_finetune__swin_base__img224_window7__800ep.yaml) |
+| Swin-Large | 800 | 192x192 | 224x224 | 85.4 | [google](https://drive.google.com/file/d/1qDxrTl2YUDB0505_4QrU5LU2R1kKmcBP/view?usp=sharing)/[config](configs/swin_large__800ep/simmim_pretrain__swin_large__img192_window12__800ep.yaml) | [google](https://drive.google.com/file/d/1mf0ZpXttEvFsH87Www4oQ-t8Kwr0x485/view?usp=sharing)/[config](configs/swin_large__800ep/simmim_finetune__swin_large__img224_window14__800ep.yaml) |
+| SwinV2-Huge | 800 | 192x192 | 224x224 | 85.7 | / | / |
+| SwinV2-Huge | 800 | 192x192 | 512x512 | 87.1 | / | / |

+ 310 - 0
lib/SwinTransformer/README.md

@@ -0,0 +1,310 @@
+# Swin Transformer
+
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-v2-scaling-up-capacity-and/object-detection-on-coco)](https://paperswithcode.com/sota/object-detection-on-coco?p=swin-transformer-v2-scaling-up-capacity-and)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-v2-scaling-up-capacity-and/instance-segmentation-on-coco)](https://paperswithcode.com/sota/instance-segmentation-on-coco?p=swin-transformer-v2-scaling-up-capacity-and)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-v2-scaling-up-capacity-and/semantic-segmentation-on-ade20k)](https://paperswithcode.com/sota/semantic-segmentation-on-ade20k?p=swin-transformer-v2-scaling-up-capacity-and)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/swin-transformer-v2-scaling-up-capacity-and/action-classification-on-kinetics-400)](https://paperswithcode.com/sota/action-classification-on-kinetics-400?p=swin-transformer-v2-scaling-up-capacity-and)
+
+This repo is the official implementation of ["Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"](https://arxiv.org/pdf/2103.14030.pdf) as well as the follow-ups. It currently includes code and models for the following tasks:
+
+> **Image Classification**: Included in this repo. See [get_started.md](get_started.md) for a quick start.
+
+> **Object Detection and Instance Segmentation**: See [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).
+
+> **Semantic Segmentation**: See [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation).
+
+> **Video Action Recognition**: See [Video Swin Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer).
+
+> **Semi-Supervised Object Detection**: See [Soft Teacher](https://github.com/microsoft/SoftTeacher).
+
+> **SSL: Contrasitive Learning**: See [Transformer-SSL](https://github.com/SwinTransformer/Transformer-SSL).
+
+> **SSL: Masked Image Modeling**: See [get_started.md#simmim-support](https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md#simmim-support).
+
+> **Mixture-of-Experts**: See [get_started](get_started.md#mixture-of-experts-support) for more instructions.
+
+> **Feature-Distillation**: See [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation).
+
+## Updates
+
+***12/29/2022***
+
+1. **Nvidia**'s [FasterTransformer](https://github.com/NVIDIA/FasterTransformer/blob/main/docs/swin_guide.md) now supports Swin Transformer V2 inference, which have significant speed improvements on `T4 and A100 GPUs`.
+
+***11/30/2022***
+
+1. Models and codes of **Feature Distillation** are released. Please refer to [Feature-Distillation](https://github.com/SwinTransformer/Feature-Distillation) for details, and the checkpoints (FD-EsViT-Swin-B, FD-DeiT-ViT-B, FD-DINO-ViT-B, FD-CLIP-ViT-B, FD-CLIP-ViT-L).
+
+***09/24/2022***
+
+1. Merged [SimMIM](https://github.com/microsoft/SimMIM), which is a **Masked Image Modeling** based pre-training approach applicable to Swin and SwinV2 (and also applicable for ViT and ResNet). Please refer to [get started with SimMIM](get_started.md#simmim-support) to play with SimMIM pre-training.
+
+2. Released a series of Swin and SwinV2 models pre-trained using the SimMIM approach (see [MODELHUB for SimMIM](MODELHUB.md#simmim-pretrained-swin-v2-models)), with model size ranging from SwinV2-Small-50M to SwinV2-giant-1B, data size ranging from ImageNet-1K-10% to ImageNet-22K, and iterations from 125k to 500k. You may leverage these models to study the properties of MIM methods. Please look into the [data scaling](https://arxiv.org/abs/2206.04664) paper for more details.
+
+***07/09/2022***
+
+`News`: 
+
+1. SwinV2-G achieves `61.4 mIoU` on ADE20K semantic segmentation (+1.5 mIoU over the previous SwinV2-G model), using an additional [feature distillation (FD)](https://github.com/SwinTransformer/Feature-Distillation) approach, **setting a new recrod** on this benchmark. FD is an approach that can generally improve the fine-tuning performance of various pre-trained models, including DeiT, DINO, and CLIP. Particularly, it improves CLIP pre-trained ViT-L by +1.6% to reach `89.0%` on ImageNet-1K image classification, which is **the most accurate ViT-L model**.
+2. Merged a PR from **Nvidia** that links to faster Swin Transformer inference that have significant speed improvements on `T4 and A100 GPUs`.
+3. Merged a PR from **Nvidia** that enables an option to use `pure FP16 (Apex O2)` in training, while almost maintaining the accuracy.
+
+***06/03/2022***
+
+1. Added **Swin-MoE**, the Mixture-of-Experts variant of Swin Transformer implemented using [Tutel](https://github.com/microsoft/tutel) (an optimized Mixture-of-Experts implementation). **Swin-MoE** is introduced in the [TuTel](https://arxiv.org/abs/2206.03382) paper.
+
+***05/12/2022***
+
+1. Pretrained models of [Swin Transformer V2](https://arxiv.org/abs/2111.09883) on ImageNet-1K and ImageNet-22K are released. 
+2. ImageNet-22K pretrained models for Swin-V1-Tiny and Swin-V2-Small are released.
+
+***03/02/2022***
+
+1. Swin Transformer V2 and SimMIM got accepted by CVPR 2022. [SimMIM](https://github.com/microsoft/SimMIM) is a self-supervised pre-training approach based on masked image modeling, a key technique that works out the 3-billion-parameter Swin V2 model using `40x less labelled data` than that of previous billion-scale models based on JFT-3B. 
+
+***02/09/2022***
+
+1. Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). Try out the Web Demo [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/akhaliq/Swin-Transformer)
+
+***10/12/2021***
+
+1. Swin Transformer received ICCV 2021 best paper award (Marr Prize).
+
+***08/09/2021***
+1. [Soft Teacher](https://arxiv.org/pdf/2106.09018v2.pdf) will appear at ICCV2021. The code will be released at [GitHub Repo](https://github.com/microsoft/SoftTeacher). `Soft Teacher` is an end-to-end semi-supervisd object detection method, achieving a new record on the COCO test-dev: `61.3 box AP` and `53.0 mask AP`.
+ 
+***07/03/2021***
+1. Add **Swin MLP**, which is an adaption of `Swin Transformer` by replacing all multi-head self-attention (MHSA) blocks by MLP layers (more precisely it is a group linear layer). The shifted window configuration can also significantly improve the performance of vanilla MLP architectures. 
+
+***06/25/2021***
+1. [Video Swin Transformer](https://arxiv.org/abs/2106.13230) is released at [Video-Swin-Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer).
+`Video Swin Transformer` achieves state-of-the-art accuracy on a broad range of video recognition benchmarks, including action recognition (`84.9` top-1 accuracy on Kinetics-400 and `86.1` top-1 accuracy on Kinetics-600 with `~20x` less pre-training data and `~3x` smaller model size) and temporal modeling (`69.6` top-1 accuracy on Something-Something v2).
+
+***05/12/2021***
+1. Used as a backbone for `Self-Supervised Learning`: [Transformer-SSL](https://github.com/SwinTransformer/Transformer-SSL)
+
+Using Swin-Transformer as the backbone for self-supervised learning enables us to evaluate the transferring performance of the learnt representations on down-stream tasks, which is missing in previous works due to the use of ViT/DeiT, which has not been well tamed for down-stream tasks.
+
+***04/12/2021***
+
+Initial commits:
+
+1. Pretrained models on ImageNet-1K ([Swin-T-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth), [Swin-S-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth), [Swin-B-IN1K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)) and ImageNet-22K ([Swin-B-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth), [Swin-L-IN22K](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)) are provided.
+2. The supported code and models for ImageNet-1K image classification, COCO object detection and ADE20K semantic segmentation are provided.
+3. The cuda kernel implementation for the [local relation layer](https://arxiv.org/pdf/1904.11491.pdf) is provided in branch [LR-Net](https://github.com/microsoft/Swin-Transformer/tree/LR-Net).
+
+## Introduction
+
+**Swin Transformer** (the name `Swin` stands for **S**hifted **win**dow) is initially described in [arxiv](https://arxiv.org/abs/2103.14030), which capably serves as a
+general-purpose backbone for computer vision. It is basically a hierarchical Transformer whose representation is
+computed with shifted windows. The shifted windowing scheme brings greater efficiency by limiting self-attention
+computation to non-overlapping local windows while also allowing for cross-window connection.
+
+Swin Transformer achieves strong performance on COCO object detection (`58.7 box AP` and `51.1 mask AP` on test-dev) and
+ADE20K semantic segmentation (`53.5 mIoU` on val), surpassing previous models by a large margin.
+
+![teaser](figures/teaser.png)
+
+## Main Results on ImageNet with Pretrained Models
+
+**ImageNet-1K and ImageNet-22K Pretrained Swin-V1 Models**
+
+| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS| 22K model | 1K model |
+| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---: |:---: |
+| Swin-T | ImageNet-1K | 224x224 | 81.2 | 95.5 | 28M | 4.5G | 755 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/156nWJy4Q28rDlrX-rRbI3w)/[config](configs/swin/swin_tiny_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745562/log_swin_tiny_patch4_window7_224.txt) |
+| Swin-S | ImageNet-1K | 224x224 | 83.2 | 96.2 | 50M | 8.7G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1KFjpj3Efey3LmtE1QqPeQg)/[config](configs/swin/swin_small_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745563/log_swin_small_patch4_window7_224.txt) |
+| Swin-B | ImageNet-1K | 224x224 | 83.5 | 96.5 | 88M | 15.4G | 278  | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/16bqCTEc70nC_isSsgBSaqQ)/[config](configs/swin/swin_base_patch4_window7_224.yaml)/[log](https://github.com/SwinTransformer/storage/files/7745564/log_swin_base_patch4_window7_224.txt) |
+| Swin-B | ImageNet-1K | 384x384 | 84.5 | 97.0 | 88M | 47.1G | 85 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth)/[baidu](https://pan.baidu.com/s/1xT1cu740-ejW7htUdVLnmw)/[config](configs/swin/swin_base_patch4_window12_384_finetune.yaml) |
+| Swin-T | ImageNet-22K | 224x224 | 80.9 | 96.0 | 28M | 4.5G | 755 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1vct0VYwwQQ8PYkBjwSSBZQ?pwd=swin)/[config](configs/swin/swin_tiny_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22kto1k_finetune.pth)/[baidu](https://pan.baidu.com/s/1K0OO-nGZDPkR8fm_r83e8Q?pwd=swin)/[config](configs/swin/swin_tiny_patch4_window7_224_22kto1k_finetune.yaml) |
+| Swin-S | ImageNet-22K | 224x224 | 83.2 | 97.0 | 50M | 8.7G | 437 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/11NC1xdT5BAGBgazdTme5Sg?pwd=swin)/[config](configs/swin/swin_small_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22kto1k_finetune.pth)/[baidu](https://pan.baidu.com/s/10RFVfjQJhwPfeHrmxQUaLw?pwd=swin)/[config](configs/swin/swin_small_patch4_window7_224_22kto1k_finetune.yaml) |
+| Swin-B | ImageNet-22K | 224x224 | 85.2 | 97.5 | 88M | 15.4G | 278 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1y1Ec3UlrKSI8IMtEs-oBXA)/[config](configs/swin/swin_base_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1n_wNkcbRxVXit8r_KrfAVg)/[config](configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml) |
+| Swin-B | ImageNet-22K | 384x384 | 86.4 | 98.0 | 88M | 47.1G | 85 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1vwJxnJcVqcLZAw9HaqiR6g) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1caKTSdoLJYoi4WBcnmWuWg)/[config](configs/swin/swin_base_patch4_window12_384_22kto1k_finetune.yaml) |
+| Swin-L | ImageNet-22K | 224x224 | 86.3 | 97.9 | 197M | 34.5G | 141 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth)/[baidu](https://pan.baidu.com/s/1pws3rOTFuOebBYP3h6Kx8w)/[config](configs/swin/swin_large_patch4_window7_224_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1NkQApMWUhxBGjk1ne6VqBQ)/[config](configs/swin/swin_large_patch4_window7_224_22kto1k_finetune.yaml) |
+| Swin-L | ImageNet-22K | 384x384 | 87.3 | 98.2 | 197M | 103.9G | 42 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth)/[baidu](https://pan.baidu.com/s/1sl7o_bJA143OD7UqSLAMoA) | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth)/[baidu](https://pan.baidu.com/s/1X0FLHQyPOC6Kmv2CmgxJvA)/[config](configs/swin/swin_large_patch4_window12_384_22kto1k_finetune.yaml) |
+
+**ImageNet-1K and ImageNet-22K Pretrained Swin-V2 Models**
+
+| name | pretrain | resolution | window |acc@1 | acc@5 | #params | FLOPs | FPS |22K model | 1K model |
+|:---------------------:| :---: | :---: | :---: | :---: | :---: | :---: | :---: |:---:|:---: |:---: |
+| SwinV2-T | ImageNet-1K | 256x256 | 8x8 | 81.8 | 95.9 | 28M | 5.9G | 572 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1RzLkAH_5OtfRCJe6Vlg6rg?pwd=swin)/[config](configs/swinv2/swinv2_tiny_patch4_window8_256.yaml) |
+| SwinV2-S | ImageNet-1K | 256x256 | 8x8 | 83.7 | 96.6 | 50M | 11.5G | 327 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/195PdA41szEduW3jEtRSa4Q?pwd=swin)/[config](configs/swinv2/swinv2_small_patch4_window8_256.yaml) |
+| SwinV2-B | ImageNet-1K | 256x256 | 8x8 | 84.2 | 96.9 | 88M | 20.3G | 217 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/18AfMSz3dPyzIvP1dKuERvQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window8_256.yaml) |
+| SwinV2-T | ImageNet-1K | 256x256 | 16x16 | 82.8 | 96.2 | 28M | 6.6G | 437 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1dyK3cK9Xipmv6RnTtrPocw?pwd=swin)/[config](configs/swinv2/swinv2_tiny_patch4_window16_256.yaml) |
+| SwinV2-S | ImageNet-1K | 256x256 | 16x16 | 84.1 | 96.8 | 50M | 12.6G  | 257 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1ZIPiSfWNKTPp821Ka-Mifw?pwd=swin)/[config](configs/swinv2/swinv2_small_patch4_window16_256.yaml) |
+| SwinV2-B | ImageNet-1K | 256x256 | 16x16 | 84.6 | 97.0 | 88M | 21.8G | 174 | - | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth)/[baidu](https://pan.baidu.com/s/1dlDQGn8BXCmnh7wQSM5Nhw?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window16_256.yaml) |
+| SwinV2-B<sup>\*</sup> | ImageNet-22K | 256x256 | 16x16 | 86.2 | 97.9 |  88M | 21.8G | 174 | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/1Xc2rsSsRQz_sy5mjgfxrMQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/1sgstld4MgGsZxhUAW7MlmQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml) |
+| SwinV2-B<sup>\*</sup> | ImageNet-22K | 384x384 | 24x24 | 87.1 | 98.2 | 88M | 54.7G | 57  | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/1Xc2rsSsRQz_sy5mjgfxrMQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/17u3sEQaUYlvfL195rrORzQ?pwd=swin)/[config](configs/swinv2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml) |
+| SwinV2-L<sup>\*</sup> | ImageNet-22K | 256x256 | 16x16 | 86.9 | 98.0 | 197M | 47.5G | 95  | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/11PhCV7qAGXtZ8dXNgyiGOw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/1pqp31N80qIWjFPbudzB6Bw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml) |
+| SwinV2-L<sup>\*</sup> | ImageNet-22K | 384x384 | 24x24 | 87.6 | 98.3 | 197M | 115.4G | 33  | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth)/[baidu](https://pan.baidu.com/s/11PhCV7qAGXtZ8dXNgyiGOw?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml) | [github](https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth)/[baidu](https://pan.baidu.com/s/13URdNkygr3Xn0N3e6IwjgA?pwd=swin)/[config](configs/swinv2/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml) |
+
+Note: 
+- SwinV2-B<sup>\*</sup>  (SwinV2-L<sup>\*</sup>) with input resolution of 256x256 and 384x384 both fine-tuned from the same pre-training model using a smaller input resolution of 192x192.
+- SwinV2-B<sup>\*</sup> (384x384) achieves 78.08 acc@1 on ImageNet-1K-V2 while SwinV2-L<sup>\*</sup> (384x384) achieves 78.31.
+
+**ImageNet-1K Pretrained Swin MLP Models**
+
+| name | pretrain | resolution |acc@1 | acc@5 | #params | FLOPs | FPS |  1K model |
+| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
+| [Mixer-B/16](https://arxiv.org/pdf/2105.01601.pdf) | ImageNet-1K | 224x224 | 76.4 | - | 59M | 12.7G | - | [official repo](https://github.com/google-research/vision_transformer) |
+| [ResMLP-S24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 79.4 | - | 30M | 6.0G | 715 | [timm](https://github.com/rwightman/pytorch-image-models) |
+| [ResMLP-B24](https://arxiv.org/abs/2105.03404) | ImageNet-1K | 224x224 | 81.0 | - | 116M | 23.0G |  231 | [timm](https://github.com/rwightman/pytorch-image-models) |
+| Swin-T/C24 | ImageNet-1K | 256x256 | 81.6 | 95.7 | 28M | 5.9G | 563 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/17k-7l6Sxt7uZ7IV0f26GNQ)/[config](configs/swin/swin_tiny_c24_patch4_window8_256.yaml) |
+| SwinMLP-T/C24 | ImageNet-1K | 256x256 | 79.4 | 94.6 | 20M | 4.0G | 807 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c24_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1Sa4vP5R0M2RjfIe9HIga-Q)/[config](configs/swin/swin_mlp_tiny_c24_patch4_window8_256.yaml) |
+| SwinMLP-T/C12 | ImageNet-1K | 256x256 | 79.6 | 94.7 | 21M | 4.0G | 792 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c12_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1mM9J2_DEVZHUB5ASIpFl0w)/[config](configs/swin/swin_mlp_tiny_c12_patch4_window8_256.yaml) |
+| SwinMLP-T/C6 | ImageNet-1K | 256x256 | 79.7 | 94.9 | 23M | 4.0G | 766 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_tiny_c6_patch4_window8_256.pth)/[baidu](https://pan.baidu.com/s/1hUTYVT2W1CsjICw-3W-Vjg)/[config](configs/swin/swin_mlp_tiny_c6_patch4_window8_256.yaml) |
+| SwinMLP-B | ImageNet-1K | 224x224 | 81.3 | 95.3 | 61M | 10.4G | 409 | [github](https://github.com/SwinTransformer/storage/releases/download/v1.0.5/swin_mlp_base_patch4_window7_224.pth)/[baidu](https://pan.baidu.com/s/1zww3dnbX3GxNiGfb-GwyUg)/[config](configs/swin/swin_mlp_base_patch4_window7_224.yaml) |
+
+Note: access code for `baidu` is `swin`. C24 means each head has 24 channels.
+
+**ImageNet-22K Pretrained Swin-MoE Models**
+
+- Please refer to [get_started](get_started.md#mixture-of-experts-support) for instructions on running Swin-MoE. 
+- Pretrained models for Swin-MoE can be found in [MODEL HUB](MODELHUB.md#imagenet-22k-pretrained-swin-moe-models)
+
+## Main Results on Downstream Tasks
+
+**COCO Object Detection (2017 val)**
+
+| Backbone | Method | pretrain | Lr Schd | box mAP | mask mAP | #params | FLOPs |
+| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
+| Swin-T | Mask R-CNN | ImageNet-1K | 3x | 46.0 | 41.6 | 48M | 267G |
+| Swin-S | Mask R-CNN | ImageNet-1K | 3x | 48.5 | 43.3 | 69M | 359G |
+| Swin-T | Cascade Mask R-CNN | ImageNet-1K | 3x | 50.4 | 43.7 | 86M | 745G |
+| Swin-S | Cascade Mask R-CNN | ImageNet-1K |  3x | 51.9 | 45.0 | 107M | 838G |
+| Swin-B | Cascade Mask R-CNN | ImageNet-1K |  3x | 51.9 | 45.0 | 145M | 982G |
+| Swin-T | RepPoints V2 | ImageNet-1K | 3x | 50.0 | - | 45M | 283G |
+| Swin-T | Mask RepPoints V2 | ImageNet-1K | 3x | 50.3 | 43.6 | 47M | 292G |
+| Swin-B | HTC++ | ImageNet-22K | 6x | 56.4 | 49.1 | 160M | 1043G |
+| Swin-L | HTC++ | ImageNet-22K | 3x | 57.1 | 49.5 | 284M | 1470G |
+| Swin-L | HTC++<sup>*</sup> | ImageNet-22K | 3x | 58.0 | 50.4 | 284M | - |
+
+Note: <sup>*</sup> indicates multi-scale testing.
+
+**ADE20K Semantic Segmentation (val)**
+
+| Backbone | Method | pretrain | Crop Size | Lr Schd | mIoU | mIoU (ms+flip) | #params | FLOPs |
+| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
+| Swin-T | UPerNet | ImageNet-1K | 512x512 | 160K | 44.51 | 45.81 | 60M | 945G |
+| Swin-S | UperNet | ImageNet-1K | 512x512 | 160K | 47.64 | 49.47 | 81M | 1038G |
+| Swin-B | UperNet | ImageNet-1K | 512x512 | 160K | 48.13 | 49.72 | 121M | 1188G |
+| Swin-B | UPerNet | ImageNet-22K | 640x640 | 160K | 50.04 | 51.66 | 121M | 1841G |
+| Swin-L | UperNet | ImageNet-22K | 640x640 | 160K | 52.05 | 53.53 | 234M | 3230G |
+
+## Citing Swin Transformer
+
+```
+@inproceedings{liu2021Swin,
+  title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
+  author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
+  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
+  year={2021}
+}
+```
+## Citing Local Relation Networks (the first full-attention visual backbone)
+```
+@inproceedings{hu2019local,
+  title={Local Relation Networks for Image Recognition},
+  author={Hu, Han and Zhang, Zheng and Xie, Zhenda and Lin, Stephen},
+  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
+  pages={3464--3473},
+  year={2019}
+}
+```
+## Citing Swin Transformer V2
+```
+@inproceedings{liu2021swinv2,
+  title={Swin Transformer V2: Scaling Up Capacity and Resolution}, 
+  author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
+  booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)},
+  year={2022}
+}
+```
+## Citing SimMIM (a self-supervised approach that enables SwinV2-G)
+```
+@inproceedings{xie2021simmim,
+  title={SimMIM: A Simple Framework for Masked Image Modeling},
+  author={Xie, Zhenda and Zhang, Zheng and Cao, Yue and Lin, Yutong and Bao, Jianmin and Yao, Zhuliang and Dai, Qi and Hu, Han},
+  booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)},
+  year={2022}
+}
+```
+## Citing SimMIM-data-scaling
+```
+@article{xie2022data,
+  title={On Data Scaling in Masked Image Modeling},
+  author={Xie, Zhenda and Zhang, Zheng and Cao, Yue and Lin, Yutong and Wei, Yixuan and Dai, Qi and Hu, Han},
+  journal={arXiv preprint arXiv:2206.04664},
+  year={2022}
+}
+```
+## Citing Swin-MoE
+```
+@misc{hwang2022tutel,
+      title={Tutel: Adaptive Mixture-of-Experts at Scale}, 
+      author={Changho Hwang and Wei Cui and Yifan Xiong and Ziyue Yang and Ze Liu and Han Hu and Zilong Wang and Rafael Salas and Jithin Jose and Prabhat Ram and Joe Chau and Peng Cheng and Fan Yang and Mao Yang and Yongqiang Xiong},
+      year={2022},
+      eprint={2206.03382},
+      archivePrefix={arXiv}
+}
+```
+
+## Getting Started
+
+- For **Image Classification**, please see [get_started.md](get_started.md) for detailed instructions.
+- For **Object Detection and Instance Segmentation**, please see [Swin Transformer for Object Detection](https://github.com/SwinTransformer/Swin-Transformer-Object-Detection).
+- For **Semantic Segmentation**, please see [Swin Transformer for Semantic Segmentation](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation).
+- For **Self-Supervised Learning**, please see [Transformer-SSL](https://github.com/SwinTransformer/Transformer-SSL).
+- For **Video Recognition**, please see [Video Swin Transformer](https://github.com/SwinTransformer/Video-Swin-Transformer).
+
+## Third-party Usage and Experiments
+
+***In this pargraph, we cross link third-party repositories which use Swin and report results. You can let us know by raising an issue*** 
+
+(`Note please report accuracy numbers and provide trained models in your new repository to facilitate others to get sense of correctness and model behavior`)
+
+[12/29/2022] Swin Transformers (V2) inference implemented in FasterTransformer: [FasterTransformer](https://github.com/NVIDIA/FasterTransformer/blob/main/docs/swin_guide.md)
+
+[06/30/2022] Swin Transformers (V1) inference implemented in FasterTransformer: [FasterTransformer](https://github.com/NVIDIA/FasterTransformer/blob/main/docs/swin_guide.md)
+
+[05/12/2022] Swin Transformers (V1) implemented in TensorFlow with the pre-trained parameters ported into them. Find the implementation,
+TensorFlow weights, code example here in [this repository](https://github.com/sayakpaul/swin-transformers-tf/).
+
+[04/06/2022] Swin Transformer for Audio Classification: [Hierarchical Token Semantic Audio Transformer](https://github.com/RetroCirce/HTS-Audio-Transformer).
+
+[12/21/2021] Swin Transformer for StyleGAN: [StyleSwin](https://github.com/microsoft/StyleSwin)
+
+[12/13/2021] Swin Transformer for Face Recognition: [FaceX-Zoo](https://github.com/JDAI-CV/FaceX-Zoo)
+
+[08/29/2021] Swin Transformer for Image Restoration: [SwinIR](https://github.com/JingyunLiang/SwinIR)
+
+[08/12/2021] Swin Transformer for person reID: [https://github.com/layumi/Person_reID_baseline_pytorch](https://github.com/layumi/Person_reID_baseline_pytorch)
+
+[06/29/2021] Swin-Transformer in PaddleClas and inference based on whl package: [https://github.com/PaddlePaddle/PaddleClas](https://github.com/PaddlePaddle/PaddleClas)
+
+[04/14/2021] Swin for RetinaNet in Detectron: https://github.com/xiaohu2015/SwinT_detectron2.
+
+[04/16/2021] Included in a famous model zoo: https://github.com/rwightman/pytorch-image-models.
+
+[04/20/2021] Swin-Transformer classifier inference using TorchServe: https://github.com/kamalkraj/Swin-Transformer-Serve
+
+## Contributing
+
+This project welcomes contributions and suggestions.  Most contributions require you to agree to a
+Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
+the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
+
+When you submit a pull request, a CLA bot will automatically determine whether you need to provide
+a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
+provided by the bot. You will only need to do this once across all repos using our CLA.
+
+This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
+For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
+contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
+
+## Trademarks
+
+This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 
+trademarks or logos is subject to and must follow 
+[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
+Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
+Any use of third-party trademarks or logos are subject to those third-party's policies.

+ 41 - 0
lib/SwinTransformer/SECURITY.md

@@ -0,0 +1,41 @@
+<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->
+
+## Security
+
+Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
+
+If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
+
+## Reporting Security Issues
+
+**Please do not report security vulnerabilities through public GitHub issues.**
+
+Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
+
+If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com).  If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
+
+You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 
+
+Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
+
+  * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
+  * Full paths of source file(s) related to the manifestation of the issue
+  * The location of the affected source code (tag/branch/commit or direct URL)
+  * Any special configuration required to reproduce the issue
+  * Step-by-step instructions to reproduce the issue
+  * Proof-of-concept or exploit code (if possible)
+  * Impact of the issue, including how an attacker might exploit the issue
+
+This information will help us triage your report more quickly.
+
+If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
+
+## Preferred Languages
+
+We prefer all communications to be in English.
+
+## Policy
+
+Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
+
+<!-- END MICROSOFT SECURITY.MD BLOCK -->

+ 25 - 0
lib/SwinTransformer/SUPPORT.md

@@ -0,0 +1,25 @@
+# TODO: The maintainer of this repo has not yet edited this file
+
+**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
+
+- **No CSS support:** Fill out this template with information about how to file issues and get help.
+- **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).
+- **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide.
+
+*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
+
+# Support
+
+## How to file issues and get help  
+
+This project uses GitHub Issues to track bugs and feature requests. Please search the existing 
+issues before filing new issues to avoid duplicates.  For new issues, file your bug or 
+feature request as a new Issue.
+
+For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 
+FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
+CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
+
+## Microsoft Support Policy  
+
+Support for this **PROJECT or PRODUCT** is limited to the resources listed above.

+ 3 - 0
lib/SwinTransformer/__init__.py

@@ -0,0 +1,3 @@
+from .models.swin_transformer_v2 import SwinTransformerV2, SwinTransformerBlock
+
+__all__ = ["SwinTransformerV2", "SwinTransformerBlock"]

+ 359 - 0
lib/SwinTransformer/config.py

@@ -0,0 +1,359 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------'
+
+import os
+import torch
+import yaml
+from yacs.config import CfgNode as CN
+
+# pytorch major version (1.x or 2.x)
+PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
+
+_C = CN()
+
+# Base config files
+_C.BASE = ['']
+
+# -----------------------------------------------------------------------------
+# Data settings
+# -----------------------------------------------------------------------------
+_C.DATA = CN()
+# Batch size for a single GPU, could be overwritten by command line argument
+_C.DATA.BATCH_SIZE = 128
+# Path to dataset, could be overwritten by command line argument
+_C.DATA.DATA_PATH = ''
+# Dataset name
+_C.DATA.DATASET = 'imagenet'
+# Input image size
+_C.DATA.IMG_SIZE = 224
+# Interpolation to resize image (random, bilinear, bicubic)
+_C.DATA.INTERPOLATION = 'bicubic'
+# Use zipped dataset instead of folder dataset
+# could be overwritten by command line argument
+_C.DATA.ZIP_MODE = False
+# Cache Data in Memory, could be overwritten by command line argument
+_C.DATA.CACHE_MODE = 'part'
+# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
+_C.DATA.PIN_MEMORY = True
+# Number of data loading threads
+_C.DATA.NUM_WORKERS = 8
+
+# [SimMIM] Mask patch size for MaskGenerator
+_C.DATA.MASK_PATCH_SIZE = 32
+# [SimMIM] Mask ratio for MaskGenerator
+_C.DATA.MASK_RATIO = 0.6
+
+# -----------------------------------------------------------------------------
+# Model settings
+# -----------------------------------------------------------------------------
+_C.MODEL = CN()
+# Model type
+_C.MODEL.TYPE = 'swin'
+# Model name
+_C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
+# Pretrained weight from checkpoint, could be imagenet22k pretrained weight
+# could be overwritten by command line argument
+_C.MODEL.PRETRAINED = ''
+# Checkpoint to resume, could be overwritten by command line argument
+_C.MODEL.RESUME = ''
+# Number of classes, overwritten in data preparation
+_C.MODEL.NUM_CLASSES = 1000
+# Dropout rate
+_C.MODEL.DROP_RATE = 0.0
+# Drop path rate
+_C.MODEL.DROP_PATH_RATE = 0.1
+# Label Smoothing
+_C.MODEL.LABEL_SMOOTHING = 0.1
+
+# Swin Transformer parameters
+_C.MODEL.SWIN = CN()
+_C.MODEL.SWIN.PATCH_SIZE = 4
+_C.MODEL.SWIN.IN_CHANS = 3
+_C.MODEL.SWIN.EMBED_DIM = 96
+_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
+_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
+_C.MODEL.SWIN.WINDOW_SIZE = 7
+_C.MODEL.SWIN.MLP_RATIO = 4.
+_C.MODEL.SWIN.QKV_BIAS = True
+_C.MODEL.SWIN.QK_SCALE = None
+_C.MODEL.SWIN.APE = False
+_C.MODEL.SWIN.PATCH_NORM = True
+
+# Swin Transformer V2 parameters
+_C.MODEL.SWINV2 = CN()
+_C.MODEL.SWINV2.PATCH_SIZE = 4
+_C.MODEL.SWINV2.IN_CHANS = 3
+_C.MODEL.SWINV2.EMBED_DIM = 96
+_C.MODEL.SWINV2.DEPTHS = [2, 2, 6, 2]
+_C.MODEL.SWINV2.NUM_HEADS = [3, 6, 12, 24]
+_C.MODEL.SWINV2.WINDOW_SIZE = 7
+_C.MODEL.SWINV2.MLP_RATIO = 4.
+_C.MODEL.SWINV2.QKV_BIAS = True
+_C.MODEL.SWINV2.APE = False
+_C.MODEL.SWINV2.PATCH_NORM = True
+_C.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0]
+
+# Swin Transformer MoE parameters
+_C.MODEL.SWIN_MOE = CN()
+_C.MODEL.SWIN_MOE.PATCH_SIZE = 4
+_C.MODEL.SWIN_MOE.IN_CHANS = 3
+_C.MODEL.SWIN_MOE.EMBED_DIM = 96
+_C.MODEL.SWIN_MOE.DEPTHS = [2, 2, 6, 2]
+_C.MODEL.SWIN_MOE.NUM_HEADS = [3, 6, 12, 24]
+_C.MODEL.SWIN_MOE.WINDOW_SIZE = 7
+_C.MODEL.SWIN_MOE.MLP_RATIO = 4.
+_C.MODEL.SWIN_MOE.QKV_BIAS = True
+_C.MODEL.SWIN_MOE.QK_SCALE = None
+_C.MODEL.SWIN_MOE.APE = False
+_C.MODEL.SWIN_MOE.PATCH_NORM = True
+_C.MODEL.SWIN_MOE.MLP_FC2_BIAS = True
+_C.MODEL.SWIN_MOE.INIT_STD = 0.02
+_C.MODEL.SWIN_MOE.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0]
+_C.MODEL.SWIN_MOE.MOE_BLOCKS = [[-1], [-1], [-1], [-1]]
+_C.MODEL.SWIN_MOE.NUM_LOCAL_EXPERTS = 1
+_C.MODEL.SWIN_MOE.TOP_VALUE = 1
+_C.MODEL.SWIN_MOE.CAPACITY_FACTOR = 1.25
+_C.MODEL.SWIN_MOE.COSINE_ROUTER = False
+_C.MODEL.SWIN_MOE.NORMALIZE_GATE = False
+_C.MODEL.SWIN_MOE.USE_BPR = True
+_C.MODEL.SWIN_MOE.IS_GSHARD_LOSS = False
+_C.MODEL.SWIN_MOE.GATE_NOISE = 1.0
+_C.MODEL.SWIN_MOE.COSINE_ROUTER_DIM = 256
+_C.MODEL.SWIN_MOE.COSINE_ROUTER_INIT_T = 0.5
+_C.MODEL.SWIN_MOE.MOE_DROP = 0.0
+_C.MODEL.SWIN_MOE.AUX_LOSS_WEIGHT = 0.01
+
+# Swin MLP parameters
+_C.MODEL.SWIN_MLP = CN()
+_C.MODEL.SWIN_MLP.PATCH_SIZE = 4
+_C.MODEL.SWIN_MLP.IN_CHANS = 3
+_C.MODEL.SWIN_MLP.EMBED_DIM = 96
+_C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2]
+_C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24]
+_C.MODEL.SWIN_MLP.WINDOW_SIZE = 7
+_C.MODEL.SWIN_MLP.MLP_RATIO = 4.
+_C.MODEL.SWIN_MLP.APE = False
+_C.MODEL.SWIN_MLP.PATCH_NORM = True
+
+# [SimMIM] Norm target during training
+_C.MODEL.SIMMIM = CN()
+_C.MODEL.SIMMIM.NORM_TARGET = CN()
+_C.MODEL.SIMMIM.NORM_TARGET.ENABLE = False
+_C.MODEL.SIMMIM.NORM_TARGET.PATCH_SIZE = 47
+
+# -----------------------------------------------------------------------------
+# Training settings
+# -----------------------------------------------------------------------------
+_C.TRAIN = CN()
+_C.TRAIN.START_EPOCH = 0
+_C.TRAIN.EPOCHS = 300
+_C.TRAIN.WARMUP_EPOCHS = 20
+_C.TRAIN.WEIGHT_DECAY = 0.05
+_C.TRAIN.BASE_LR = 5e-4
+_C.TRAIN.WARMUP_LR = 5e-7
+_C.TRAIN.MIN_LR = 5e-6
+# Clip gradient norm
+_C.TRAIN.CLIP_GRAD = 5.0
+# Auto resume from latest checkpoint
+_C.TRAIN.AUTO_RESUME = True
+# Gradient accumulation steps
+# could be overwritten by command line argument
+_C.TRAIN.ACCUMULATION_STEPS = 1
+# Whether to use gradient checkpointing to save memory
+# could be overwritten by command line argument
+_C.TRAIN.USE_CHECKPOINT = False
+
+# LR scheduler
+_C.TRAIN.LR_SCHEDULER = CN()
+_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
+# Epoch interval to decay LR, used in StepLRScheduler
+_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
+# LR decay rate, used in StepLRScheduler
+_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
+# warmup_prefix used in CosineLRScheduler
+_C.TRAIN.LR_SCHEDULER.WARMUP_PREFIX = True
+# [SimMIM] Gamma / Multi steps value, used in MultiStepLRScheduler
+_C.TRAIN.LR_SCHEDULER.GAMMA = 0.1
+_C.TRAIN.LR_SCHEDULER.MULTISTEPS = []
+
+# Optimizer
+_C.TRAIN.OPTIMIZER = CN()
+_C.TRAIN.OPTIMIZER.NAME = 'adamw'
+# Optimizer Epsilon
+_C.TRAIN.OPTIMIZER.EPS = 1e-8
+# Optimizer Betas
+_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
+# SGD momentum
+_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
+
+# [SimMIM] Layer decay for fine-tuning
+_C.TRAIN.LAYER_DECAY = 1.0
+
+# MoE
+_C.TRAIN.MOE = CN()
+# Only save model on master device
+_C.TRAIN.MOE.SAVE_MASTER = False
+# -----------------------------------------------------------------------------
+# Augmentation settings
+# -----------------------------------------------------------------------------
+_C.AUG = CN()
+# Color jitter factor
+_C.AUG.COLOR_JITTER = 0.4
+# Use AutoAugment policy. "v0" or "original"
+_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
+# Random erase prob
+_C.AUG.REPROB = 0.25
+# Random erase mode
+_C.AUG.REMODE = 'pixel'
+# Random erase count
+_C.AUG.RECOUNT = 1
+# Mixup alpha, mixup enabled if > 0
+_C.AUG.MIXUP = 0.8
+# Cutmix alpha, cutmix enabled if > 0
+_C.AUG.CUTMIX = 1.0
+# Cutmix min/max ratio, overrides alpha and enables cutmix if set
+_C.AUG.CUTMIX_MINMAX = None
+# Probability of performing mixup or cutmix when either/both is enabled
+_C.AUG.MIXUP_PROB = 1.0
+# Probability of switching to cutmix when both mixup and cutmix enabled
+_C.AUG.MIXUP_SWITCH_PROB = 0.5
+# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
+_C.AUG.MIXUP_MODE = 'batch'
+
+# -----------------------------------------------------------------------------
+# Testing settings
+# -----------------------------------------------------------------------------
+_C.TEST = CN()
+# Whether to use center crop when testing
+_C.TEST.CROP = True
+# Whether to use SequentialSampler as validation sampler
+_C.TEST.SEQUENTIAL = False
+_C.TEST.SHUFFLE = False
+
+# -----------------------------------------------------------------------------
+# Misc
+# -----------------------------------------------------------------------------
+# [SimMIM] Whether to enable pytorch amp, overwritten by command line argument
+_C.ENABLE_AMP = False
+
+# Enable Pytorch automatic mixed precision (amp).
+_C.AMP_ENABLE = True
+# [Deprecated] Mixed precision opt level of apex, if O0, no apex amp is used ('O0', 'O1', 'O2')
+_C.AMP_OPT_LEVEL = ''
+# Path to output folder, overwritten by command line argument
+_C.OUTPUT = ''
+# Tag of experiment, overwritten by command line argument
+_C.TAG = 'default'
+# Frequency to save checkpoint
+_C.SAVE_FREQ = 1
+# Frequency to logging info
+_C.PRINT_FREQ = 10
+# Fixed random seed
+_C.SEED = 0
+# Perform evaluation only, overwritten by command line argument
+_C.EVAL_MODE = False
+# Test throughput only, overwritten by command line argument
+_C.THROUGHPUT_MODE = False
+# local rank for DistributedDataParallel, given by command line argument
+_C.LOCAL_RANK = 0
+# for acceleration
+_C.FUSED_WINDOW_PROCESS = False
+_C.FUSED_LAYERNORM = False
+
+
+def _update_config_from_file(config, cfg_file):
+    config.defrost()
+    with open(cfg_file, 'r') as f:
+        yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
+
+    for cfg in yaml_cfg.setdefault('BASE', ['']):
+        if cfg:
+            _update_config_from_file(
+                config, os.path.join(os.path.dirname(cfg_file), cfg)
+            )
+    print('=> merge config from {}'.format(cfg_file))
+    config.merge_from_file(cfg_file)
+    config.freeze()
+
+
+def update_config(config, args):
+    _update_config_from_file(config, args.cfg)
+
+    config.defrost()
+    if args.opts:
+        config.merge_from_list(args.opts)
+
+    def _check_args(name):
+        if hasattr(args, name) and eval(f'args.{name}'):
+            return True
+        return False
+
+    # merge from specific arguments
+    if _check_args('batch_size'):
+        config.DATA.BATCH_SIZE = args.batch_size
+    if _check_args('data_path'):
+        config.DATA.DATA_PATH = args.data_path
+    if _check_args('zip'):
+        config.DATA.ZIP_MODE = True
+    if _check_args('cache_mode'):
+        config.DATA.CACHE_MODE = args.cache_mode
+    if _check_args('pretrained'):
+        config.MODEL.PRETRAINED = args.pretrained
+    if _check_args('resume'):
+        config.MODEL.RESUME = args.resume
+    if _check_args('accumulation_steps'):
+        config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
+    if _check_args('use_checkpoint'):
+        config.TRAIN.USE_CHECKPOINT = True
+    if _check_args('amp_opt_level'):
+        print("[warning] Apex amp has been deprecated, please use pytorch amp instead!")
+        if args.amp_opt_level == 'O0':
+            config.AMP_ENABLE = False
+    if _check_args('disable_amp'):
+        config.AMP_ENABLE = False
+    if _check_args('output'):
+        config.OUTPUT = args.output
+    if _check_args('tag'):
+        config.TAG = args.tag
+    if _check_args('eval'):
+        config.EVAL_MODE = True
+    if _check_args('throughput'):
+        config.THROUGHPUT_MODE = True
+
+    # [SimMIM]
+    if _check_args('enable_amp'):
+        config.ENABLE_AMP = args.enable_amp
+
+    # for acceleration
+    if _check_args('fused_window_process'):
+        config.FUSED_WINDOW_PROCESS = True
+    if _check_args('fused_layernorm'):
+        config.FUSED_LAYERNORM = True
+    ## Overwrite optimizer if not None, currently we use it for [fused_adam, fused_lamb]
+    if _check_args('optim'):
+        config.TRAIN.OPTIMIZER.NAME = args.optim
+
+    # set local rank for distributed training
+    if PYTORCH_MAJOR_VERSION == 1:
+        config.LOCAL_RANK = args.local_rank
+    else:
+        config.LOCAL_RANK = int(os.environ['LOCAL_RANK'])
+
+    # output folder
+    config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
+
+    config.freeze()
+
+
+def get_config(args):
+    """Get a yacs CfgNode object with default values."""
+    # Return a clone so that the defaults will not be altered
+    # This is for the "local variable" use pattern
+    config = _C.clone()
+    update_config(config, args)
+
+    return config

+ 22 - 0
lib/SwinTransformer/configs/simmim/simmim_finetune__swin_base__img224_window7__800ep.yaml

@@ -0,0 +1,22 @@
+MODEL:
+  TYPE: swin
+  NAME: simmim_finetune
+  DROP_PATH_RATE: 0.1
+  SWIN:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 7
+DATA:
+  IMG_SIZE: 224
+TRAIN:
+  EPOCHS: 100
+  WARMUP_EPOCHS: 20
+  BASE_LR: 1.25e-3
+  WARMUP_LR: 2.5e-7
+  MIN_LR: 2.5e-7
+  WEIGHT_DECAY: 0.05
+  LAYER_DECAY: 0.8
+PRINT_FREQ: 100
+SAVE_FREQ: 5
+TAG: simmim_finetune__swin_base__img224_window7__800ep

+ 23 - 0
lib/SwinTransformer/configs/simmim/simmim_finetune__swinv2_base__img224_window14__800ep.yaml

@@ -0,0 +1,23 @@
+MODEL:
+  TYPE: swinv2
+  NAME: simmim_finetune
+  DROP_PATH_RATE: 0.1
+  SWINV2:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 14
+    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
+DATA:
+  IMG_SIZE: 224
+TRAIN:
+  EPOCHS: 100
+  WARMUP_EPOCHS: 20
+  BASE_LR: 1.25e-3
+  WARMUP_LR: 2.5e-7
+  MIN_LR: 2.5e-7
+  WEIGHT_DECAY: 0.05
+  LAYER_DECAY: 0.75
+PRINT_FREQ: 100
+SAVE_FREQ: 5
+TAG: simmim_finetune__swinv2_base__img224_window14__800ep

+ 26 - 0
lib/SwinTransformer/configs/simmim/simmim_pretrain__swin_base__img192_window6__800ep.yaml

@@ -0,0 +1,26 @@
+MODEL:
+  TYPE: swin
+  NAME: simmim_pretrain
+  DROP_PATH_RATE: 0.0
+  SWIN:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 6
+DATA:
+  IMG_SIZE: 192
+  MASK_PATCH_SIZE: 32
+  MASK_RATIO: 0.6
+TRAIN:
+  EPOCHS: 800
+  WARMUP_EPOCHS: 10
+  BASE_LR: 1e-4
+  WARMUP_LR: 5e-7
+  WEIGHT_DECAY: 0.05
+  LR_SCHEDULER:
+    NAME: 'multistep'
+    GAMMA: 0.1
+    MULTISTEPS: [700,]
+PRINT_FREQ: 100
+SAVE_FREQ: 5
+TAG: simmim_pretrain__swin_base__img192_window6__800ep

+ 30 - 0
lib/SwinTransformer/configs/simmim/simmim_pretrain__swinv2_base__img192_window12__800ep.yaml

@@ -0,0 +1,30 @@
+MODEL:
+  TYPE: swinv2
+  NAME: simmim_pretrain
+  DROP_PATH_RATE: 0.1
+  SIMMIM:
+    NORM_TARGET:
+      ENABLE: True
+      PATCH_SIZE: 47
+  SWINV2:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 12
+DATA:
+  IMG_SIZE: 192
+  MASK_PATCH_SIZE: 32
+  MASK_RATIO: 0.6
+TRAIN:
+  EPOCHS: 800
+  WARMUP_EPOCHS: 10
+  BASE_LR: 1e-4
+  WARMUP_LR: 5e-7
+  WEIGHT_DECAY: 0.05
+  LR_SCHEDULER:
+    NAME: 'multistep'
+    GAMMA: 0.1
+    MULTISTEPS: [700,]
+PRINT_FREQ: 100
+SAVE_FREQ: 5
+TAG: simmim_pretrain__swinv2_base__img192_window12__800ep

+ 20 - 0
lib/SwinTransformer/configs/swin/swin_base_patch4_window12_384_22kto1k_finetune.yaml

@@ -0,0 +1,20 @@
+DATA:
+  IMG_SIZE: 384
+MODEL:
+  TYPE: swin
+  NAME: swin_base_patch4_window12_384_22kto1k_finetune
+  DROP_PATH_RATE: 0.2
+  SWIN:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 12
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
+  MIN_LR: 2e-07
+TEST:
+  CROP: False

+ 20 - 0
lib/SwinTransformer/configs/swin/swin_base_patch4_window12_384_finetune.yaml

@@ -0,0 +1,20 @@
+DATA:
+  IMG_SIZE: 384
+MODEL:
+  TYPE: swin
+  NAME: swin_base_patch4_window12_384_finetune
+  DROP_PATH_RATE: 0.5
+  SWIN:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 12
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
+  MIN_LR: 2e-07
+TEST:
+  CROP: False

+ 9 - 0
lib/SwinTransformer/configs/swin/swin_base_patch4_window7_224.yaml

@@ -0,0 +1,9 @@
+MODEL:
+  TYPE: swin
+  NAME: swin_base_patch4_window7_224
+  DROP_PATH_RATE: 0.5
+  SWIN:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 7

+ 18 - 0
lib/SwinTransformer/configs/swin/swin_base_patch4_window7_224_22k.yaml

@@ -0,0 +1,18 @@
+DATA:
+  DATASET: imagenet22K
+MODEL:
+  TYPE: swin
+  NAME: swin_base_patch4_window7_224_22k
+  DROP_PATH_RATE: 0.2
+  SWIN:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 7
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 0.05
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6

+ 16 - 0
lib/SwinTransformer/configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml

@@ -0,0 +1,16 @@
+MODEL:
+  TYPE: swin
+  NAME: swin_base_patch4_window7_224_22kto1k_finetune
+  DROP_PATH_RATE: 0.2
+  SWIN:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 7
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
+  MIN_LR: 2e-07

+ 20 - 0
lib/SwinTransformer/configs/swin/swin_large_patch4_window12_384_22kto1k_finetune.yaml

@@ -0,0 +1,20 @@
+DATA:
+  IMG_SIZE: 384
+MODEL:
+  TYPE: swin
+  NAME: swin_large_patch4_window12_384_22kto1k_finetune
+  DROP_PATH_RATE: 0.2
+  SWIN:
+    EMBED_DIM: 192
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 6, 12, 24, 48 ]
+    WINDOW_SIZE: 12
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
+  MIN_LR: 2e-07
+TEST:
+  CROP: False

+ 18 - 0
lib/SwinTransformer/configs/swin/swin_large_patch4_window7_224_22k.yaml

@@ -0,0 +1,18 @@
+DATA:
+  DATASET: imagenet22K
+MODEL:
+  TYPE: swin
+  NAME: swin_large_patch4_window7_224_22k
+  DROP_PATH_RATE: 0.2
+  SWIN:
+    EMBED_DIM: 192
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 6, 12, 24, 48 ]
+    WINDOW_SIZE: 7
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 0.05
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6

+ 16 - 0
lib/SwinTransformer/configs/swin/swin_large_patch4_window7_224_22kto1k_finetune.yaml

@@ -0,0 +1,16 @@
+MODEL:
+  TYPE: swin
+  NAME: swin_large_patch4_window7_224_22kto1k_finetune
+  DROP_PATH_RATE: 0.2
+  SWIN:
+    EMBED_DIM: 192
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 6, 12, 24, 48 ]
+    WINDOW_SIZE: 7
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
+  MIN_LR: 2e-07

+ 9 - 0
lib/SwinTransformer/configs/swin/swin_small_patch4_window7_224.yaml

@@ -0,0 +1,9 @@
+MODEL:
+  TYPE: swin
+  NAME: swin_small_patch4_window7_224
+  DROP_PATH_RATE: 0.3
+  SWIN:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 7

+ 18 - 0
lib/SwinTransformer/configs/swin/swin_small_patch4_window7_224_22k.yaml

@@ -0,0 +1,18 @@
+DATA:
+  DATASET: imagenet22K
+MODEL:
+  TYPE: swin
+  NAME: swin_small_patch4_window7_224_22k
+  DROP_PATH_RATE: 0.2
+  SWIN:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 7
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 0.05
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6

+ 16 - 0
lib/SwinTransformer/configs/swin/swin_small_patch4_window7_224_22kto1k_finetune.yaml

@@ -0,0 +1,16 @@
+MODEL:
+  TYPE: swin
+  NAME: swin_small_patch4_window7_224_22kto1k_finetune
+  DROP_PATH_RATE: 0.2
+  SWIN:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 7
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
+  MIN_LR: 2e-07

+ 11 - 0
lib/SwinTransformer/configs/swin/swin_tiny_c24_patch4_window8_256.yaml

@@ -0,0 +1,11 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swin
+  NAME: swin_tiny_c24_patch4_window8_256
+  DROP_PATH_RATE: 0.2
+  SWIN:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 6, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 8

+ 9 - 0
lib/SwinTransformer/configs/swin/swin_tiny_patch4_window7_224.yaml

@@ -0,0 +1,9 @@
+MODEL:
+  TYPE: swin
+  NAME: swin_tiny_patch4_window7_224
+  DROP_PATH_RATE: 0.2
+  SWIN:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 6, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 7

+ 18 - 0
lib/SwinTransformer/configs/swin/swin_tiny_patch4_window7_224_22k.yaml

@@ -0,0 +1,18 @@
+DATA:
+  DATASET: imagenet22K
+MODEL:
+  TYPE: swin
+  NAME: swin_tiny_patch4_window7_224_22k
+  DROP_PATH_RATE: 0.1
+  SWIN:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 6, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 7
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 0.05
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6

+ 16 - 0
lib/SwinTransformer/configs/swin/swin_tiny_patch4_window7_224_22kto1k_finetune.yaml

@@ -0,0 +1,16 @@
+MODEL:
+  TYPE: swin
+  NAME: swin_tiny_patch4_window7_224_22kto1k_finetune
+  DROP_PATH_RATE: 0.1
+  SWIN:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 6, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 7
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
+  MIN_LR: 2e-07

+ 9 - 0
lib/SwinTransformer/configs/swinmlp/swin_mlp_base_patch4_window7_224.yaml

@@ -0,0 +1,9 @@
+MODEL:
+  TYPE: swin_mlp
+  NAME: swin_mlp_base_patch4_window7_224
+  DROP_PATH_RATE: 0.5
+  SWIN_MLP:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 7

+ 11 - 0
lib/SwinTransformer/configs/swinmlp/swin_mlp_tiny_c12_patch4_window8_256.yaml

@@ -0,0 +1,11 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swin_mlp
+  NAME: swin_mlp_tiny_c12_patch4_window8_256
+  DROP_PATH_RATE: 0.2
+  SWIN_MLP:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 6, 2 ]
+    NUM_HEADS: [ 8, 16, 32, 64 ]
+    WINDOW_SIZE: 8

+ 11 - 0
lib/SwinTransformer/configs/swinmlp/swin_mlp_tiny_c24_patch4_window8_256.yaml

@@ -0,0 +1,11 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swin_mlp
+  NAME: swin_mlp_tiny_c24_patch4_window8_256
+  DROP_PATH_RATE: 0.2
+  SWIN_MLP:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 6, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 8

+ 11 - 0
lib/SwinTransformer/configs/swinmlp/swin_mlp_tiny_c6_patch4_window8_256.yaml

@@ -0,0 +1,11 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swin_mlp
+  NAME: swin_mlp_tiny_c6_patch4_window8_256
+  DROP_PATH_RATE: 0.2
+  SWIN_MLP:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 6, 2 ]
+    NUM_HEADS: [ 16, 32, 64, 128 ]
+    WINDOW_SIZE: 8

+ 31 - 0
lib/SwinTransformer/configs/swinmoe/swin_moe_base_patch4_window12_192_16expert_32gpu_22k.yaml

@@ -0,0 +1,31 @@
+DATA:
+  DATASET: imagenet22K
+  IMG_SIZE: 192
+MODEL:
+  TYPE: swin_moe
+  NAME: swin_moe_base_patch4_window12_192_16expert_32gpu_22k
+  DROP_PATH_RATE: 0.3
+  SWIN_MOE:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 12
+    MLP_FC2_BIAS: False
+    INIT_STD: 0.005
+    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]
+    NUM_LOCAL_EXPERTS: -2
+    TOP_VALUE: 1
+    CAPACITY_FACTOR: 1.25
+    IS_GSHARD_LOSS: False
+    MOE_DROP: 0.1
+    AUX_LOSS_WEIGHT: 0.01
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 10
+  WEIGHT_DECAY: 0.1
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6
+  CLIP_GRAD: 3.0
+TEST:
+  SHUFFLE: True

+ 31 - 0
lib/SwinTransformer/configs/swinmoe/swin_moe_base_patch4_window12_192_32expert_32gpu_22k.yaml

@@ -0,0 +1,31 @@
+DATA:
+  DATASET: imagenet22K
+  IMG_SIZE: 192
+MODEL:
+  TYPE: swin_moe
+  NAME: swin_moe_base_patch4_window12_192_32expert_32gpu_22k
+  DROP_PATH_RATE: 0.3
+  SWIN_MOE:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 12
+    MLP_FC2_BIAS: False
+    INIT_STD: 0.005
+    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]
+    NUM_LOCAL_EXPERTS: 1
+    TOP_VALUE: 1
+    CAPACITY_FACTOR: 1.25
+    IS_GSHARD_LOSS: False
+    MOE_DROP: 0.1
+    AUX_LOSS_WEIGHT: 0.01
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 10
+  WEIGHT_DECAY: 0.1
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6
+  CLIP_GRAD: 3.0
+TEST:
+  SHUFFLE: True

+ 31 - 0
lib/SwinTransformer/configs/swinmoe/swin_moe_base_patch4_window12_192_8expert_32gpu_22k.yaml

@@ -0,0 +1,31 @@
+DATA:
+  DATASET: imagenet22K
+  IMG_SIZE: 192
+MODEL:
+  TYPE: swin_moe
+  NAME: swin_moe_base_patch4_window12_192_8expert_32gpu_22k
+  DROP_PATH_RATE: 0.3
+  SWIN_MOE:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 12
+    MLP_FC2_BIAS: False
+    INIT_STD: 0.005
+    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]
+    NUM_LOCAL_EXPERTS: -4
+    TOP_VALUE: 1
+    CAPACITY_FACTOR: 1.25
+    IS_GSHARD_LOSS: False
+    MOE_DROP: 0.1
+    AUX_LOSS_WEIGHT: 0.01
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 10
+  WEIGHT_DECAY: 0.1
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6
+  CLIP_GRAD: 3.0
+TEST:
+  SHUFFLE: True

+ 32 - 0
lib/SwinTransformer/configs/swinmoe/swin_moe_base_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml

@@ -0,0 +1,32 @@
+DATA:
+  DATASET: imagenet22K
+  IMG_SIZE: 192
+MODEL:
+  TYPE: swin_moe
+  NAME: swin_moe_base_patch4_window12_192_cosine_router_32expert_32gpu_22k
+  DROP_PATH_RATE: 0.3
+  SWIN_MOE:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 12
+    MLP_FC2_BIAS: False
+    INIT_STD: 0.005
+    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]
+    NUM_LOCAL_EXPERTS: 1
+    TOP_VALUE: 1
+    CAPACITY_FACTOR: 1.25
+    COSINE_ROUTER: True
+    IS_GSHARD_LOSS: False
+    MOE_DROP: 0.1
+    AUX_LOSS_WEIGHT: 0.01
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 10
+  WEIGHT_DECAY: 0.1
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6
+  CLIP_GRAD: 3.0
+TEST:
+  SHUFFLE: True

+ 26 - 0
lib/SwinTransformer/configs/swinmoe/swin_moe_base_patch4_window12_192_densebaseline_22k.yaml

@@ -0,0 +1,26 @@
+DATA:
+  DATASET: imagenet22K
+  IMG_SIZE: 192
+MODEL:
+  TYPE: swin_moe
+  NAME: swin_moe_base_patch4_window12_192_densebaseline_22k
+  DROP_PATH_RATE: 0.2
+  SWIN_MOE:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 12
+    MLP_FC2_BIAS: False
+    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ -1 ], [ -1 ] ]
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 10
+  WEIGHT_DECAY: 0.1
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6
+  CLIP_GRAD: 3.0
+  MOE:
+    SAVE_MASTER: True
+TEST:
+  SHUFFLE: True

+ 31 - 0
lib/SwinTransformer/configs/swinmoe/swin_moe_small_patch4_window12_192_16expert_32gpu_22k.yaml

@@ -0,0 +1,31 @@
+DATA:
+  DATASET: imagenet22K
+  IMG_SIZE: 192
+MODEL:
+  TYPE: swin_moe
+  NAME: swin_moe_small_patch4_window12_192_16expert_32gpu_22k
+  DROP_PATH_RATE: 0.2
+  SWIN_MOE:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 12
+    MLP_FC2_BIAS: False
+    INIT_STD: 0.005
+    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]
+    NUM_LOCAL_EXPERTS: -2
+    TOP_VALUE: 1
+    CAPACITY_FACTOR: 1.25
+    IS_GSHARD_LOSS: False
+    MOE_DROP: 0.1
+    AUX_LOSS_WEIGHT: 0.01
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 10
+  WEIGHT_DECAY: 0.1
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6
+  CLIP_GRAD: 3.0
+TEST:
+  SHUFFLE: True

+ 31 - 0
lib/SwinTransformer/configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml

@@ -0,0 +1,31 @@
+DATA:
+  DATASET: imagenet22K
+  IMG_SIZE: 192
+MODEL:
+  TYPE: swin_moe
+  NAME: swin_moe_small_patch4_window12_192_32expert_32gpu_22k
+  DROP_PATH_RATE: 0.2
+  SWIN_MOE:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 12
+    MLP_FC2_BIAS: False
+    INIT_STD: 0.005
+    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]
+    NUM_LOCAL_EXPERTS: 1
+    TOP_VALUE: 1
+    CAPACITY_FACTOR: 1.25
+    IS_GSHARD_LOSS: False
+    MOE_DROP: 0.1
+    AUX_LOSS_WEIGHT: 0.01
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 10
+  WEIGHT_DECAY: 0.1
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6
+  CLIP_GRAD: 3.0
+TEST:
+  SHUFFLE: True

+ 31 - 0
lib/SwinTransformer/configs/swinmoe/swin_moe_small_patch4_window12_192_64expert_64gpu_22k.yaml

@@ -0,0 +1,31 @@
+DATA:
+  DATASET: imagenet22K
+  IMG_SIZE: 192
+MODEL:
+  TYPE: swin_moe
+  NAME: swin_moe_small_patch4_window12_192_64expert_64gpu_22k
+  DROP_PATH_RATE: 0.2
+  SWIN_MOE:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 12
+    MLP_FC2_BIAS: False
+    INIT_STD: 0.005
+    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]
+    NUM_LOCAL_EXPERTS: 1
+    TOP_VALUE: 1
+    CAPACITY_FACTOR: 1.25
+    IS_GSHARD_LOSS: False
+    MOE_DROP: 0.1
+    AUX_LOSS_WEIGHT: 0.01
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 10
+  WEIGHT_DECAY: 0.1
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6
+  CLIP_GRAD: 3.0
+TEST:
+  SHUFFLE: True

+ 31 - 0
lib/SwinTransformer/configs/swinmoe/swin_moe_small_patch4_window12_192_8expert_32gpu_22k.yaml

@@ -0,0 +1,31 @@
+DATA:
+  DATASET: imagenet22K
+  IMG_SIZE: 192
+MODEL:
+  TYPE: swin_moe
+  NAME: swin_moe_small_patch4_window12_192_8expert_32gpu_22k
+  DROP_PATH_RATE: 0.2
+  SWIN_MOE:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 12
+    MLP_FC2_BIAS: False
+    INIT_STD: 0.005
+    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]
+    NUM_LOCAL_EXPERTS: -4
+    TOP_VALUE: 1
+    CAPACITY_FACTOR: 1.25
+    IS_GSHARD_LOSS: False
+    MOE_DROP: 0.1
+    AUX_LOSS_WEIGHT: 0.01
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 10
+  WEIGHT_DECAY: 0.1
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6
+  CLIP_GRAD: 3.0
+TEST:
+  SHUFFLE: True

+ 32 - 0
lib/SwinTransformer/configs/swinmoe/swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml

@@ -0,0 +1,32 @@
+DATA:
+  DATASET: imagenet22K
+  IMG_SIZE: 192
+MODEL:
+  TYPE: swin_moe
+  NAME: swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k
+  DROP_PATH_RATE: 0.2
+  SWIN_MOE:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 12
+    MLP_FC2_BIAS: False
+    INIT_STD: 0.005
+    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ]
+    NUM_LOCAL_EXPERTS: 1
+    TOP_VALUE: 1
+    CAPACITY_FACTOR: 1.25
+    COSINE_ROUTER: True
+    IS_GSHARD_LOSS: False
+    MOE_DROP: 0.1
+    AUX_LOSS_WEIGHT: 0.01
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 10
+  WEIGHT_DECAY: 0.1
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6
+  CLIP_GRAD: 3.0
+TEST:
+  SHUFFLE: True

+ 26 - 0
lib/SwinTransformer/configs/swinmoe/swin_moe_small_patch4_window12_192_densebaseline_22k.yaml

@@ -0,0 +1,26 @@
+DATA:
+  DATASET: imagenet22K
+  IMG_SIZE: 192
+MODEL:
+  TYPE: swin_moe
+  NAME: swin_moe_small_patch4_window12_192_densebaseline_22k
+  DROP_PATH_RATE: 0.2
+  SWIN_MOE:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 12
+    MLP_FC2_BIAS: False
+    MOE_BLOCKS: [ [ -1 ], [ -1 ], [ -1 ], [ -1 ] ]
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 10
+  WEIGHT_DECAY: 0.1
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6
+  CLIP_GRAD: 3.0
+  MOE:
+    SAVE_MASTER: True
+TEST:
+  SHUFFLE: True

+ 19 - 0
lib/SwinTransformer/configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml

@@ -0,0 +1,19 @@
+DATA:
+  DATASET: imagenet22K
+  IMG_SIZE: 192
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window12_192_22k
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 12
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 0.1
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6

+ 19 - 0
lib/SwinTransformer/configs/swinv2/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml

@@ -0,0 +1,19 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window12to16_192to256_22kto1k_ft
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 16
+    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
+  MIN_LR: 2e-07

+ 21 - 0
lib/SwinTransformer/configs/swinv2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml

@@ -0,0 +1,21 @@
+DATA:
+  IMG_SIZE: 384
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window12to24_192to384_22kto1k_ft
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 24
+    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
+  MIN_LR: 2e-07
+TEST:
+  CROP: False

+ 11 - 0
lib/SwinTransformer/configs/swinv2/swinv2_base_patch4_window16_256.yaml

@@ -0,0 +1,11 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window16_256
+  DROP_PATH_RATE: 0.5
+  SWINV2:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 16

+ 11 - 0
lib/SwinTransformer/configs/swinv2/swinv2_base_patch4_window8_256.yaml

@@ -0,0 +1,11 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window8_256
+  DROP_PATH_RATE: 0.5
+  SWINV2:
+    EMBED_DIM: 128
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 4, 8, 16, 32 ]
+    WINDOW_SIZE: 8

+ 19 - 0
lib/SwinTransformer/configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml

@@ -0,0 +1,19 @@
+DATA:
+  DATASET: imagenet22K
+  IMG_SIZE: 192
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_large_patch4_window12_192_22k
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 192
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 6, 12, 24, 48 ]
+    WINDOW_SIZE: 12
+TRAIN:
+  EPOCHS: 90
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 0.1
+  BASE_LR: 1.25e-4 # 4096 batch-size
+  WARMUP_LR: 1.25e-7
+  MIN_LR: 1.25e-6

+ 19 - 0
lib/SwinTransformer/configs/swinv2/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml

@@ -0,0 +1,19 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_base_patch4_window12to16_192to256_22kto1k_ft
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 192
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 6, 12, 24, 48 ]
+    WINDOW_SIZE: 16
+    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
+  MIN_LR: 2e-07

+ 21 - 0
lib/SwinTransformer/configs/swinv2/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml

@@ -0,0 +1,21 @@
+DATA:
+  IMG_SIZE: 384
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_large_patch4_window12to24_192to384_22kto1k_ft
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 192
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 6, 12, 24, 48 ]
+    WINDOW_SIZE: 24
+    PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ]
+TRAIN:
+  EPOCHS: 30
+  WARMUP_EPOCHS: 5
+  WEIGHT_DECAY: 1e-8
+  BASE_LR: 2e-05
+  WARMUP_LR: 2e-08
+  MIN_LR: 2e-07
+TEST:
+  CROP: False

+ 11 - 0
lib/SwinTransformer/configs/swinv2/swinv2_small_patch4_window16_256.yaml

@@ -0,0 +1,11 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_small_patch4_window16_256
+  DROP_PATH_RATE: 0.3
+  SWINV2:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 16

+ 11 - 0
lib/SwinTransformer/configs/swinv2/swinv2_small_patch4_window8_256.yaml

@@ -0,0 +1,11 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_small_patch4_window8_256
+  DROP_PATH_RATE: 0.3
+  SWINV2:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 18, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 8

+ 11 - 0
lib/SwinTransformer/configs/swinv2/swinv2_tiny_patch4_window16_256.yaml

@@ -0,0 +1,11 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_tiny_patch4_window16_256
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 6, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 16

+ 11 - 0
lib/SwinTransformer/configs/swinv2/swinv2_tiny_patch4_window8_256.yaml

@@ -0,0 +1,11 @@
+DATA:
+  IMG_SIZE: 256
+MODEL:
+  TYPE: swinv2
+  NAME: swinv2_tiny_patch4_window8_256
+  DROP_PATH_RATE: 0.2
+  SWINV2:
+    EMBED_DIM: 96
+    DEPTHS: [ 2, 2, 6, 2 ]
+    NUM_HEADS: [ 3, 6, 12, 24 ]
+    WINDOW_SIZE: 8

+ 12 - 0
lib/SwinTransformer/data/__init__.py

@@ -0,0 +1,12 @@
+from .build import build_loader as _build_loader
+from .data_simmim_pt import build_loader_simmim
+from .data_simmim_ft import build_loader_finetune
+
+
+def build_loader(config, simmim=False, is_pretrain=False):
+    if not simmim:
+        return _build_loader(config)
+    if is_pretrain:
+        return build_loader_simmim(config)
+    else:
+        return build_loader_finetune(config)

+ 162 - 0
lib/SwinTransformer/data/build.py

@@ -0,0 +1,162 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import os
+import torch
+import numpy as np
+import torch.distributed as dist
+from torchvision import datasets, transforms
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.data import Mixup
+from timm.data import create_transform
+
+from .cached_image_folder import CachedImageFolder
+from .imagenet22k_dataset import IN22KDATASET
+from .samplers import SubsetRandomSampler
+
+try:
+    from torchvision.transforms import InterpolationMode
+
+
+    def _pil_interp(method):
+        if method == 'bicubic':
+            return InterpolationMode.BICUBIC
+        elif method == 'lanczos':
+            return InterpolationMode.LANCZOS
+        elif method == 'hamming':
+            return InterpolationMode.HAMMING
+        else:
+            # default bilinear, do we want to allow nearest?
+            return InterpolationMode.BILINEAR
+
+
+    import timm.data.transforms as timm_transforms
+
+    timm_transforms._pil_interp = _pil_interp
+except:
+    from timm.data.transforms import _pil_interp
+
+
+def build_loader(config):
+    config.defrost()
+    dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
+    config.freeze()
+    print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset")
+    dataset_val, _ = build_dataset(is_train=False, config=config)
+    print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset")
+
+    num_tasks = dist.get_world_size()
+    global_rank = dist.get_rank()
+    if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part':
+        indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size())
+        sampler_train = SubsetRandomSampler(indices)
+    else:
+        sampler_train = torch.utils.data.DistributedSampler(
+            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+        )
+
+    if config.TEST.SEQUENTIAL:
+        sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+    else:
+        sampler_val = torch.utils.data.distributed.DistributedSampler(
+            dataset_val, shuffle=config.TEST.SHUFFLE
+        )
+
+    data_loader_train = torch.utils.data.DataLoader(
+        dataset_train, sampler=sampler_train,
+        batch_size=config.DATA.BATCH_SIZE,
+        num_workers=config.DATA.NUM_WORKERS,
+        pin_memory=config.DATA.PIN_MEMORY,
+        drop_last=True,
+    )
+
+    data_loader_val = torch.utils.data.DataLoader(
+        dataset_val, sampler=sampler_val,
+        batch_size=config.DATA.BATCH_SIZE,
+        shuffle=False,
+        num_workers=config.DATA.NUM_WORKERS,
+        pin_memory=config.DATA.PIN_MEMORY,
+        drop_last=False
+    )
+
+    # setup mixup / cutmix
+    mixup_fn = None
+    mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
+    if mixup_active:
+        mixup_fn = Mixup(
+            mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
+            prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
+            label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
+
+    return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
+
+
+def build_dataset(is_train, config):
+    transform = build_transform(is_train, config)
+    if config.DATA.DATASET == 'imagenet':
+        prefix = 'train' if is_train else 'val'
+        if config.DATA.ZIP_MODE:
+            ann_file = prefix + "_map.txt"
+            prefix = prefix + ".zip@/"
+            dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
+                                        cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
+        else:
+            root = os.path.join(config.DATA.DATA_PATH, prefix)
+            dataset = datasets.ImageFolder(root, transform=transform)
+        nb_classes = 1000
+    elif config.DATA.DATASET == 'imagenet22K':
+        prefix = 'ILSVRC2011fall_whole'
+        if is_train:
+            ann_file = prefix + "_map_train.txt"
+        else:
+            ann_file = prefix + "_map_val.txt"
+        dataset = IN22KDATASET(config.DATA.DATA_PATH, ann_file, transform)
+        nb_classes = 21841
+    else:
+        raise NotImplementedError("We only support ImageNet Now.")
+
+    return dataset, nb_classes
+
+
+def build_transform(is_train, config):
+    resize_im = config.DATA.IMG_SIZE > 32
+    if is_train:
+        # this should always dispatch to transforms_imagenet_train
+        transform = create_transform(
+            input_size=config.DATA.IMG_SIZE,
+            is_training=True,
+            color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
+            auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
+            re_prob=config.AUG.REPROB,
+            re_mode=config.AUG.REMODE,
+            re_count=config.AUG.RECOUNT,
+            interpolation=config.DATA.INTERPOLATION,
+        )
+        if not resize_im:
+            # replace RandomResizedCropAndInterpolation with
+            # RandomCrop
+            transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
+        return transform
+
+    t = []
+    if resize_im:
+        if config.TEST.CROP:
+            size = int((256 / 224) * config.DATA.IMG_SIZE)
+            t.append(
+                transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
+                # to maintain same ratio w.r.t. 224 images
+            )
+            t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
+        else:
+            t.append(
+                transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
+                                  interpolation=_pil_interp(config.DATA.INTERPOLATION))
+            )
+
+    t.append(transforms.ToTensor())
+    t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
+    return transforms.Compose(t)

+ 252 - 0
lib/SwinTransformer/data/cached_image_folder.py

@@ -0,0 +1,252 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import io
+import os
+import time
+import torch.distributed as dist
+import torch.utils.data as data
+from PIL import Image
+
+from .zipreader import is_zip_path, ZipReader
+
+
+def has_file_allowed_extension(filename, extensions):
+    """Checks if a file is an allowed extension.
+    Args:
+        filename (string): path to a file
+    Returns:
+        bool: True if the filename ends with a known image extension
+    """
+    filename_lower = filename.lower()
+    return any(filename_lower.endswith(ext) for ext in extensions)
+
+
+def find_classes(dir):
+    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
+    classes.sort()
+    class_to_idx = {classes[i]: i for i in range(len(classes))}
+    return classes, class_to_idx
+
+
+def make_dataset(dir, class_to_idx, extensions):
+    images = []
+    dir = os.path.expanduser(dir)
+    for target in sorted(os.listdir(dir)):
+        d = os.path.join(dir, target)
+        if not os.path.isdir(d):
+            continue
+
+        for root, _, fnames in sorted(os.walk(d)):
+            for fname in sorted(fnames):
+                if has_file_allowed_extension(fname, extensions):
+                    path = os.path.join(root, fname)
+                    item = (path, class_to_idx[target])
+                    images.append(item)
+
+    return images
+
+
+def make_dataset_with_ann(ann_file, img_prefix, extensions):
+    images = []
+    with open(ann_file, "r") as f:
+        contents = f.readlines()
+        for line_str in contents:
+            path_contents = [c for c in line_str.split('\t')]
+            im_file_name = path_contents[0]
+            class_index = int(path_contents[1])
+
+            assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions
+            item = (os.path.join(img_prefix, im_file_name), class_index)
+
+            images.append(item)
+
+    return images
+
+
+class DatasetFolder(data.Dataset):
+    """A generic data loader where the samples are arranged in this way: ::
+        root/class_x/xxx.ext
+        root/class_x/xxy.ext
+        root/class_x/xxz.ext
+        root/class_y/123.ext
+        root/class_y/nsdf3.ext
+        root/class_y/asd932_.ext
+    Args:
+        root (string): Root directory path.
+        loader (callable): A function to load a sample given its path.
+        extensions (list[string]): A list of allowed extensions.
+        transform (callable, optional): A function/transform that takes in
+            a sample and returns a transformed version.
+            E.g, ``transforms.RandomCrop`` for images.
+        target_transform (callable, optional): A function/transform that takes
+            in the target and transforms it.
+     Attributes:
+        samples (list): List of (sample path, class_index) tuples
+    """
+
+    def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None,
+                 cache_mode="no"):
+        # image folder mode
+        if ann_file == '':
+            _, class_to_idx = find_classes(root)
+            samples = make_dataset(root, class_to_idx, extensions)
+        # zip mode
+        else:
+            samples = make_dataset_with_ann(os.path.join(root, ann_file),
+                                            os.path.join(root, img_prefix),
+                                            extensions)
+
+        if len(samples) == 0:
+            raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" +
+                                "Supported extensions are: " + ",".join(extensions)))
+
+        self.root = root
+        self.loader = loader
+        self.extensions = extensions
+
+        self.samples = samples
+        self.labels = [y_1k for _, y_1k in samples]
+        self.classes = list(set(self.labels))
+
+        self.transform = transform
+        self.target_transform = target_transform
+
+        self.cache_mode = cache_mode
+        if self.cache_mode != "no":
+            self.init_cache()
+
+    def init_cache(self):
+        assert self.cache_mode in ["part", "full"]
+        n_sample = len(self.samples)
+        global_rank = dist.get_rank()
+        world_size = dist.get_world_size()
+
+        samples_bytes = [None for _ in range(n_sample)]
+        start_time = time.time()
+        for index in range(n_sample):
+            if index % (n_sample // 10) == 0:
+                t = time.time() - start_time
+                print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block')
+                start_time = time.time()
+            path, target = self.samples[index]
+            if self.cache_mode == "full":
+                samples_bytes[index] = (ZipReader.read(path), target)
+            elif self.cache_mode == "part" and index % world_size == global_rank:
+                samples_bytes[index] = (ZipReader.read(path), target)
+            else:
+                samples_bytes[index] = (path, target)
+        self.samples = samples_bytes
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+        Returns:
+            tuple: (sample, target) where target is class_index of the target class.
+        """
+        path, target = self.samples[index]
+        sample = self.loader(path)
+        if self.transform is not None:
+            sample = self.transform(sample)
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return sample, target
+
+    def __len__(self):
+        return len(self.samples)
+
+    def __repr__(self):
+        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
+        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
+        fmt_str += '    Root Location: {}\n'.format(self.root)
+        tmp = '    Transforms (if any): '
+        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
+        tmp = '    Target Transforms (if any): '
+        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
+        return fmt_str
+
+
+IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
+
+
+def pil_loader(path):
+    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
+    if isinstance(path, bytes):
+        img = Image.open(io.BytesIO(path))
+    elif is_zip_path(path):
+        data = ZipReader.read(path)
+        img = Image.open(io.BytesIO(data))
+    else:
+        with open(path, 'rb') as f:
+            img = Image.open(f)
+            return img.convert('RGB')
+    return img.convert('RGB')
+
+
+def accimage_loader(path):
+    import accimage
+    try:
+        return accimage.Image(path)
+    except IOError:
+        # Potentially a decoding problem, fall back to PIL.Image
+        return pil_loader(path)
+
+
+def default_img_loader(path):
+    from torchvision import get_image_backend
+    if get_image_backend() == 'accimage':
+        return accimage_loader(path)
+    else:
+        return pil_loader(path)
+
+
+class CachedImageFolder(DatasetFolder):
+    """A generic data loader where the images are arranged in this way: ::
+        root/dog/xxx.png
+        root/dog/xxy.png
+        root/dog/xxz.png
+        root/cat/123.png
+        root/cat/nsdf3.png
+        root/cat/asd932_.png
+    Args:
+        root (string): Root directory path.
+        transform (callable, optional): A function/transform that  takes in an PIL image
+            and returns a transformed version. E.g, ``transforms.RandomCrop``
+        target_transform (callable, optional): A function/transform that takes in the
+            target and transforms it.
+        loader (callable, optional): A function to load an image given its path.
+     Attributes:
+        imgs (list): List of (image path, class_index) tuples
+    """
+
+    def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None,
+                 loader=default_img_loader, cache_mode="no"):
+        super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
+                                                ann_file=ann_file, img_prefix=img_prefix,
+                                                transform=transform, target_transform=target_transform,
+                                                cache_mode=cache_mode)
+        self.imgs = self.samples
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+        Returns:
+            tuple: (image, target) where target is class_index of the target class.
+        """
+        path, target = self.samples[index]
+        image = self.loader(path)
+        if self.transform is not None:
+            img = self.transform(image)
+        else:
+            img = image
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return img, target

+ 112 - 0
lib/SwinTransformer/data/data_simmim_ft.py

@@ -0,0 +1,112 @@
+# --------------------------------------------------------
+# SimMIM
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Zhenda Xie
+# --------------------------------------------------------
+
+import os
+import torch.distributed as dist
+from torch.utils.data import DataLoader, DistributedSampler
+from torchvision import datasets, transforms
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.data import Mixup
+from timm.data import create_transform
+from timm.data.transforms import _pil_interp
+
+
+def build_loader_finetune(config):
+    config.defrost()
+    dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config)
+    config.freeze()
+    dataset_val, _ = build_dataset(is_train=False, config=config)
+
+    num_tasks = dist.get_world_size()
+    global_rank = dist.get_rank()
+    sampler_train = DistributedSampler(
+        dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+    )
+    sampler_val = DistributedSampler(
+        dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False
+    )
+
+    data_loader_train = DataLoader(
+        dataset_train, sampler=sampler_train,
+        batch_size=config.DATA.BATCH_SIZE,
+        num_workers=config.DATA.NUM_WORKERS,
+        pin_memory=config.DATA.PIN_MEMORY,
+        drop_last=True,
+    )
+
+    data_loader_val = DataLoader(
+        dataset_val, sampler=sampler_val,
+        batch_size=config.DATA.BATCH_SIZE,
+        num_workers=config.DATA.NUM_WORKERS,
+        pin_memory=config.DATA.PIN_MEMORY,
+        drop_last=False,
+    )
+
+    # setup mixup / cutmix
+    mixup_fn = None
+    mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
+    if mixup_active:
+        mixup_fn = Mixup(
+            mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
+            prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
+            label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
+
+    return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
+
+
+def build_dataset(is_train, config):
+    transform = build_transform(is_train, config)
+    
+    if config.DATA.DATASET == 'imagenet':
+        prefix = 'train' if is_train else 'val'
+        root = os.path.join(config.DATA.DATA_PATH, prefix)
+        dataset = datasets.ImageFolder(root, transform=transform)
+        nb_classes = 1000
+    else:
+        raise NotImplementedError("We only support ImageNet Now.")
+
+    return dataset, nb_classes
+
+
+def build_transform(is_train, config):
+    resize_im = config.DATA.IMG_SIZE > 32
+    if is_train:
+        # this should always dispatch to transforms_imagenet_train
+        transform = create_transform(
+            input_size=config.DATA.IMG_SIZE,
+            is_training=True,
+            color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
+            auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
+            re_prob=config.AUG.REPROB,
+            re_mode=config.AUG.REMODE,
+            re_count=config.AUG.RECOUNT,
+            interpolation=config.DATA.INTERPOLATION,
+        )
+        if not resize_im:
+            # replace RandomResizedCropAndInterpolation with
+            # RandomCrop
+            transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
+        return transform
+
+    t = []
+    if resize_im:
+        if config.TEST.CROP:
+            size = int((256 / 224) * config.DATA.IMG_SIZE)
+            t.append(
+                transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
+                # to maintain same ratio w.r.t. 224 images
+            )
+            t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
+        else:
+            t.append(
+                transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
+                                  interpolation=_pil_interp(config.DATA.INTERPOLATION))
+            )
+
+    t.append(transforms.ToTensor())
+    t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
+    return transforms.Compose(t)

+ 99 - 0
lib/SwinTransformer/data/data_simmim_pt.py

@@ -0,0 +1,99 @@
+# --------------------------------------------------------
+# SimMIM
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Zhenda Xie
+# --------------------------------------------------------
+
+import math
+import random
+import numpy as np
+
+import torch
+import torch.distributed as dist
+import torchvision.transforms as T
+from torch.utils.data import DataLoader, DistributedSampler
+from torch.utils.data._utils.collate import default_collate
+from torchvision.datasets import ImageFolder
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+
+
+class MaskGenerator:
+    def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
+        self.input_size = input_size
+        self.mask_patch_size = mask_patch_size
+        self.model_patch_size = model_patch_size
+        self.mask_ratio = mask_ratio
+        
+        assert self.input_size % self.mask_patch_size == 0
+        assert self.mask_patch_size % self.model_patch_size == 0
+        
+        self.rand_size = self.input_size // self.mask_patch_size
+        self.scale = self.mask_patch_size // self.model_patch_size
+        
+        self.token_count = self.rand_size ** 2
+        self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
+        
+    def __call__(self):
+        mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
+        mask = np.zeros(self.token_count, dtype=int)
+        mask[mask_idx] = 1
+        
+        mask = mask.reshape((self.rand_size, self.rand_size))
+        mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
+        
+        return mask
+
+
+class SimMIMTransform:
+    def __init__(self, config):
+        self.transform_img = T.Compose([
+            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
+            T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)),
+            T.RandomHorizontalFlip(),
+            T.ToTensor(),
+            T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)),
+        ])
+ 
+        if config.MODEL.TYPE in ['swin', 'swinv2']:
+            model_patch_size=config.MODEL.SWIN.PATCH_SIZE
+        else:
+            raise NotImplementedError
+        
+        self.mask_generator = MaskGenerator(
+            input_size=config.DATA.IMG_SIZE,
+            mask_patch_size=config.DATA.MASK_PATCH_SIZE,
+            model_patch_size=model_patch_size,
+            mask_ratio=config.DATA.MASK_RATIO,
+        )
+    
+    def __call__(self, img):
+        img = self.transform_img(img)
+        mask = self.mask_generator()
+        
+        return img, mask
+
+
+def collate_fn(batch):
+    if not isinstance(batch[0][0], tuple):
+        return default_collate(batch)
+    else:
+        batch_num = len(batch)
+        ret = []
+        for item_idx in range(len(batch[0][0])):
+            if batch[0][0][item_idx] is None:
+                ret.append(None)
+            else:
+                ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
+        ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
+        return ret
+
+
+def build_loader_simmim(config):
+    transform = SimMIMTransform(config)
+    dataset = ImageFolder(config.DATA.DATA_PATH, transform)
+    
+    sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
+    dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn)
+    
+    return dataloader

+ 55 - 0
lib/SwinTransformer/data/imagenet22k_dataset.py

@@ -0,0 +1,55 @@
+import os
+import json
+import torch.utils.data as data
+import numpy as np
+from PIL import Image
+
+import warnings
+
+warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
+
+
+class IN22KDATASET(data.Dataset):
+    def __init__(self, root, ann_file='', transform=None, target_transform=None):
+        super(IN22KDATASET, self).__init__()
+
+        self.data_path = root
+        self.ann_path = os.path.join(self.data_path, ann_file)
+        self.transform = transform
+        self.target_transform = target_transform
+        # id & label: https://github.com/google-research/big_transfer/issues/7
+        # total: 21843; only 21841 class have images: map 21841->9205; 21842->15027
+        self.database = json.load(open(self.ann_path))
+
+    def _load_image(self, path):
+        try:
+            im = Image.open(path)
+        except:
+            print("ERROR IMG LOADED: ", path)
+            random_img = np.random.rand(224, 224, 3) * 255
+            im = Image.fromarray(np.uint8(random_img))
+        return im
+
+    def __getitem__(self, index):
+        """
+        Args:
+            index (int): Index
+        Returns:
+            tuple: (image, target) where target is class_index of the target class.
+        """
+        idb = self.database[index]
+
+        # images
+        images = self._load_image(self.data_path + '/' + idb[0]).convert('RGB')
+        if self.transform is not None:
+            images = self.transform(images)
+
+        # target
+        target = int(idb[1])
+        if self.target_transform is not None:
+            target = self.target_transform(target)
+
+        return images, target
+
+    def __len__(self):
+        return len(self.database)

+ 1000 - 0
lib/SwinTransformer/data/map22kto1k.txt

@@ -0,0 +1,1000 @@
+359
+368
+460
+475
+486
+492
+496
+514
+516
+525
+547
+548
+556
+563
+575
+641
+648
+723
+733
+765
+801
+826
+852
+858
+878
+896
+900
+905
+908
+910
+935
+946
+947
+994
+999
+1003
+1005
+1010
+1027
+1029
+1048
+1055
+1064
+1065
+1069
+1075
+1079
+1081
+1085
+1088
+1093
+1106
+1143
+1144
+1145
+1147
+1168
+1171
+1178
+1187
+1190
+1197
+1205
+1216
+1223
+1230
+1236
+1241
+1245
+1257
+1259
+1260
+1267
+1268
+1269
+1271
+1272
+1273
+1277
+1303
+1344
+1349
+1355
+1357
+1384
+1388
+1391
+1427
+1429
+1432
+1437
+1450
+1461
+1462
+1474
+1502
+1503
+1512
+1552
+1555
+1577
+1584
+1587
+1589
+1599
+1615
+1616
+1681
+1692
+1701
+1716
+1729
+1757
+1759
+1764
+1777
+1786
+1822
+1841
+1842
+1848
+1850
+1856
+1860
+1861
+1864
+1876
+1897
+1898
+1910
+1913
+1918
+1922
+1928
+1932
+1935
+1947
+1951
+1953
+1970
+1977
+1979
+2001
+2017
+2067
+2081
+2087
+2112
+2128
+2135
+2147
+2174
+2175
+2176
+2177
+2178
+2181
+2183
+2184
+2187
+2189
+2190
+2191
+2192
+2193
+2197
+2202
+2203
+2206
+2208
+2209
+2211
+2212
+2213
+2214
+2215
+2216
+2217
+2219
+2222
+2223
+2224
+2225
+2226
+2227
+2228
+2229
+2230
+2236
+2238
+2240
+2241
+2242
+2243
+2244
+2245
+2247
+2248
+2249
+2250
+2251
+2252
+2255
+2256
+2257
+2262
+2263
+2264
+2265
+2266
+2268
+2270
+2271
+2272
+2273
+2275
+2276
+2279
+2280
+2281
+2282
+2285
+2289
+2292
+2295
+2296
+2297
+2298
+2299
+2300
+2301
+2302
+2303
+2304
+2305
+2306
+2309
+2310
+2312
+2313
+2314
+2315
+2316
+2318
+2319
+2321
+2322
+2326
+2329
+2330
+2331
+2332
+2334
+2335
+2336
+2337
+2338
+2339
+2341
+2342
+2343
+2344
+2346
+2348
+2349
+2351
+2352
+2353
+2355
+2357
+2358
+2359
+2360
+2364
+2365
+2368
+2369
+2377
+2382
+2383
+2385
+2397
+2398
+2400
+2402
+2405
+2412
+2421
+2428
+2431
+2432
+2433
+2436
+2441
+2445
+2450
+2453
+2454
+2465
+2469
+2532
+2533
+2538
+2544
+2547
+2557
+2565
+2578
+2612
+2658
+2702
+2722
+2731
+2738
+2741
+2747
+2810
+2818
+2833
+2844
+2845
+2867
+2874
+2882
+2884
+2888
+2889
+3008
+3012
+3019
+3029
+3033
+3042
+3091
+3106
+3138
+3159
+3164
+3169
+3280
+3296
+3311
+3318
+3320
+3324
+3330
+3366
+3375
+3381
+3406
+3419
+3432
+3434
+3435
+3493
+3495
+3503
+3509
+3511
+3513
+3517
+3521
+3526
+3546
+3554
+3600
+3601
+3606
+3612
+3613
+3616
+3622
+3623
+3627
+3632
+3634
+3636
+3638
+3644
+3646
+3649
+3650
+3651
+3656
+3663
+3673
+3674
+3689
+3690
+3702
+3733
+3769
+3971
+3974
+4065
+4068
+4073
+4102
+4136
+4140
+4151
+4159
+4165
+4207
+4219
+4226
+4249
+4256
+4263
+4270
+4313
+4321
+4378
+4386
+4478
+4508
+4512
+4536
+4542
+4550
+4560
+4562
+4570
+4571
+4572
+4583
+4588
+4594
+4604
+4608
+4623
+4634
+4636
+4646
+4651
+4652
+4686
+4688
+4691
+4699
+4724
+4727
+4737
+4770
+4774
+4789
+4802
+4807
+4819
+4880
+4886
+4908
+4927
+4931
+4936
+4964
+4976
+4993
+5028
+5033
+5043
+5046
+5096
+5111
+5114
+5131
+5132
+5183
+5199
+5235
+5275
+5291
+5293
+5294
+5343
+5360
+5362
+5364
+5390
+5402
+5418
+5428
+5430
+5437
+5443
+5473
+5484
+5486
+5505
+5507
+5508
+5510
+5567
+5578
+5580
+5584
+5606
+5613
+5629
+5672
+5676
+5692
+5701
+5760
+5769
+5770
+5779
+5814
+5850
+5871
+5893
+5911
+5949
+5954
+6005
+6006
+6012
+6017
+6023
+6024
+6040
+6050
+6054
+6087
+6105
+6157
+6235
+6237
+6256
+6259
+6286
+6291
+6306
+6339
+6341
+6343
+6379
+6383
+6393
+6405
+6479
+6511
+6517
+6541
+6561
+6608
+6611
+6615
+6678
+6682
+6707
+6752
+6798
+6850
+6880
+6885
+6890
+6920
+6981
+7000
+7009
+7038
+7049
+7050
+7052
+7073
+7078
+7098
+7111
+7165
+7198
+7204
+7280
+7283
+7286
+7287
+7293
+7294
+7305
+7318
+7341
+7346
+7354
+7382
+7427
+7428
+7435
+7445
+7450
+7455
+7467
+7469
+7497
+7502
+7506
+7514
+7523
+7651
+7661
+7664
+7672
+7679
+7685
+7696
+7730
+7871
+7873
+7895
+7914
+7915
+7920
+7934
+7935
+7949
+8009
+8036
+8051
+8065
+8074
+8090
+8112
+8140
+8164
+8168
+8178
+8182
+8198
+8212
+8216
+8230
+8242
+8288
+8289
+8295
+8318
+8352
+8368
+8371
+8375
+8376
+8401
+8416
+8419
+8436
+8460
+8477
+8478
+8482
+8498
+8500
+8539
+8543
+8552
+8555
+8580
+8584
+8586
+8594
+8598
+8601
+8606
+8610
+8611
+8622
+8627
+8639
+8649
+8650
+8653
+8654
+8667
+8672
+8673
+8674
+8676
+8684
+8720
+8723
+8750
+8753
+8801
+8815
+8831
+8835
+8842
+8845
+8858
+8897
+8916
+8951
+8954
+8959
+8970
+8976
+8981
+8983
+8989
+8991
+8993
+9019
+9039
+9042
+9043
+9056
+9057
+9070
+9087
+9098
+9106
+9130
+9131
+9155
+9171
+9183
+9198
+9199
+9201
+9204
+9212
+9221
+9225
+9229
+9250
+9260
+9271
+9279
+9295
+9300
+9310
+9322
+9345
+9352
+9376
+9377
+9382
+9392
+9401
+9405
+9441
+9449
+9464
+9475
+9502
+9505
+9514
+9515
+9545
+9567
+9576
+9608
+9609
+9624
+9633
+9639
+9643
+9656
+9674
+9740
+9752
+9760
+9767
+9778
+9802
+9820
+9839
+9879
+9924
+9956
+9961
+9963
+9970
+9997
+10010
+10031
+10040
+10052
+10073
+10075
+10078
+10094
+10097
+10109
+10118
+10121
+10124
+10158
+10226
+10276
+10304
+10307
+10314
+10315
+10332
+10337
+10338
+10413
+10423
+10451
+10463
+10465
+10487
+10519
+10522
+10523
+10532
+10534
+10535
+10551
+10559
+10574
+10583
+10586
+10589
+10612
+10626
+10635
+10638
+10677
+10683
+10726
+10776
+10782
+10783
+10807
+10837
+10840
+10848
+10859
+10871
+10881
+10884
+10908
+10914
+10921
+10936
+10947
+10951
+10952
+10957
+10999
+11003
+11018
+11023
+11025
+11027
+11045
+11055
+11095
+11110
+11137
+5564
+11168
+11186
+11221
+11223
+11242
+11255
+11259
+11279
+11306
+11311
+11331
+11367
+11377
+11389
+11392
+11401
+11407
+11437
+11449
+11466
+11469
+11473
+11478
+11483
+11484
+11507
+11536
+11558
+11566
+11575
+11584
+11594
+11611
+11612
+11619
+11621
+11640
+11643
+11664
+11674
+11689
+11709
+11710
+11716
+11721
+11726
+11729
+11743
+11760
+11771
+11837
+11839
+11856
+11876
+11878
+11884
+11889
+11896
+11917
+11923
+11930
+11944
+11952
+11980
+11984
+12214
+12229
+12239
+12241
+12242
+12247
+12283
+12349
+12369
+12373
+12422
+12560
+12566
+12575
+12688
+12755
+12768
+12778
+12780
+12812
+12832
+12835
+12836
+12843
+12847
+12849
+12850
+12856
+12858
+12873
+12938
+12971
+13017
+13038
+13046
+13059
+13085
+13086
+13088
+13094
+13134
+13182
+13230
+13406
+13444
+13614
+13690
+13698
+13709
+13749
+13804
+13982
+14051
+14059
+14219
+14246
+14256
+14264
+14294
+14324
+14367
+14389
+14394
+14438
+14442
+14965
+15732
+16744
+18037
+18205
+18535
+18792
+19102
+20019
+20462
+21026
+21045
+21163
+21171
+21181
+21196
+21200
+21369
+21817

+ 29 - 0
lib/SwinTransformer/data/samplers.py

@@ -0,0 +1,29 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import torch
+
+
+class SubsetRandomSampler(torch.utils.data.Sampler):
+    r"""Samples elements randomly from a given list of indices, without replacement.
+
+    Arguments:
+        indices (sequence): a sequence of indices
+    """
+
+    def __init__(self, indices):
+        self.epoch = 0
+        self.indices = indices
+
+    def __iter__(self):
+        return (self.indices[i] for i in torch.randperm(len(self.indices)))
+
+    def __len__(self):
+        return len(self.indices)
+
+    def set_epoch(self, epoch):
+        self.epoch = epoch

+ 103 - 0
lib/SwinTransformer/data/zipreader.py

@@ -0,0 +1,103 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import os
+import zipfile
+import io
+import numpy as np
+from PIL import Image
+from PIL import ImageFile
+
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+
+def is_zip_path(img_or_path):
+    """judge if this is a zip path"""
+    return '.zip@' in img_or_path
+
+
+class ZipReader(object):
+    """A class to read zipped files"""
+    zip_bank = dict()
+
+    def __init__(self):
+        super(ZipReader, self).__init__()
+
+    @staticmethod
+    def get_zipfile(path):
+        zip_bank = ZipReader.zip_bank
+        if path not in zip_bank:
+            zfile = zipfile.ZipFile(path, 'r')
+            zip_bank[path] = zfile
+        return zip_bank[path]
+
+    @staticmethod
+    def split_zip_style_path(path):
+        pos_at = path.index('@')
+        assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
+
+        zip_path = path[0: pos_at]
+        folder_path = path[pos_at + 1:]
+        folder_path = str.strip(folder_path, '/')
+        return zip_path, folder_path
+
+    @staticmethod
+    def list_folder(path):
+        zip_path, folder_path = ZipReader.split_zip_style_path(path)
+
+        zfile = ZipReader.get_zipfile(zip_path)
+        folder_list = []
+        for file_foler_name in zfile.namelist():
+            file_foler_name = str.strip(file_foler_name, '/')
+            if file_foler_name.startswith(folder_path) and \
+                    len(os.path.splitext(file_foler_name)[-1]) == 0 and \
+                    file_foler_name != folder_path:
+                if len(folder_path) == 0:
+                    folder_list.append(file_foler_name)
+                else:
+                    folder_list.append(file_foler_name[len(folder_path) + 1:])
+
+        return folder_list
+
+    @staticmethod
+    def list_files(path, extension=None):
+        if extension is None:
+            extension = ['.*']
+        zip_path, folder_path = ZipReader.split_zip_style_path(path)
+
+        zfile = ZipReader.get_zipfile(zip_path)
+        file_lists = []
+        for file_foler_name in zfile.namelist():
+            file_foler_name = str.strip(file_foler_name, '/')
+            if file_foler_name.startswith(folder_path) and \
+                    str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
+                if len(folder_path) == 0:
+                    file_lists.append(file_foler_name)
+                else:
+                    file_lists.append(file_foler_name[len(folder_path) + 1:])
+
+        return file_lists
+
+    @staticmethod
+    def read(path):
+        zip_path, path_img = ZipReader.split_zip_style_path(path)
+        zfile = ZipReader.get_zipfile(zip_path)
+        data = zfile.read(path_img)
+        return data
+
+    @staticmethod
+    def imread(path):
+        zip_path, path_img = ZipReader.split_zip_style_path(path)
+        zfile = ZipReader.get_zipfile(zip_path)
+        data = zfile.read(path_img)
+        try:
+            im = Image.open(io.BytesIO(data))
+        except:
+            print("ERROR IMG LOADED: ", path_img)
+            random_img = np.random.rand(224, 224, 3) * 255
+            im = Image.fromarray(np.uint8(random_img))
+        return im

BIN
lib/SwinTransformer/figures/teaser.png


+ 310 - 0
lib/SwinTransformer/get_started.md

@@ -0,0 +1,310 @@
+# Swin Transformer for Image Classification
+
+This folder contains the implementation of the Swin Transformer for image classification.
+
+## Model Zoo
+
+Please refer to [MODEL HUB](MODELHUB.md#imagenet-22k-pretrained-swin-moe-models) for more pre-trained models.
+
+## Usage
+
+### Install
+
+We recommend using the pytorch docker `nvcr>=21.05` by
+nvidia: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch.
+
+- Clone this repo:
+
+```bash
+git clone https://github.com/microsoft/Swin-Transformer.git
+cd SwinTransformer
+```
+
+- Create a conda virtual environment and activate it:
+
+```bash
+conda create -n swin python=3.7 -y
+conda activate swin
+```
+
+- Install `CUDA>=10.2` with `cudnn>=7` following
+  the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html)
+- Install `PyTorch>=1.8.0` and `torchvision>=0.9.0` with `CUDA>=10.2`:
+
+```bash
+conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=10.2 -c pytorch
+```
+
+- Install `timm==0.4.12`:
+
+```bash
+pip install timm==0.4.12
+```
+
+- Install other requirements:
+
+```bash
+pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8 pyyaml scipy
+```
+
+- Install fused window process for acceleration, activated by passing `--fused_window_process` in the running script
+```bash
+cd kernels/window_process
+python setup.py install #--user
+```
+
+### Data preparation
+
+We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to
+load data:
+
+- For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like:
+  ```bash
+  $ tree data
+  imagenet
+  ├── train
+  │   ├── class1
+  │   │   ├── img1.jpeg
+  │   │   ├── img2.jpeg
+  │   │   └── ...
+  │   ├── class2
+  │   │   ├── img3.jpeg
+  │   │   └── ...
+  │   └── ...
+  └── val
+      ├── class1
+      │   ├── img4.jpeg
+      │   ├── img5.jpeg
+      │   └── ...
+      ├── class2
+      │   ├── img6.jpeg
+      │   └── ...
+      └── ...
+ 
+  ```
+- To boost the slow speed when reading images from massive small files, we also support zipped ImageNet, which includes
+  four files:
+    - `train.zip`, `val.zip`: which store the zipped folder for train and validate splits.
+    - `train_map.txt`, `val_map.txt`: which store the relative path in the corresponding zip file and ground truth
+      label. Make sure the data folder looks like this:
+
+  ```bash
+  $ tree data
+  data
+  └── ImageNet-Zip
+      ├── train_map.txt
+      ├── train.zip
+      ├── val_map.txt
+      └── val.zip
+  
+  $ head -n 5 data/ImageNet-Zip/val_map.txt
+  ILSVRC2012_val_00000001.JPEG	65
+  ILSVRC2012_val_00000002.JPEG	970
+  ILSVRC2012_val_00000003.JPEG	230
+  ILSVRC2012_val_00000004.JPEG	809
+  ILSVRC2012_val_00000005.JPEG	516
+  
+  $ head -n 5 data/ImageNet-Zip/train_map.txt
+  n01440764/n01440764_10026.JPEG	0
+  n01440764/n01440764_10027.JPEG	0
+  n01440764/n01440764_10029.JPEG	0
+  n01440764/n01440764_10040.JPEG	0
+  n01440764/n01440764_10042.JPEG	0
+  ```
+- For ImageNet-22K dataset, make a folder named `fall11_whole` and move all images to labeled sub-folders in this
+  folder. Then download the train-val split
+  file ([ILSVRC2011fall_whole_map_train.txt](https://github.com/SwinTransformer/storage/releases/download/v2.0.1/ILSVRC2011fall_whole_map_train.txt)
+  & [ILSVRC2011fall_whole_map_val.txt](https://github.com/SwinTransformer/storage/releases/download/v2.0.1/ILSVRC2011fall_whole_map_val.txt))
+  , and put them in the parent directory of `fall11_whole`. The file structure should look like:
+
+  ```bash
+    $ tree imagenet22k/
+    imagenet22k/
+    ├── ILSVRC2011fall_whole_map_train.txt
+    ├── ILSVRC2011fall_whole_map_val.txt
+    └── fall11_whole
+        ├── n00004475
+        ├── n00005787
+        ├── n00006024
+        ├── n00006484
+        └── ...
+  ```
+
+### Evaluation
+
+To evaluate a pre-trained `Swin Transformer` on ImageNet val, run:
+
+```bash
+python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py --eval \
+--cfg <config-file> --resume <checkpoint> --data-path <imagenet-path> 
+```
+
+For example, to evaluate the `Swin-B` with a single GPU:
+
+```bash
+python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --eval \
+--cfg configs/swin/swin_base_patch4_window7_224.yaml --resume swin_base_patch4_window7_224.pth --data-path <imagenet-path>
+```
+
+### Training from scratch on ImageNet-1K
+
+To train a `Swin Transformer` on ImageNet from scratch, run:
+
+```bash
+python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345  main.py \ 
+--cfg <config-file> --data-path <imagenet-path> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]
+```
+
+**Notes**:
+
+- To use zipped ImageNet instead of folder dataset, add `--zip` to the parameters.
+    - To cache the dataset in the memory instead of reading from files every time, add `--cache-mode part`, which will
+      shard the dataset into non-overlapping pieces for different GPUs and only load the corresponding one for each GPU.
+- When GPU memory is not enough, you can try the following suggestions:
+    - Use gradient accumulation by adding `--accumulation-steps <steps>`, set appropriate `<steps>` according to your need.
+    - Use gradient checkpointing by adding `--use-checkpoint`, e.g., it saves about 60% memory when training `Swin-B`.
+      Please refer to [this page](https://pytorch.org/docs/stable/checkpoint.html) for more details.
+    - We recommend using multi-node with more GPUs for training very large models, a tutorial can be found
+      in [this page](https://pytorch.org/tutorials/intermediate/dist_tuto.html).
+- To change config options in general, you can use `--opts KEY1 VALUE1 KEY2 VALUE2`, e.g.,
+  `--opts TRAIN.EPOCHS 100 TRAIN.WARMUP_EPOCHS 5` will change total epochs to 100 and warm-up epochs to 5.
+- For additional options, see [config](config.py) and run `python main.py --help` to get detailed message.
+
+For example, to train `Swin Transformer` with 8 GPU on a single node for 300 epochs, run:
+
+`Swin-T`:
+
+```bash
+python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \
+--cfg configs/swin/swin_tiny_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 128 
+```
+
+`Swin-S`:
+
+```bash
+python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \
+--cfg configs/swin/swin_small_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 128 
+```
+
+`Swin-B`:
+
+```bash
+python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \
+--cfg configs/swin/swin_base_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 64 \
+--accumulation-steps 2 [--use-checkpoint]
+```
+
+### Pre-training on ImageNet-22K
+
+For example, to pre-train a `Swin-B` model on ImageNet-22K:
+
+```bash
+python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \
+--cfg configs/swin/swin_base_patch4_window7_224_22k.yaml --data-path <imagenet22k-path> --batch-size 64 \
+--accumulation-steps 8 [--use-checkpoint]
+```
+
+### Fine-tuning on higher resolution
+
+For example, to fine-tune a `Swin-B` model pre-trained on 224x224 resolution to 384x384 resolution:
+
+```bashs
+python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \
+--cfg configs/swin/swin_base_patch4_window12_384_finetune.yaml --pretrained swin_base_patch4_window7_224.pth \
+--data-path <imagenet-path> --batch-size 64 --accumulation-steps 2 [--use-checkpoint]
+```
+
+### Fine-tuning from a ImageNet-22K(21K) pre-trained model
+
+For example, to fine-tune a `Swin-B` model pre-trained on ImageNet-22K(21K):
+
+```bashs
+python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \
+--cfg configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml --pretrained swin_base_patch4_window7_224_22k.pth \
+--data-path <imagenet-path> --batch-size 64 --accumulation-steps 2 [--use-checkpoint]
+```
+
+### Throughput
+
+To measure the throughput, run:
+
+```bash
+python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345  main.py \
+--cfg <config-file> --data-path <imagenet-path> --batch-size 64 --throughput --disable_amp
+```
+
+
+## Mixture-of-Experts Support
+
+### Install [Tutel](https://github.com/microsoft/tutel)
+```bash
+python3 -m pip uninstall tutel -y 
+python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@main
+```
+
+### Training Swin-MoE 
+For example, to train a `Swin-MoE-S` model with 32 experts on ImageNet-22K with 32 GPUs (4 nodes):
+
+```bash
+python -m torch.distributed.launch --nproc_per_node 8 --nnode=4 \
+--node_rank=<node-rank> --master_addr=<master-ip> --master_port 12345  main_moe.py \
+--cfg configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml --data-path <imagenet22k-path> --batch-size 128
+```
+
+### Evaluating Swin-MoE
+
+To evaluate a `Swin-MoE-S` with 32 experts on ImageNet-22K with 32 GPUs (4 nodes):
+
+1. Download the zip file [swin_moe_small_patch4_window12_192_32expert_32gpu_22k.zip](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.zip) which contains the pre-trained models for each rank, and unzip them to the folder "swin_moe_small_patch4_window12_192_32expert_32gpu_22k".
+2. Run the following evaluation command, note the checkpoint path should not contain the ".rank\<x\>" suffix.
+
+```bash
+python -m torch.distributed.launch --nproc_per_node 8 --nnode=4 \
+--node_rank=<node-rank> --master_addr=<master-ip> --master_port 12345  main_moe.py \
+--cfg configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml --data-path <imagenet22k-path> --batch-size 128 \
+--resume swin_moe_small_patch4_window12_192_32expert_32gpu_22k/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.pth 
+```
+
+More Swin-MoE models can be found in [MODEL HUB](MODELHUB.md#imagenet-22k-pretrained-swin-moe-models)
+
+## SimMIM Support
+
+### Evaluating provided models
+
+To evaluate a provided model on ImageNet validation set, run:
+```bash
+python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> main_simmim_ft.py \
+--eval --cfg <config-file> --resume <checkpoint> --data-path <imagenet-path>
+```
+
+For example, to evaluate the `Swin Base` model on a single GPU, run:
+```bash
+python -m torch.distributed.launch --nproc_per_node 1 main_simmim_ft.py \
+--eval --cfg configs/simmim/simmim_finetune__swin_base__img224_window7__800ep.yaml --resume simmim_finetune__swin_base__img224_window7__800ep.pth --data-path <imagenet-path>
+```
+
+### Pre-training with SimMIM
+To pre-train models with `SimMIM`, run:
+```bash
+python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> main_simmim_pt.py \ 
+--cfg <config-file> --data-path <imagenet-path>/train [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]
+```
+
+For example, to pre-train `Swin Base` for 800 epochs on one DGX-2 server, run:
+```bash
+python -m torch.distributed.launch --nproc_per_node 16 main_simmim_pt.py \ 
+--cfg configs/simmim/simmim_pretrain__swin_base__img192_window6__800ep.yaml --batch-size 128 --data-path <imagenet-path>/train [--output <output-directory> --tag <job-tag>]
+```
+
+### Fine-tuning pre-trained models
+To fine-tune models pre-trained by `SimMIM`, run:
+```bash
+python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> main_simmim_ft.py \ 
+--cfg <config-file> --data-path <imagenet-path> --pretrained <pretrained-ckpt> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]
+```
+
+For example, to fine-tune `Swin Base` pre-trained by `SimMIM` on one DGX-2 server, run:
+```bash
+python -m torch.distributed.launch --nproc_per_node 16 main_simmim_ft.py \ 
+--cfg configs/simmim/simmim_finetune__swin_base__img224_window7__800ep.yaml --batch-size 128 --data-path <imagenet-path> --pretrained <pretrained-ckpt> [--output <output-directory> --tag <job-tag>]
+```

+ 12 - 0
lib/SwinTransformer/kernels/window_process/setup.py

@@ -0,0 +1,12 @@
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+
+setup(name='swin_window_process',
+    ext_modules=[
+        CUDAExtension('swin_window_process', [
+            'swin_window_process.cpp',
+            'swin_window_process_kernel.cu',
+        ])
+    ],
+    cmdclass={'build_ext': BuildExtension})

+ 132 - 0
lib/SwinTransformer/kernels/window_process/swin_window_process.cpp

@@ -0,0 +1,132 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <torch/torch.h>
+#include <torch/extension.h>
+
+
+at::Tensor roll_and_window_partition_forward_cuda(
+    at::Tensor & input, 
+    //at::Tensor & output,
+    const int B,
+    const int H,
+    const int W,
+    const int C,
+    const int shift_size,
+    const int window_size);
+
+
+at::Tensor roll_and_window_partition_backward_cuda(
+    at::Tensor & grad_in, 
+    //at::Tensor & grad_out,
+    const int B,
+    const int H,
+    const int W,
+    const int C,
+    const int shift_size,
+    const int window_size);
+
+
+at::Tensor window_merge_and_roll_forward_cuda(
+    at::Tensor & input, 
+    //at::Tensor & output,
+    const int B,
+    const int H,
+    const int W,
+    const int C,
+    const int shift_size,
+    const int window_size);
+
+at::Tensor window_merge_and_roll_backward_cuda(
+    at::Tensor & grad_in, 
+    //at::Tensor & grad_out,
+    const int B,
+    const int H,
+    const int W,
+    const int C,
+    const int shift_size,
+    const int window_size);
+
+
+#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+
+
+at::Tensor roll_and_window_partition_forward(
+    at::Tensor & input, 
+    //at::Tensor & output,
+    const int B,
+    const int H,
+    const int W,
+    const int C,
+    const int shift_size,
+    const int window_size){
+    CHECK_INPUT(input);
+    return roll_and_window_partition_forward_cuda(input, B, H, W, C, shift_size, window_size);
+}
+
+
+at::Tensor roll_and_window_partition_backward(
+    at::Tensor & grad_in, 
+    //at::Tensor & grad_out,
+    const int B,
+    const int H,
+    const int W,
+    const int C,
+    const int shift_size,
+    const int window_size){
+    CHECK_INPUT(grad_in);
+    return roll_and_window_partition_backward_cuda(grad_in, B, H, W, C, shift_size, window_size);
+}
+
+
+at::Tensor window_merge_and_roll_forward(
+    at::Tensor & input, 
+    //at::Tensor & output,
+    const int B,
+    const int H,
+    const int W,
+    const int C,
+    const int shift_size,
+    const int window_size){
+    CHECK_INPUT(input);
+    return window_merge_and_roll_forward_cuda(input, B, H, W, C, shift_size, window_size);
+}
+
+
+at::Tensor window_merge_and_roll_backward(
+    at::Tensor & grad_in, 
+    //at::Tensor & grad_out,
+    const int B,
+    const int H,
+    const int W,
+    const int C,
+    const int shift_size,
+    const int window_size){
+    CHECK_INPUT(grad_in);
+    return window_merge_and_roll_backward_cuda(grad_in, B, H, W, C, shift_size, window_size);
+}
+
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+    m.def("roll_and_window_partition_forward", &roll_and_window_partition_forward, "torch.roll and window_partition.");
+    m.def("roll_and_window_partition_backward", &roll_and_window_partition_backward, "torch.roll and window_partition.");
+    m.def("window_merge_and_roll_forward", &window_merge_and_roll_forward, "window merge and torch.roll.");
+    m.def("window_merge_and_roll_backward", &window_merge_and_roll_backward, "window merge and torch.roll.");
+}

+ 323 - 0
lib/SwinTransformer/kernels/window_process/swin_window_process_kernel.cu

@@ -0,0 +1,323 @@
+/*
+ * Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <ATen/ATen.h>
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <cuda_fp16.h>
+#include <torch/extension.h>
+#include <stdio.h>
+
+int best_block_dim(int feat_dim){
+    int best_dim;
+    if (feat_dim < 384){
+        best_dim = 64;
+    }
+    else{
+        if (feat_dim < 1024){
+            best_dim = 128;
+        }
+        else{
+            best_dim = 256;
+        }
+    }
+    return best_dim;
+}
+
+
+template <typename T>
+__global__ void roll_and_window_partition_forward_cuda_kernel(
+    T* input, 
+    T* output, 
+    const int B,
+    const int H,
+    const int W,
+    const int C,
+    const int shift_size,
+    const int window_size,
+    const int nH,
+    const int nW){
+    // start
+    //bool qual = threadIdx.x < C;
+    int index = threadIdx.x;
+    int offset;
+    for (int i = index; i < C; i += blockDim.x) {
+        offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize
+        int input_offset = blockIdx.z / (nH * nW) * H * W * C +
+            (blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y - shift_size + H) % H * W * C + 
+            (blockIdx.z % nW * window_size + blockIdx.x - shift_size + W) % W * C +
+            i;
+        output[offset] = (T)(__ldg(input + input_offset));
+    }
+}
+
+
+template <typename T>
+__global__ void roll_and_window_partition_backward_cuda_kernel(
+    T* grad_in, 
+    T* grad_out, 
+    const int B,
+    const int H,
+    const int W,
+    const int C,
+    const int shift_size,
+    const int window_size,
+    const int nH,
+    const int nW){
+    // start
+    int index = threadIdx.x;
+    int offset;
+    for (int i = index; i < C; i += blockDim.x) {
+        offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize
+        int input_offset = 
+        (blockIdx.z * nH * nW + (blockIdx.y + shift_size + H) % H / window_size * nW + (blockIdx.x + shift_size + W) % W / window_size) * window_size * window_size * C +
+        (blockIdx.y + shift_size + H ) % H % window_size * window_size * C +
+        (blockIdx.x + shift_size + W ) % W % window_size * C +
+        i;
+        grad_out[offset] = (T)(__ldg(grad_in + input_offset));
+    }
+}
+
+
+template <typename T>
+__global__ void window_merge_and_roll_forward_cuda_kernel(
+    T* input, 
+    T* output, 
+    const int B,
+    const int H,
+    const int W,
+    const int C,
+    const int shift_size,
+    const int window_size,
+    const int nH,
+    const int nW){
+    // start
+    int index = threadIdx.x;
+    int offset;
+    for (int i = index; i < C; i += blockDim.x) {
+        offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize
+        int input_offset = 
+            (blockIdx.z * nH * nW + (blockIdx.y - shift_size + H) % H / window_size * nH + (blockIdx.x - shift_size + W) % W / window_size) * window_size * window_size * C +
+            (blockIdx.y - shift_size + H) % window_size * window_size * C + 
+            (blockIdx.x - shift_size + W) % window_size * C +
+            i;
+        output[offset] = (T)(__ldg(input + input_offset));
+    }
+}
+
+
+
+template <typename T>
+__global__ void window_merge_and_roll_backward_cuda_kernel(
+    T* grad_in, 
+    T* grad_out, 
+    const int B,
+    const int H,
+    const int W,
+    const int C,
+    const int shift_size,
+    const int window_size,
+    const int nH,
+    const int nW){
+    // start
+    int index = threadIdx.x;
+    int offset;
+    for (int i = index; i < C; i += blockDim.x) {
+        offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize
+        int input_offset = 
+        (blockIdx.z / (nH * nW)) * H * W * C +
+        (blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y + shift_size + H) % H * W * C +
+        (blockIdx.z % nW * window_size + blockIdx.x + shift_size + W) % W * C +
+        i;
+        grad_out[offset] = (T)(__ldg(grad_in + input_offset));
+    }
+}
+
+// input: [B, H, W, C]
+// output: [B*nH*nW, window_size, window_size, C]
+at::Tensor roll_and_window_partition_forward_cuda(
+    at::Tensor & input, 
+    //at::Tensor & output,
+    const int B,
+    const int H,
+    const int W,
+    const int C,
+    const int shift_size,
+    const int window_size){
+    
+    int nH = H / window_size;
+    int nW = W / window_size;
+
+    dim3 grid(window_size, window_size, B * nH * nW);
+    //dim3 block((C + 31) / 32 * 32);
+    int blocknum = best_block_dim(C);
+    dim3 block(blocknum);
+
+    at::Tensor output;
+    if (input.scalar_type() == torch::kFloat16){
+        output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true));
+    }
+    else{
+        output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true));
+    }
+
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "roll_and_window_partition_forward_cuda_kernel", ([&] {
+        roll_and_window_partition_forward_cuda_kernel<scalar_t><<<grid, block, 0>>>(
+            input.data<scalar_t>(),
+            output.data<scalar_t>(),
+            B,
+            H,
+            W,
+            C,
+            shift_size,
+            window_size,
+            nH,
+            nW);
+    }));
+    return output;
+}
+
+
+// grad_in: [B*nH*nW, window_size, window_size, C]
+// grad_out: [B, H, W, C]
+at::Tensor roll_and_window_partition_backward_cuda(
+    at::Tensor & grad_in, 
+    const int B,
+    const int H,
+    const int W,
+    const int C,
+    const int shift_size,
+    const int window_size){
+    
+    int nH = H / window_size;
+    int nW = W / window_size;
+
+    dim3 grid(W, H, B);
+    //dim3 block((C + 31) / 32 * 32);
+    int blocknum = best_block_dim(C);
+    dim3 block(blocknum);
+
+    at::Tensor grad_out;
+    if (grad_in.scalar_type() == torch::kFloat16){
+        grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false));
+    }
+    else{
+        grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
+    }
+
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), "roll_and_window_partition_backward_cuda_kernel", ([&] {
+        roll_and_window_partition_backward_cuda_kernel<scalar_t><<<grid, block, 0>>>(
+            grad_in.data<scalar_t>(),
+            grad_out.data<scalar_t>(),
+            B,
+            H,
+            W,
+            C,
+            shift_size,
+            window_size,
+            nH,
+            nW);
+    }));
+    return grad_out;
+}
+
+
+// input: [B*nH*nW, window_size, window_size, C]
+// output: [B, H, W, C]
+at::Tensor window_merge_and_roll_forward_cuda(
+    at::Tensor & input, 
+    //at::Tensor & output,
+    const int B,
+    const int H,
+    const int W,
+    const int C,
+    const int shift_size,
+    const int window_size){
+    
+    int nH = H / window_size;
+    int nW = W / window_size;
+
+    dim3 grid(W, H, B);
+    //dim3 block((C + 31) / 32 * 32);
+    int blocknum = best_block_dim(C);
+    dim3 block(blocknum);
+
+    //generate output tensor inside
+    at::Tensor output;
+    if (input.scalar_type() == torch::kFloat16){
+        output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true));
+    }
+    else{
+        output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true));
+    }
+
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "window_merge_and_roll_forward_cuda_kernel", ([&] {
+        window_merge_and_roll_forward_cuda_kernel<scalar_t><<<grid, block, 0>>>(
+            input.data<scalar_t>(),
+            output.data<scalar_t>(),
+            B,
+            H,
+            W,
+            C,
+            shift_size,
+            window_size,
+            nH,
+            nW);
+    }));
+    return output;
+}
+
+
+at::Tensor window_merge_and_roll_backward_cuda(
+    at::Tensor & grad_in, 
+    const int B,
+    const int H,
+    const int W,
+    const int C,
+    const int shift_size,
+    const int window_size){
+    
+    int nH = H / window_size;
+    int nW = W / window_size;
+
+    dim3 grid(window_size, window_size, B * nH * nW);
+    //dim3 block((C + 31) / 32 * 32);
+    int blocknum = best_block_dim(C);
+    dim3 block(blocknum);
+
+    at::Tensor grad_out;
+    if (grad_in.scalar_type() == torch::kFloat16){
+        grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false));
+    }
+    else{
+        grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false));
+    }
+
+    AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), "window_merge_and_roll_backward_cuda_kernel", ([&] {
+        window_merge_and_roll_backward_cuda_kernel<scalar_t><<<grid, block, 0>>>(
+            grad_in.data<scalar_t>(),
+            grad_out.data<scalar_t>(),
+            B,
+            H,
+            W,
+            C,
+            shift_size,
+            window_size,
+            nH,
+            nW);
+    }));
+    return grad_out;
+}

+ 250 - 0
lib/SwinTransformer/kernels/window_process/unit_test.py

@@ -0,0 +1,250 @@
+# --------------------------------------------------------
+# Fused kernel for window process for SwinTransformer
+# Copyright (c) 2022 Nvidia
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+import torch
+import swin_window_process
+import random
+import time
+import unittest
+
+
+class WindowProcess(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input, B, H, W, C, shift_size, window_size):
+        output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size)
+
+        ctx.B = B
+        ctx.H = H
+        ctx.W = W 
+        ctx.C = C 
+        ctx.shift_size = shift_size
+        ctx.window_size = window_size
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_in):
+        B = ctx.B
+        H = ctx.H
+        W = ctx.W 
+        C = ctx.C 
+        shift_size = ctx.shift_size
+        window_size = ctx.window_size
+
+        grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size)
+        return grad_out, None, None, None, None, None, None, None
+
+
+class WindowProcessReverse(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input, B, H, W, C, shift_size, window_size):
+        output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size)
+
+        ctx.B = B
+        ctx.H = H
+        ctx.W = W 
+        ctx.C = C 
+        ctx.shift_size = shift_size
+        ctx.window_size = window_size
+
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_in):
+        B = ctx.B
+        H = ctx.H
+        W = ctx.W 
+        C = ctx.C 
+        shift_size = ctx.shift_size
+        window_size = ctx.window_size
+
+        grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size)
+        return grad_out, None, None, None, None, None, None, None
+
+
+def window_partition(x, window_size):
+    """
+    Args:
+        x: (B, H, W, C)
+        window_size (int): window size
+    Returns:
+        windows: (num_windows*B, window_size, window_size, C)
+    """
+    B, H, W, C = x.shape
+    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    return windows
+
+def window_reverse(windows, window_size, H, W):
+    """
+    Args:
+        windows: (num_windows*B, window_size, window_size, C)
+        window_size (int): Window size
+        H (int): Height of image
+        W (int): Width of image
+    Returns:
+        x: (B, H, W, C)
+    """
+    B = int(windows.shape[0] / (H * W / window_size / window_size))
+    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+    return x
+
+
+def pyt_forward(x, shift_size, window_size):
+    # x in shape(B, H, W, C)
+    # cyclic shift
+    if shift_size > 0:
+        shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))
+    else:
+        shifted_x = x
+    # partition windows
+    x_windows = window_partition(shifted_x, window_size)
+    return x_windows
+
+
+def reverse_pyt_forward(attn_windows, shift_size, window_size, H, W):
+    # x in shape(B*nH*nW, window_size, window_size, C)
+    shifted_x = window_reverse(attn_windows, window_size, H, W)
+    if shift_size > 0:
+        x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2))
+    else:
+        x = shifted_x
+    return x
+
+
+def copy_one_tensor(input, requires_grad=True):
+    input1 = input.clone().detach().requires_grad_(requires_grad).cuda()
+    return input1
+
+class Test_WindowProcess(unittest.TestCase):
+    def setUp(self):
+        self.B = 192
+        self.H = 56
+        self.W = 56
+        self.C = 96
+        self.shift_size = 2
+        self.window_size = 7
+        self.nH = self.H // self.window_size
+        self.nW = self.W // self.window_size
+    
+    def test_roll_and_window_partition_forward(self, dtype=torch.float32):
+        input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
+        
+        input1 = copy_one_tensor(input, True)
+        input2 = copy_one_tensor(input, True)
+
+        with torch.no_grad():
+            # ori
+            expected = pyt_forward(input1, self.shift_size, self.window_size)
+            # fused kernel
+            fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size)
+        
+        self.assertTrue(torch.equal(expected, fused_output))
+        #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
+    
+    def test_roll_and_window_partition_backward(self, dtype=torch.float32):
+        input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
+        d_loss_tensor = torch.randn((self.B*self.nW*self.nH, self.window_size, self.window_size, self.C), dtype=dtype).cuda()
+        
+        input1 = copy_one_tensor(input, True)
+        input2 = copy_one_tensor(input, True)
+
+        # ori
+        expected = pyt_forward(input1, self.shift_size, self.window_size)
+        expected.backward(d_loss_tensor)
+        # fused kernel
+        fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size)
+        fused_output.backward(d_loss_tensor)
+        
+        self.assertTrue(torch.equal(expected, fused_output))
+        #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
+
+    def test_window_merge_and_roll_forward(self, dtype=torch.float32):
+        input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
+        
+        input1 = copy_one_tensor(input, True)
+        input2 = copy_one_tensor(input, True)
+
+        with torch.no_grad():
+            # ori
+            expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
+            # fused kernel
+            fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
+        
+        self.assertTrue(torch.equal(expected, fused_output))
+        #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
+    
+
+    def test_window_merge_and_roll_backward(self, dtype=torch.float32):
+        input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
+        d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
+        
+        input1 = copy_one_tensor(input, True)
+        input2 = copy_one_tensor(input, True)
+
+        # ori
+        expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
+        expected.backward(d_loss_tensor)
+        # fused kernel
+        fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
+        fused_output.backward(d_loss_tensor)
+        
+        self.assertTrue(torch.equal(expected, fused_output))
+        #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08))
+
+    def test_forward_backward_speed(self, dtype=torch.float32, times=1000):
+        input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda()
+        d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda()
+        
+        input1 = copy_one_tensor(input, True)
+        input2 = copy_one_tensor(input, True)
+
+        # SwinTransformer official
+        def run_pyt(t=1000):
+            for _ in range(t):
+                expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W)
+                expected.backward(d_loss_tensor)
+
+        # my op
+        def run_fusedop(t=1000):
+            for _ in range(t):
+                fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size)
+                fused_output.backward(d_loss_tensor)
+        
+        torch.cuda.synchronize()
+        t1 = time.time()
+        run_pyt(t=times)
+        torch.cuda.synchronize()
+        t2 = time.time()
+        run_fusedop(t=times)
+        torch.cuda.synchronize()
+        t3 = time.time()
+        self.assertTrue((t3 - t2) < (t2 - t1))
+
+        print('Run {} times'.format(times))
+        print('Original time cost: {}'.format(t2 - t1))
+        print('Fused op time cost: {}'.format(t3 - t2))
+    
+    def test_roll_and_window_partition_forward_fp16(self, dtype=torch.float16):
+        self.test_roll_and_window_partition_forward(dtype=dtype)
+
+    def test_roll_and_window_partition_backward_fp16(self, dtype=torch.float16):
+        self.test_roll_and_window_partition_backward(dtype=dtype)
+
+    def test_window_merge_and_roll_forward_fp16(self, dtype=torch.float16):
+        self.test_window_merge_and_roll_forward(dtype=dtype)
+    
+    def test_window_merge_and_roll_backward_fp16(self, dtype=torch.float16):
+        self.test_window_merge_and_roll_backward(dtype=dtype)
+
+    def test_forward_backward_speed_fp16(self, dtype=torch.float16, times=1000):
+        self.test_forward_backward_speed(dtype=dtype, times=times)
+
+
+if __name__ == '__main__':
+    print('Pass only two tensors are exactly the same (using torch.equal).\n')
+    torch.manual_seed(0)
+    unittest.main(verbosity=2)

+ 63 - 0
lib/SwinTransformer/kernels/window_process/window_process.py

@@ -0,0 +1,63 @@
+# --------------------------------------------------------
+# Fused kernel for window process for SwinTransformer
+# Copyright (c) 2022 Nvidia
+# Licensed under The MIT License [see LICENSE for details]
+# --------------------------------------------------------
+
+import torch
+import swin_window_process
+
+
+class WindowProcess(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input, B, H, W, C, shift_size, window_size):
+        output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size)
+
+        ctx.B = B
+        ctx.H = H
+        ctx.W = W 
+        ctx.C = C 
+        ctx.shift_size = shift_size
+        ctx.window_size = window_size
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_in):
+        B = ctx.B
+        H = ctx.H
+        W = ctx.W 
+        C = ctx.C 
+        shift_size = ctx.shift_size
+        window_size = ctx.window_size
+
+        grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size)
+        return grad_out, None, None, None, None, None, None, None
+
+
+class WindowProcessReverse(torch.autograd.Function):
+    @staticmethod
+    def forward(ctx, input, B, H, W, C, shift_size, window_size):
+        output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size)
+
+        ctx.B = B
+        ctx.H = H
+        ctx.W = W 
+        ctx.C = C 
+        ctx.shift_size = shift_size
+        ctx.window_size = window_size
+
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_in):
+        B = ctx.B
+        H = ctx.H
+        W = ctx.W 
+        C = ctx.C 
+        shift_size = ctx.shift_size
+        window_size = ctx.window_size
+
+        #grad_out = ctx.saved_tensors[0]
+        #grad_out = torch.zeros((B, H, W, C), dtype=dtype).cuda()
+        grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size)
+        return grad_out, None, None, None, None, None, None, None

+ 41 - 0
lib/SwinTransformer/logger.py

@@ -0,0 +1,41 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import os
+import sys
+import logging
+import functools
+from termcolor import colored
+
+
+@functools.lru_cache()
+def create_logger(output_dir, dist_rank=0, name=''):
+    # create logger
+    logger = logging.getLogger(name)
+    logger.setLevel(logging.DEBUG)
+    logger.propagate = False
+
+    # create formatter
+    fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
+    color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
+                colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
+
+    # create console handlers for master process
+    if dist_rank == 0:
+        console_handler = logging.StreamHandler(sys.stdout)
+        console_handler.setLevel(logging.DEBUG)
+        console_handler.setFormatter(
+            logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
+        logger.addHandler(console_handler)
+
+    # create file handlers
+    file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
+    file_handler.setLevel(logging.DEBUG)
+    file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
+    logger.addHandler(file_handler)
+
+    return logger

+ 152 - 0
lib/SwinTransformer/lr_scheduler.py

@@ -0,0 +1,152 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import bisect
+
+import torch
+from timm.scheduler.cosine_lr import CosineLRScheduler
+from timm.scheduler.step_lr import StepLRScheduler
+from timm.scheduler.scheduler import Scheduler
+
+
+def build_scheduler(config, optimizer, n_iter_per_epoch):
+    num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
+    warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
+    decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
+    multi_steps = [i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS]
+
+    lr_scheduler = None
+    if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
+        lr_scheduler = CosineLRScheduler(
+            optimizer,
+            t_initial=(num_steps - warmup_steps) if config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX else num_steps,
+            t_mul=1.,
+            lr_min=config.TRAIN.MIN_LR,
+            warmup_lr_init=config.TRAIN.WARMUP_LR,
+            warmup_t=warmup_steps,
+            cycle_limit=1,
+            t_in_epochs=False,
+            warmup_prefix=config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX,
+        )
+    elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
+        lr_scheduler = LinearLRScheduler(
+            optimizer,
+            t_initial=num_steps,
+            lr_min_rate=0.01,
+            warmup_lr_init=config.TRAIN.WARMUP_LR,
+            warmup_t=warmup_steps,
+            t_in_epochs=False,
+        )
+    elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
+        lr_scheduler = StepLRScheduler(
+            optimizer,
+            decay_t=decay_steps,
+            decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
+            warmup_lr_init=config.TRAIN.WARMUP_LR,
+            warmup_t=warmup_steps,
+            t_in_epochs=False,
+        )
+    elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep':
+        lr_scheduler = MultiStepLRScheduler(
+            optimizer,
+            milestones=multi_steps,
+            gamma=config.TRAIN.LR_SCHEDULER.GAMMA,
+            warmup_lr_init=config.TRAIN.WARMUP_LR,
+            warmup_t=warmup_steps,
+            t_in_epochs=False,
+        )
+
+    return lr_scheduler
+
+
+class LinearLRScheduler(Scheduler):
+    def __init__(self,
+                 optimizer: torch.optim.Optimizer,
+                 t_initial: int,
+                 lr_min_rate: float,
+                 warmup_t=0,
+                 warmup_lr_init=0.,
+                 t_in_epochs=True,
+                 noise_range_t=None,
+                 noise_pct=0.67,
+                 noise_std=1.0,
+                 noise_seed=42,
+                 initialize=True,
+                 ) -> None:
+        super().__init__(
+            optimizer, param_group_field="lr",
+            noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
+            initialize=initialize)
+
+        self.t_initial = t_initial
+        self.lr_min_rate = lr_min_rate
+        self.warmup_t = warmup_t
+        self.warmup_lr_init = warmup_lr_init
+        self.t_in_epochs = t_in_epochs
+        if self.warmup_t:
+            self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
+            super().update_groups(self.warmup_lr_init)
+        else:
+            self.warmup_steps = [1 for _ in self.base_values]
+
+    def _get_lr(self, t):
+        if t < self.warmup_t:
+            lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
+        else:
+            t = t - self.warmup_t
+            total_t = self.t_initial - self.warmup_t
+            lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
+        return lrs
+
+    def get_epoch_values(self, epoch: int):
+        if self.t_in_epochs:
+            return self._get_lr(epoch)
+        else:
+            return None
+
+    def get_update_values(self, num_updates: int):
+        if not self.t_in_epochs:
+            return self._get_lr(num_updates)
+        else:
+            return None
+
+
+class MultiStepLRScheduler(Scheduler):
+    def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None:
+        super().__init__(optimizer, param_group_field="lr")
+        
+        self.milestones = milestones
+        self.gamma = gamma
+        self.warmup_t = warmup_t
+        self.warmup_lr_init = warmup_lr_init
+        self.t_in_epochs = t_in_epochs
+        if self.warmup_t:
+            self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
+            super().update_groups(self.warmup_lr_init)
+        else:
+            self.warmup_steps = [1 for _ in self.base_values]
+        
+        assert self.warmup_t <= min(self.milestones)
+    
+    def _get_lr(self, t):
+        if t < self.warmup_t:
+            lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
+        else:
+            lrs = [v * (self.gamma ** bisect.bisect_right(self.milestones, t)) for v in self.base_values]
+        return lrs
+
+    def get_epoch_values(self, epoch: int):
+        if self.t_in_epochs:
+            return self._get_lr(epoch)
+        else:
+            return None
+
+    def get_update_values(self, num_updates: int):
+        if not self.t_in_epochs:
+            return self._get_lr(num_updates)
+        else:
+            return None

+ 354 - 0
lib/SwinTransformer/main.py

@@ -0,0 +1,354 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import os
+import time
+import json
+import random
+import argparse
+import datetime
+import numpy as np
+
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
+from timm.utils import accuracy, AverageMeter
+
+from config import get_config
+from models import build_model
+from data import build_loader
+from lr_scheduler import build_scheduler
+from optimizer import build_optimizer
+from logger import create_logger
+from utils import load_checkpoint, load_pretrained, save_checkpoint, NativeScalerWithGradNormCount, auto_resume_helper, \
+    reduce_tensor
+
+# pytorch major version (1.x or 2.x)
+PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
+
+
+def parse_option():
+    parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
+    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
+    parser.add_argument(
+        "--opts",
+        help="Modify config options by adding 'KEY VALUE' pairs. ",
+        default=None,
+        nargs='+',
+    )
+
+    # easy config modification
+    parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
+    parser.add_argument('--data-path', type=str, help='path to dataset')
+    parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
+    parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
+                        help='no: no cache, '
+                             'full: cache all data, '
+                             'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
+    parser.add_argument('--pretrained',
+                        help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')
+    parser.add_argument('--resume', help='resume from checkpoint')
+    parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
+    parser.add_argument('--use-checkpoint', action='store_true',
+                        help="whether to use gradient checkpointing to save memory")
+    parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')
+    parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'],
+                        help='mixed precision opt level, if O0, no amp is used (deprecated!)')
+    parser.add_argument('--output', default='output', type=str, metavar='PATH',
+                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
+    parser.add_argument('--tag', help='tag of experiment')
+    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
+    parser.add_argument('--throughput', action='store_true', help='Test throughput only')
+
+    # distributed training
+    # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
+    # (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
+    if PYTORCH_MAJOR_VERSION == 1:
+        parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
+
+    # for acceleration
+    parser.add_argument('--fused_window_process', action='store_true',
+                        help='Fused window shift & window partition, similar for reversed part.')
+    parser.add_argument('--fused_layernorm', action='store_true', help='Use fused layernorm.')
+    ## overwrite optimizer in config (*.yaml) if specified, e.g., fused_adam/fused_lamb
+    parser.add_argument('--optim', type=str,
+                        help='overwrite optimizer if provided, can be adamw/sgd/fused_adam/fused_lamb.')
+
+    args, unparsed = parser.parse_known_args()
+
+    config = get_config(args)
+
+    return args, config
+
+
+def main(config):
+    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
+
+    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
+    model = build_model(config)
+    logger.info(str(model))
+
+    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    logger.info(f"number of params: {n_parameters}")
+    if hasattr(model, 'flops'):
+        flops = model.flops()
+        logger.info(f"number of GFLOPs: {flops / 1e9}")
+
+    model.cuda()
+    model_without_ddp = model
+
+    optimizer = build_optimizer(config, model)
+    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
+    loss_scaler = NativeScalerWithGradNormCount()
+
+    if config.TRAIN.ACCUMULATION_STEPS > 1:
+        lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS)
+    else:
+        lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
+
+    if config.AUG.MIXUP > 0.:
+        # smoothing is handled with mixup label transform
+        criterion = SoftTargetCrossEntropy()
+    elif config.MODEL.LABEL_SMOOTHING > 0.:
+        criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
+    else:
+        criterion = torch.nn.CrossEntropyLoss()
+
+    max_accuracy = 0.0
+
+    if config.TRAIN.AUTO_RESUME:
+        resume_file = auto_resume_helper(config.OUTPUT)
+        if resume_file:
+            if config.MODEL.RESUME:
+                logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
+            config.defrost()
+            config.MODEL.RESUME = resume_file
+            config.freeze()
+            logger.info(f'auto resuming from {resume_file}')
+        else:
+            logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
+
+    if config.MODEL.RESUME:
+        max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger)
+        acc1, acc5, loss = validate(config, data_loader_val, model)
+        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
+        if config.EVAL_MODE:
+            return
+
+    if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
+        load_pretrained(config, model_without_ddp, logger)
+        acc1, acc5, loss = validate(config, data_loader_val, model)
+        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
+
+    if config.THROUGHPUT_MODE:
+        throughput(data_loader_val, model, logger)
+        return
+
+    logger.info("Start training")
+    start_time = time.time()
+    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
+        data_loader_train.sampler.set_epoch(epoch)
+
+        train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler,
+                        loss_scaler)
+        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
+            save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler,
+                            logger)
+
+        acc1, acc5, loss = validate(config, data_loader_val, model)
+        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
+        max_accuracy = max(max_accuracy, acc1)
+        logger.info(f'Max accuracy: {max_accuracy:.2f}%')
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    logger.info('Training time {}'.format(total_time_str))
+
+
+def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler):
+    model.train()
+    optimizer.zero_grad()
+
+    num_steps = len(data_loader)
+    batch_time = AverageMeter()
+    loss_meter = AverageMeter()
+    norm_meter = AverageMeter()
+    scaler_meter = AverageMeter()
+
+    start = time.time()
+    end = time.time()
+    for idx, (samples, targets) in enumerate(data_loader):
+        samples = samples.cuda(non_blocking=True)
+        targets = targets.cuda(non_blocking=True)
+
+        if mixup_fn is not None:
+            samples, targets = mixup_fn(samples, targets)
+
+        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
+            outputs = model(samples)
+        loss = criterion(outputs, targets)
+        loss = loss / config.TRAIN.ACCUMULATION_STEPS
+
+        # this attribute is added by timm on one optimizer (adahessian)
+        is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
+        grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD,
+                                parameters=model.parameters(), create_graph=is_second_order,
+                                update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0)
+        if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
+            optimizer.zero_grad()
+            lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS)
+        loss_scale_value = loss_scaler.state_dict()["scale"]
+
+        torch.cuda.synchronize()
+
+        loss_meter.update(loss.item(), targets.size(0))
+        if grad_norm is not None:  # loss_scaler return None if not update
+            norm_meter.update(grad_norm)
+        scaler_meter.update(loss_scale_value)
+        batch_time.update(time.time() - end)
+        end = time.time()
+
+        if idx % config.PRINT_FREQ == 0:
+            lr = optimizer.param_groups[0]['lr']
+            wd = optimizer.param_groups[0]['weight_decay']
+            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
+            etas = batch_time.avg * (num_steps - idx)
+            logger.info(
+                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
+                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t wd {wd:.4f}\t'
+                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
+                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
+                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
+                f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t'
+                f'mem {memory_used:.0f}MB')
+    epoch_time = time.time() - start
+    logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
+
+
+@torch.no_grad()
+def validate(config, data_loader, model):
+    criterion = torch.nn.CrossEntropyLoss()
+    model.eval()
+
+    batch_time = AverageMeter()
+    loss_meter = AverageMeter()
+    acc1_meter = AverageMeter()
+    acc5_meter = AverageMeter()
+
+    end = time.time()
+    for idx, (images, target) in enumerate(data_loader):
+        images = images.cuda(non_blocking=True)
+        target = target.cuda(non_blocking=True)
+
+        # compute output
+        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
+            output = model(images)
+
+        # measure accuracy and record tools
+        loss = criterion(output, target)
+        acc1, acc5 = accuracy(output, target, topk=(1, 5))
+
+        acc1 = reduce_tensor(acc1)
+        acc5 = reduce_tensor(acc5)
+        loss = reduce_tensor(loss)
+
+        loss_meter.update(loss.item(), target.size(0))
+        acc1_meter.update(acc1.item(), target.size(0))
+        acc5_meter.update(acc5.item(), target.size(0))
+
+        # measure elapsed time
+        batch_time.update(time.time() - end)
+        end = time.time()
+
+        if idx % config.PRINT_FREQ == 0:
+            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
+            logger.info(
+                f'Test: [{idx}/{len(data_loader)}]\t'
+                f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
+                f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
+                f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
+                f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
+                f'Mem {memory_used:.0f}MB')
+    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
+    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
+
+
+@torch.no_grad()
+def throughput(data_loader, model, logger):
+    model.eval()
+
+    for idx, (images, _) in enumerate(data_loader):
+        images = images.cuda(non_blocking=True)
+        batch_size = images.shape[0]
+        for i in range(50):
+            model(images)
+        torch.cuda.synchronize()
+        logger.info(f"throughput averaged with 30 times")
+        tic1 = time.time()
+        for i in range(30):
+            model(images)
+        torch.cuda.synchronize()
+        tic2 = time.time()
+        logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
+        return
+
+
+if __name__ == '__main__':
+    args, config = parse_option()
+
+    if config.AMP_OPT_LEVEL:
+        print("[warning] Apex amp has been deprecated, please use pytorch amp instead!")
+
+    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+        rank = int(os.environ["RANK"])
+        world_size = int(os.environ['WORLD_SIZE'])
+        print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
+    else:
+        rank = -1
+        world_size = -1
+    torch.cuda.set_device(config.LOCAL_RANK)
+    torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
+    torch.distributed.barrier()
+
+    seed = config.SEED + dist.get_rank()
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    np.random.seed(seed)
+    random.seed(seed)
+    cudnn.benchmark = True
+
+    # linear scale the learning rate according to total batch size, may not be optimal
+    linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
+    linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
+    linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
+    # gradient accumulation also need to scale the learning rate
+    if config.TRAIN.ACCUMULATION_STEPS > 1:
+        linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
+        linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
+        linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
+    config.defrost()
+    config.TRAIN.BASE_LR = linear_scaled_lr
+    config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
+    config.TRAIN.MIN_LR = linear_scaled_min_lr
+    config.freeze()
+
+    os.makedirs(config.OUTPUT, exist_ok=True)
+    logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
+
+    if dist.get_rank() == 0:
+        path = os.path.join(config.OUTPUT, "config.json")
+        with open(path, "w") as f:
+            f.write(config.dump())
+        logger.info(f"Full config saved to {path}")
+
+    # print config
+    logger.info(config.dump())
+    logger.info(json.dumps(vars(args)))
+
+    main(config)

+ 373 - 0
lib/SwinTransformer/main_moe.py

@@ -0,0 +1,373 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+from tutel import system
+
+import os
+import time
+import json
+import random
+import argparse
+import datetime
+import numpy as np
+from functools import partial
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
+from timm.utils import accuracy, AverageMeter
+
+from config import get_config
+from models import build_model
+from data import build_loader
+from lr_scheduler import build_scheduler
+from optimizer import build_optimizer
+from logger import create_logger
+from utils import NativeScalerWithGradNormCount, reduce_tensor
+from utils_moe import load_checkpoint, load_pretrained, save_checkpoint, auto_resume_helper, hook_scale_grad
+
+assert torch.__version__ >= '1.8.0', "DDP-based MoE requires Pytorch >= 1.8.0"
+
+# pytorch major version (1.x or 2.x)
+PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
+
+
+def parse_option():
+    parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
+    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
+    parser.add_argument(
+        "--opts",
+        help="Modify config options by adding 'KEY VALUE' pairs. ",
+        default=None,
+        nargs='+',
+    )
+
+    # easy config modification
+    parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
+    parser.add_argument('--data-path', type=str, help='path to dataset')
+    parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
+    parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
+                        help='no: no cache, '
+                             'full: cache all data, '
+                             'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
+    parser.add_argument('--pretrained',
+                        help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')
+    parser.add_argument('--resume', help='resume from checkpoint')
+    parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
+    parser.add_argument('--use-checkpoint', action='store_true',
+                        help="whether to use gradient checkpointing to save memory")
+    parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')
+    parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'],
+                        help='mixed precision opt level, if O0, no amp is used (deprecated!)')
+    parser.add_argument('--output', default='output', type=str, metavar='PATH',
+                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
+    parser.add_argument('--tag', help='tag of experiment')
+    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
+    parser.add_argument('--throughput', action='store_true', help='Test throughput only')
+
+    # distributed training
+    # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
+    # (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
+    if PYTORCH_MAJOR_VERSION == 1:
+        parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
+
+    args, unparsed = parser.parse_known_args()
+
+    config = get_config(args)
+
+    return args, config
+
+
+def main(config):
+    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
+
+    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
+    model = build_model(config)
+    logger.info(str(model))
+
+    # For Tutel MoE
+    for name, param in model.named_parameters():
+        if param.requires_grad == True and hasattr(param, 'skip_allreduce') and param.skip_allreduce is True:
+            model.add_param_to_skip_allreduce(name)
+            param.register_hook(partial(hook_scale_grad, dist.get_world_size()))
+            logger.info(f"[rank{dist.get_rank()}] [{name}] skip all_reduce and div {dist.get_world_size()} for grad")
+
+    n_parameters_single = sum(p.numel() * model.sharded_count if hasattr(p, 'skip_allreduce')
+                              else p.numel() for p in model.parameters() if p.requires_grad)
+    logger.info(f"number of params single: {n_parameters_single}")
+    n_parameters_whole = sum(p.numel() * model.sharded_count * model.global_experts if hasattr(p, 'skip_allreduce')
+                             else p.numel() for p in model.parameters() if p.requires_grad)
+    logger.info(f"number of params whole: {n_parameters_whole}")
+    if hasattr(model, 'flops'):
+        flops = model.flops()
+        logger.info(f"number of GFLOPs: {flops / 1e9}")
+
+    model.cuda(config.LOCAL_RANK)
+    model_without_ddp = model
+
+    optimizer = build_optimizer(config, model)
+    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
+    loss_scaler = NativeScalerWithGradNormCount()
+
+    if config.TRAIN.ACCUMULATION_STEPS > 1:
+        lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS)
+    else:
+        lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
+
+    if config.AUG.MIXUP > 0.:
+        # smoothing is handled with mixup label transform
+        criterion = SoftTargetCrossEntropy()
+    elif config.MODEL.LABEL_SMOOTHING > 0.:
+        criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
+    else:
+        criterion = torch.nn.CrossEntropyLoss()
+
+    max_accuracy = 0.0
+
+    if config.TRAIN.AUTO_RESUME:
+        resume_file = auto_resume_helper(config.OUTPUT, config.TRAIN.MOE.SAVE_MASTER)
+        if resume_file:
+            if config.MODEL.RESUME:
+                logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
+            config.defrost()
+            config.MODEL.RESUME = resume_file
+            config.freeze()
+            logger.info(f'auto resuming from {resume_file}')
+        else:
+            logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
+
+    if config.MODEL.RESUME:
+        max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger)
+        acc1, acc5, loss = validate(config, data_loader_val, model)
+        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
+        if config.EVAL_MODE:
+            return
+
+    if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
+        load_pretrained(config, model_without_ddp, logger)
+        acc1, acc5, loss = validate(config, data_loader_val, model)
+        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
+        if config.EVAL_MODE:
+            return
+
+    if config.THROUGHPUT_MODE:
+        throughput(data_loader_val, model, logger)
+        return
+
+    logger.info("Start training")
+    start_time = time.time()
+    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
+        data_loader_train.sampler.set_epoch(epoch)
+
+        train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler,
+                        loss_scaler)
+        if (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
+            save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler,
+                            logger)
+
+        acc1, acc5, loss = validate(config, data_loader_val, model)
+        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
+        max_accuracy = max(max_accuracy, acc1)
+        logger.info(f'Max accuracy: {max_accuracy:.2f}%')
+    save_checkpoint(config, 'final', model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler,
+                    logger, zero_redundancy=True)
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    logger.info('Training time {}'.format(total_time_str))
+
+
+def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler):
+    model.train()
+    optimizer.zero_grad()
+
+    num_steps = len(data_loader)
+    batch_time = AverageMeter()
+    loss_meter = AverageMeter()
+    loss_aux_meter = AverageMeter()
+    loss_cls_meter = AverageMeter()
+    norm_meter = AverageMeter()
+    scaler_meter = AverageMeter()
+
+    start = time.time()
+    end = time.time()
+    for idx, (samples, targets) in enumerate(data_loader):
+        samples = samples.cuda(non_blocking=True)
+        targets = targets.cuda(non_blocking=True)
+
+        if mixup_fn is not None:
+            samples, targets = mixup_fn(samples, targets)
+
+        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
+            outputs, l_aux = model(samples)
+        l_cls = criterion(outputs, targets)
+        loss = l_cls + l_aux
+        loss = loss / config.TRAIN.ACCUMULATION_STEPS
+
+        # this attribute is added by timm on one optimizer (adahessian)
+        is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
+        grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD,
+                                parameters=model.parameters(), create_graph=is_second_order,
+                                update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0)
+        if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
+            optimizer.zero_grad()
+            lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS)
+        loss_scale_value = loss_scaler.state_dict()["scale"]
+
+        torch.cuda.synchronize()
+
+        loss_meter.update(loss.item(), targets.size(0))
+        loss_cls_meter.update(l_cls.item(), targets.size(0))
+        loss_aux_meter.update(l_aux if isinstance(l_aux, float) else l_aux.item(), targets.size(0))
+        if grad_norm is not None:  # loss_scaler return None if not update
+            norm_meter.update(grad_norm)
+        scaler_meter.update(loss_scale_value)
+        batch_time.update(time.time() - end)
+        end = time.time()
+
+        if idx % config.PRINT_FREQ == 0:
+            lr = optimizer.param_groups[0]['lr']
+            wd = optimizer.param_groups[0]['weight_decay']
+            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
+            etas = batch_time.avg * (num_steps - idx)
+            logger.info(
+                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
+                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t wd {wd:.4f}\t'
+                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
+                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
+                f'loss-cls {loss_cls_meter.val:.4f} ({loss_cls_meter.avg:.4f})\t'
+                f'loss-aux {loss_aux_meter.val:.4f} ({loss_aux_meter.avg:.4f})\t'
+                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
+                f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t'
+                f'mem {memory_used:.0f}MB')
+    epoch_time = time.time() - start
+    logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
+
+
+@torch.no_grad()
+def validate(config, data_loader, model):
+    criterion = torch.nn.CrossEntropyLoss()
+    model.eval()
+
+    batch_time = AverageMeter()
+    loss_cls_meter = AverageMeter()
+    loss_aux_meter = AverageMeter()
+    acc1_meter = AverageMeter()
+    acc5_meter = AverageMeter()
+
+    end = time.time()
+    for idx, (images, target) in enumerate(data_loader):
+        images = images.cuda(non_blocking=True)
+        target = target.cuda(non_blocking=True)
+
+        # compute output
+        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
+            output, l_aux = model(images)
+
+        # measure accuracy and record tools
+        l_cls = criterion(output, target)
+        acc1, acc5 = accuracy(output, target, topk=(1, 5))
+
+        acc1 = reduce_tensor(acc1)
+        acc5 = reduce_tensor(acc5)
+
+        loss_cls_meter.update(l_cls.item(), target.size(0))
+        loss_aux_meter.update(l_aux if isinstance(l_aux, float) else l_aux.item(), target.size(0))
+        acc1_meter.update(acc1.item(), target.size(0))
+        acc5_meter.update(acc5.item(), target.size(0))
+
+        # measure elapsed time
+        batch_time.update(time.time() - end)
+        end = time.time()
+
+        if idx % config.PRINT_FREQ == 0:
+            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
+            logger.info(
+                f'Test: [{idx}/{len(data_loader)}]\t'
+                f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
+                f'Loss-Cls {loss_cls_meter.val:.4f} ({loss_cls_meter.avg:.4f})\t'
+                f'Loss-Aux {loss_aux_meter.val:.4f} ({loss_aux_meter.avg:.4f})\t'
+                f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
+                f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
+                f'Mem {memory_used:.0f}MB')
+    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
+    return acc1_meter.avg, acc5_meter.avg, loss_cls_meter.avg
+
+
+@torch.no_grad()
+def throughput(data_loader, model, logger):
+    model.eval()
+
+    for idx, (images, _) in enumerate(data_loader):
+        images = images.cuda(non_blocking=True)
+        batch_size = images.shape[0]
+        for i in range(50):
+            model(images)
+        torch.cuda.synchronize()
+        logger.info(f"throughput averaged with 30 times")
+        tic1 = time.time()
+        for i in range(30):
+            model(images)
+        torch.cuda.synchronize()
+        tic2 = time.time()
+        logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
+        return
+
+
+if __name__ == '__main__':
+    args, config = parse_option()
+
+    if config.AMP_OPT_LEVEL:
+        print("[warning] Apex amp has been deprecated, please use pytorch amp instead!")
+
+    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+        rank = int(os.environ["RANK"])
+        world_size = int(os.environ['WORLD_SIZE'])
+        print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
+    else:
+        rank = -1
+        world_size = -1
+    torch.cuda.set_device(config.LOCAL_RANK)
+    torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
+    torch.distributed.barrier()
+
+    seed = config.SEED + dist.get_rank()
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed(seed)
+    np.random.seed(seed)
+    random.seed(seed)
+    cudnn.benchmark = True
+
+    # linear scale the learning rate according to total batch size, may not be optimal
+    linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
+    linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
+    linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
+    # gradient accumulation also need to scale the learning rate
+    if config.TRAIN.ACCUMULATION_STEPS > 1:
+        linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
+        linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
+        linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
+    config.defrost()
+    config.TRAIN.BASE_LR = linear_scaled_lr
+    config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
+    config.TRAIN.MIN_LR = linear_scaled_min_lr
+    config.freeze()
+
+    os.makedirs(config.OUTPUT, exist_ok=True)
+    logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
+
+    if dist.get_rank() == 0:
+        path = os.path.join(config.OUTPUT, "config.json")
+        with open(path, "w") as f:
+            f.write(config.dump())
+        logger.info(f"Full config saved to {path}")
+
+    # print config
+    logger.info(config.dump())
+    logger.info(json.dumps(vars(args)))
+
+    main(config)

+ 342 - 0
lib/SwinTransformer/main_simmim_ft.py

@@ -0,0 +1,342 @@
+# --------------------------------------------------------
+# SimMIM
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# Modified by Zhenda Xie
+# --------------------------------------------------------
+
+import os
+import time
+import argparse
+import datetime
+import numpy as np
+
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import torch.cuda.amp as amp
+
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
+from timm.utils import accuracy, AverageMeter
+
+from config import get_config
+from models import build_model
+from data import build_loader
+from lr_scheduler import build_scheduler
+from optimizer import build_optimizer
+from logger import create_logger
+from utils_simmim import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, \
+    reduce_tensor
+
+# pytorch major version (1.x or 2.x)
+PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
+
+
+def parse_option():
+    parser = argparse.ArgumentParser('SimMIM fine-tuning script', add_help=False)
+    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
+    parser.add_argument(
+        "--opts",
+        help="Modify config options by adding 'KEY VALUE' pairs. ",
+        default=None,
+        nargs='+',
+    )
+
+    # easy config modification
+    parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
+    parser.add_argument('--data-path', type=str, help='path to dataset')
+    parser.add_argument('--pretrained', type=str, help='path to pre-trained model')
+    parser.add_argument('--resume', help='resume from checkpoint')
+    parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
+    parser.add_argument('--use-checkpoint', action='store_true',
+                        help="whether to use gradient checkpointing to save memory")
+    parser.add_argument('--enable-amp', action='store_true')
+    parser.add_argument('--disable-amp', action='store_false', dest='enable_amp')
+    parser.set_defaults(enable_amp=True)
+    parser.add_argument('--output', default='output', type=str, metavar='PATH',
+                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
+    parser.add_argument('--tag', help='tag of experiment')
+    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
+    parser.add_argument('--throughput', action='store_true', help='Test throughput only')
+
+    # distributed training
+    # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
+    # (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
+    if PYTORCH_MAJOR_VERSION == 1:
+        parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
+
+    args = parser.parse_args()
+
+    config = get_config(args)
+
+    return args, config
+
+
+def main(config):
+    dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, simmim=True,
+                                                                                            is_pretrain=False)
+
+    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
+    model = build_model(config, is_pretrain=False)
+    model.cuda()
+    logger.info(str(model))
+
+    optimizer = build_optimizer(config, model, simmim=True, is_pretrain=False)
+    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
+    model_without_ddp = model.module
+
+    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    logger.info(f"number of params: {n_parameters}")
+    if hasattr(model_without_ddp, 'flops'):
+        flops = model_without_ddp.flops()
+        logger.info(f"number of GFLOPs: {flops / 1e9}")
+
+    lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
+    scaler = amp.GradScaler()
+
+    if config.AUG.MIXUP > 0.:
+        # smoothing is handled with mixup label transform
+        criterion = SoftTargetCrossEntropy()
+    elif config.MODEL.LABEL_SMOOTHING > 0.:
+        criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
+    else:
+        criterion = torch.nn.CrossEntropyLoss()
+
+    max_accuracy = 0.0
+
+    if config.TRAIN.AUTO_RESUME:
+        resume_file = auto_resume_helper(config.OUTPUT, logger)
+        if resume_file:
+            if config.MODEL.RESUME:
+                logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
+            config.defrost()
+            config.MODEL.RESUME = resume_file
+            config.freeze()
+            logger.info(f'auto resuming from {resume_file}')
+        else:
+            logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
+
+    if config.MODEL.RESUME:
+        max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger)
+        acc1, acc5, loss = validate(config, data_loader_val, model)
+        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
+        if config.EVAL_MODE:
+            return
+
+    if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
+        load_pretrained(config, model_without_ddp, logger)
+        acc1, acc5, loss = validate(config, data_loader_val, model)
+        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
+
+    if config.THROUGHPUT_MODE:
+        throughput(data_loader_val, model, logger)
+        return
+
+    logger.info("Start training")
+    start_time = time.time()
+    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
+        data_loader_train.sampler.set_epoch(epoch)
+
+        train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, scaler)
+        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
+            save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, scaler, logger)
+
+        acc1, acc5, loss = validate(config, data_loader_val, model)
+        logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
+        max_accuracy = max(max_accuracy, acc1)
+        logger.info(f'Max accuracy: {max_accuracy:.2f}%')
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    logger.info('Training time {}'.format(total_time_str))
+
+
+def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, scaler):
+    model.train()
+    optimizer.zero_grad()
+
+    logger.info(f'Current learning rate for different parameter groups: {[it["lr"] for it in optimizer.param_groups]}')
+
+    num_steps = len(data_loader)
+    batch_time = AverageMeter()
+    loss_meter = AverageMeter()
+    norm_meter = AverageMeter()
+    loss_scale_meter = AverageMeter()
+
+    start = time.time()
+    end = time.time()
+    for idx, (samples, targets) in enumerate(data_loader):
+        samples = samples.cuda(non_blocking=True)
+        targets = targets.cuda(non_blocking=True)
+
+        if mixup_fn is not None:
+            samples, targets = mixup_fn(samples, targets)
+
+        outputs = model(samples)
+
+        if config.TRAIN.ACCUMULATION_STEPS > 1:
+            loss = criterion(outputs, targets)
+            loss = loss / config.TRAIN.ACCUMULATION_STEPS
+            scaler.scale(loss).backward()
+            if config.TRAIN.CLIP_GRAD:
+                scaler.unscale_(optimizer)
+                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
+            else:
+                grad_norm = get_grad_norm(model.parameters())
+            if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
+                scaler.step(optimizer)
+                optimizer.zero_grad()
+                scaler.update()
+                lr_scheduler.step_update(epoch * num_steps + idx)
+        else:
+            loss = criterion(outputs, targets)
+            optimizer.zero_grad()
+            scaler.scale(loss).backward()
+            if config.TRAIN.CLIP_GRAD:
+                scaler.unscale_(optimizer)
+                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
+            else:
+                grad_norm = get_grad_norm(model.parameters())
+            scaler.step(optimizer)
+            scaler.update()
+            lr_scheduler.step_update(epoch * num_steps + idx)
+
+        torch.cuda.synchronize()
+
+        loss_meter.update(loss.item(), targets.size(0))
+        norm_meter.update(grad_norm)
+        loss_scale_meter.update(scaler.get_scale())
+        batch_time.update(time.time() - end)
+        end = time.time()
+
+        if idx % config.PRINT_FREQ == 0:
+            lr = optimizer.param_groups[-1]['lr']
+            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
+            etas = batch_time.avg * (num_steps - idx)
+            logger.info(
+                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
+                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
+                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
+                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
+                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
+                f'loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\t'
+                f'mem {memory_used:.0f}MB')
+    epoch_time = time.time() - start
+    logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
+
+
+@torch.no_grad()
+def validate(config, data_loader, model):
+    criterion = torch.nn.CrossEntropyLoss()
+    model.eval()
+
+    batch_time = AverageMeter()
+    loss_meter = AverageMeter()
+    acc1_meter = AverageMeter()
+    acc5_meter = AverageMeter()
+
+    end = time.time()
+    for idx, (images, target) in enumerate(data_loader):
+        images = images.cuda(non_blocking=True)
+        target = target.cuda(non_blocking=True)
+
+        # compute output
+        output = model(images)
+
+        # measure accuracy and record tools
+        loss = criterion(output, target)
+        acc1, acc5 = accuracy(output, target, topk=(1, 5))
+
+        acc1 = reduce_tensor(acc1)
+        acc5 = reduce_tensor(acc5)
+        loss = reduce_tensor(loss)
+
+        loss_meter.update(loss.item(), target.size(0))
+        acc1_meter.update(acc1.item(), target.size(0))
+        acc5_meter.update(acc5.item(), target.size(0))
+
+        # measure elapsed time
+        batch_time.update(time.time() - end)
+        end = time.time()
+
+        if idx % config.PRINT_FREQ == 0:
+            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
+            logger.info(
+                f'Test: [{idx}/{len(data_loader)}]\t'
+                f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
+                f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
+                f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
+                f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
+                f'Mem {memory_used:.0f}MB')
+    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
+    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
+
+
+@torch.no_grad()
+def throughput(data_loader, model, logger):
+    model.eval()
+
+    for idx, (images, _) in enumerate(data_loader):
+        images = images.cuda(non_blocking=True)
+        batch_size = images.shape[0]
+        for i in range(50):
+            model(images)
+        torch.cuda.synchronize()
+        logger.info(f"throughput averaged with 30 times")
+        tic1 = time.time()
+        for i in range(30):
+            model(images)
+        torch.cuda.synchronize()
+        tic2 = time.time()
+        logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
+        return
+
+
+if __name__ == '__main__':
+    _, config = parse_option()
+
+    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+        rank = int(os.environ["RANK"])
+        world_size = int(os.environ['WORLD_SIZE'])
+        print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
+    else:
+        rank = -1
+        world_size = -1
+    torch.cuda.set_device(config.LOCAL_RANK)
+    torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
+    torch.distributed.barrier()
+
+    seed = config.SEED + dist.get_rank()
+    torch.manual_seed(seed)
+    np.random.seed(seed)
+    cudnn.benchmark = True
+
+    # linear scale the learning rate according to total batch size, may not be optimal
+    linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
+    linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
+    linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
+    # gradient accumulation also need to scale the learning rate
+    if config.TRAIN.ACCUMULATION_STEPS > 1:
+        linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
+        linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
+        linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
+    config.defrost()
+    config.TRAIN.BASE_LR = linear_scaled_lr
+    config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
+    config.TRAIN.MIN_LR = linear_scaled_min_lr
+    config.freeze()
+
+    os.makedirs(config.OUTPUT, exist_ok=True)
+    logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
+
+    if dist.get_rank() == 0:
+        path = os.path.join(config.OUTPUT, "config.json")
+        with open(path, "w") as f:
+            f.write(config.dump())
+        logger.info(f"Full config saved to {path}")
+
+    # print config
+    logger.info(config.dump())
+
+    main(config)

+ 234 - 0
lib/SwinTransformer/main_simmim_pt.py

@@ -0,0 +1,234 @@
+# --------------------------------------------------------
+# SimMIM
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# Modified by Zhenda Xie
+# --------------------------------------------------------
+
+import os
+import time
+import argparse
+import datetime
+import numpy as np
+
+import torch
+import torch.backends.cudnn as cudnn
+import torch.distributed as dist
+import torch.cuda.amp as amp
+from timm.utils import AverageMeter
+
+from config import get_config
+from models import build_model
+from data import build_loader
+from lr_scheduler import build_scheduler
+from optimizer import build_optimizer
+from logger import create_logger
+from utils_simmim import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper
+
+# pytorch major version (1.x or 2.x)
+PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
+
+
+def parse_option():
+    parser = argparse.ArgumentParser('SimMIM pre-training script', add_help=False)
+    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
+    parser.add_argument(
+        "--opts",
+        help="Modify config options by adding 'KEY VALUE' pairs. ",
+        default=None,
+        nargs='+',
+    )
+
+    # easy config modification
+    parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
+    parser.add_argument('--data-path', type=str, help='path to dataset')
+    parser.add_argument('--resume', help='resume from checkpoint')
+    parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
+    parser.add_argument('--use-checkpoint', action='store_true',
+                        help="whether to use gradient checkpointing to save memory")
+    parser.add_argument('--enable-amp', action='store_true')
+    parser.add_argument('--disable-amp', action='store_false', dest='enable_amp')
+    parser.set_defaults(enable_amp=True)
+    parser.add_argument('--output', default='output', type=str, metavar='PATH',
+                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
+    parser.add_argument('--tag', help='tag of experiment')
+
+    # distributed training
+    # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
+    # (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
+    if PYTORCH_MAJOR_VERSION == 1:
+        parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
+
+    args = parser.parse_args()
+
+    config = get_config(args)
+
+    return args, config
+
+
+def main(config):
+    data_loader_train = build_loader(config, simmim=True, is_pretrain=True)
+
+    logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
+    model = build_model(config, is_pretrain=True)
+    model.cuda()
+    logger.info(str(model))
+
+    optimizer = build_optimizer(config, model, simmim=True, is_pretrain=True)
+    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
+    model_without_ddp = model.module
+
+    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    logger.info(f"number of params: {n_parameters}")
+    if hasattr(model_without_ddp, 'flops'):
+        flops = model_without_ddp.flops()
+        logger.info(f"number of GFLOPs: {flops / 1e9}")
+
+    lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
+    scaler = amp.GradScaler()
+
+    if config.TRAIN.AUTO_RESUME:
+        resume_file = auto_resume_helper(config.OUTPUT, logger)
+        if resume_file:
+            if config.MODEL.RESUME:
+                logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
+            config.defrost()
+            config.MODEL.RESUME = resume_file
+            config.freeze()
+            logger.info(f'auto resuming from {resume_file}')
+        else:
+            logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
+
+    if config.MODEL.RESUME:
+        load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger)
+
+    logger.info("Start training")
+    start_time = time.time()
+    for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
+        data_loader_train.sampler.set_epoch(epoch)
+
+        train_one_epoch(config, model, data_loader_train, optimizer, epoch, lr_scheduler, scaler)
+        if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
+            save_checkpoint(config, epoch, model_without_ddp, 0., optimizer, lr_scheduler, scaler, logger)
+
+    total_time = time.time() - start_time
+    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+    logger.info('Training time {}'.format(total_time_str))
+
+
+def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler, scaler):
+    model.train()
+    optimizer.zero_grad()
+
+    num_steps = len(data_loader)
+    batch_time = AverageMeter()
+    loss_meter = AverageMeter()
+    norm_meter = AverageMeter()
+    loss_scale_meter = AverageMeter()
+
+    start = time.time()
+    end = time.time()
+    for idx, (img, mask, _) in enumerate(data_loader):
+        img = img.cuda(non_blocking=True)
+        mask = mask.cuda(non_blocking=True)
+
+        with amp.autocast(enabled=config.ENABLE_AMP):
+            loss = model(img, mask)
+
+        if config.TRAIN.ACCUMULATION_STEPS > 1:
+            loss = loss / config.TRAIN.ACCUMULATION_STEPS
+            scaler.scale(loss).backward()
+            if config.TRAIN.CLIP_GRAD:
+                scaler.unscale_(optimizer)
+                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
+            else:
+                grad_norm = get_grad_norm(model.parameters())
+            if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
+                scaler.step(optimizer)
+                optimizer.zero_grad()
+                scaler.update()
+                lr_scheduler.step_update(epoch * num_steps + idx)
+        else:
+            optimizer.zero_grad()
+            scaler.scale(loss).backward()
+            if config.TRAIN.CLIP_GRAD:
+                scaler.unscale_(optimizer)
+                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
+            else:
+                grad_norm = get_grad_norm(model.parameters())
+            scaler.step(optimizer)
+            scaler.update()
+            lr_scheduler.step_update(epoch * num_steps + idx)
+
+        torch.cuda.synchronize()
+
+        loss_meter.update(loss.item(), img.size(0))
+        norm_meter.update(grad_norm)
+        loss_scale_meter.update(scaler.get_scale())
+        batch_time.update(time.time() - end)
+        end = time.time()
+
+        if idx % config.PRINT_FREQ == 0:
+            lr = optimizer.param_groups[0]['lr']
+            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
+            etas = batch_time.avg * (num_steps - idx)
+            logger.info(
+                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
+                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
+                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
+                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
+                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
+                f'loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\t'
+                f'mem {memory_used:.0f}MB')
+    epoch_time = time.time() - start
+    logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
+
+
+if __name__ == '__main__':
+    _, config = parse_option()
+
+    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+        rank = int(os.environ["RANK"])
+        world_size = int(os.environ['WORLD_SIZE'])
+        print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
+    else:
+        rank = -1
+        world_size = -1
+    torch.cuda.set_device(config.LOCAL_RANK)
+    torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
+    torch.distributed.barrier()
+
+    seed = config.SEED + dist.get_rank()
+    torch.manual_seed(seed)
+    np.random.seed(seed)
+    cudnn.benchmark = True
+
+    # linear scale the learning rate according to total batch size, may not be optimal
+    linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
+    linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
+    linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
+    # gradient accumulation also need to scale the learning rate
+    if config.TRAIN.ACCUMULATION_STEPS > 1:
+        linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
+        linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
+        linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
+    config.defrost()
+    config.TRAIN.BASE_LR = linear_scaled_lr
+    config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
+    config.TRAIN.MIN_LR = linear_scaled_min_lr
+    config.freeze()
+
+    os.makedirs(config.OUTPUT, exist_ok=True)
+    logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
+
+    if dist.get_rank() == 0:
+        path = os.path.join(config.OUTPUT, "config.json")
+        with open(path, "w") as f:
+            f.write(config.dump())
+        logger.info(f"Full config saved to {path}")
+
+    # print config
+    logger.info(config.dump())
+
+    main(config)

+ 1 - 0
lib/SwinTransformer/models/__init__.py

@@ -0,0 +1 @@
+from .build import build_model

+ 121 - 0
lib/SwinTransformer/models/build.py

@@ -0,0 +1,121 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+from .swin_transformer import SwinTransformer
+from .swin_transformer_v2 import SwinTransformerV2
+from .swin_transformer_moe import SwinTransformerMoE
+from .swin_mlp import SwinMLP
+from .simmim import build_simmim
+
+
+def build_model(config, is_pretrain=False):
+    model_type = config.MODEL.TYPE
+
+    # accelerate layernorm
+    if config.FUSED_LAYERNORM:
+        try:
+            import apex as amp
+            layernorm = amp.normalization.FusedLayerNorm
+        except:
+            layernorm = None
+            print("To use FusedLayerNorm, please install apex.")
+    else:
+        import torch.nn as nn
+        layernorm = nn.LayerNorm
+
+    if is_pretrain:
+        model = build_simmim(config)
+        return model
+
+    if model_type == 'swin':
+        model = SwinTransformer(img_size=config.DATA.IMG_SIZE,
+                                patch_size=config.MODEL.SWIN.PATCH_SIZE,
+                                in_chans=config.MODEL.SWIN.IN_CHANS,
+                                num_classes=config.MODEL.NUM_CLASSES,
+                                embed_dim=config.MODEL.SWIN.EMBED_DIM,
+                                depths=config.MODEL.SWIN.DEPTHS,
+                                num_heads=config.MODEL.SWIN.NUM_HEADS,
+                                window_size=config.MODEL.SWIN.WINDOW_SIZE,
+                                mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
+                                qkv_bias=config.MODEL.SWIN.QKV_BIAS,
+                                qk_scale=config.MODEL.SWIN.QK_SCALE,
+                                drop_rate=config.MODEL.DROP_RATE,
+                                drop_path_rate=config.MODEL.DROP_PATH_RATE,
+                                ape=config.MODEL.SWIN.APE,
+                                norm_layer=layernorm,
+                                patch_norm=config.MODEL.SWIN.PATCH_NORM,
+                                use_checkpoint=config.TRAIN.USE_CHECKPOINT,
+                                fused_window_process=config.FUSED_WINDOW_PROCESS)
+    elif model_type == 'swinv2':
+        model = SwinTransformerV2(img_size=config.DATA.IMG_SIZE,
+                                  patch_size=config.MODEL.SWINV2.PATCH_SIZE,
+                                  in_chans=config.MODEL.SWINV2.IN_CHANS,
+                                  num_classes=config.MODEL.NUM_CLASSES,
+                                  embed_dim=config.MODEL.SWINV2.EMBED_DIM,
+                                  depths=config.MODEL.SWINV2.DEPTHS,
+                                  num_heads=config.MODEL.SWINV2.NUM_HEADS,
+                                  window_size=config.MODEL.SWINV2.WINDOW_SIZE,
+                                  mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,
+                                  qkv_bias=config.MODEL.SWINV2.QKV_BIAS,
+                                  drop_rate=config.MODEL.DROP_RATE,
+                                  drop_path_rate=config.MODEL.DROP_PATH_RATE,
+                                  ape=config.MODEL.SWINV2.APE,
+                                  patch_norm=config.MODEL.SWINV2.PATCH_NORM,
+                                  use_checkpoint=config.TRAIN.USE_CHECKPOINT,
+                                  pretrained_window_sizes=config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES)
+    elif model_type == 'swin_moe':
+        model = SwinTransformerMoE(img_size=config.DATA.IMG_SIZE,
+                                   patch_size=config.MODEL.SWIN_MOE.PATCH_SIZE,
+                                   in_chans=config.MODEL.SWIN_MOE.IN_CHANS,
+                                   num_classes=config.MODEL.NUM_CLASSES,
+                                   embed_dim=config.MODEL.SWIN_MOE.EMBED_DIM,
+                                   depths=config.MODEL.SWIN_MOE.DEPTHS,
+                                   num_heads=config.MODEL.SWIN_MOE.NUM_HEADS,
+                                   window_size=config.MODEL.SWIN_MOE.WINDOW_SIZE,
+                                   mlp_ratio=config.MODEL.SWIN_MOE.MLP_RATIO,
+                                   qkv_bias=config.MODEL.SWIN_MOE.QKV_BIAS,
+                                   qk_scale=config.MODEL.SWIN_MOE.QK_SCALE,
+                                   drop_rate=config.MODEL.DROP_RATE,
+                                   drop_path_rate=config.MODEL.DROP_PATH_RATE,
+                                   ape=config.MODEL.SWIN_MOE.APE,
+                                   patch_norm=config.MODEL.SWIN_MOE.PATCH_NORM,
+                                   mlp_fc2_bias=config.MODEL.SWIN_MOE.MLP_FC2_BIAS,
+                                   init_std=config.MODEL.SWIN_MOE.INIT_STD,
+                                   use_checkpoint=config.TRAIN.USE_CHECKPOINT,
+                                   pretrained_window_sizes=config.MODEL.SWIN_MOE.PRETRAINED_WINDOW_SIZES,
+                                   moe_blocks=config.MODEL.SWIN_MOE.MOE_BLOCKS,
+                                   num_local_experts=config.MODEL.SWIN_MOE.NUM_LOCAL_EXPERTS,
+                                   top_value=config.MODEL.SWIN_MOE.TOP_VALUE,
+                                   capacity_factor=config.MODEL.SWIN_MOE.CAPACITY_FACTOR,
+                                   cosine_router=config.MODEL.SWIN_MOE.COSINE_ROUTER,
+                                   normalize_gate=config.MODEL.SWIN_MOE.NORMALIZE_GATE,
+                                   use_bpr=config.MODEL.SWIN_MOE.USE_BPR,
+                                   is_gshard_loss=config.MODEL.SWIN_MOE.IS_GSHARD_LOSS,
+                                   gate_noise=config.MODEL.SWIN_MOE.GATE_NOISE,
+                                   cosine_router_dim=config.MODEL.SWIN_MOE.COSINE_ROUTER_DIM,
+                                   cosine_router_init_t=config.MODEL.SWIN_MOE.COSINE_ROUTER_INIT_T,
+                                   moe_drop=config.MODEL.SWIN_MOE.MOE_DROP,
+                                   aux_loss_weight=config.MODEL.SWIN_MOE.AUX_LOSS_WEIGHT)
+    elif model_type == 'swin_mlp':
+        model = SwinMLP(img_size=config.DATA.IMG_SIZE,
+                        patch_size=config.MODEL.SWIN_MLP.PATCH_SIZE,
+                        in_chans=config.MODEL.SWIN_MLP.IN_CHANS,
+                        num_classes=config.MODEL.NUM_CLASSES,
+                        embed_dim=config.MODEL.SWIN_MLP.EMBED_DIM,
+                        depths=config.MODEL.SWIN_MLP.DEPTHS,
+                        num_heads=config.MODEL.SWIN_MLP.NUM_HEADS,
+                        window_size=config.MODEL.SWIN_MLP.WINDOW_SIZE,
+                        mlp_ratio=config.MODEL.SWIN_MLP.MLP_RATIO,
+                        drop_rate=config.MODEL.DROP_RATE,
+                        drop_path_rate=config.MODEL.DROP_PATH_RATE,
+                        ape=config.MODEL.SWIN_MLP.APE,
+                        patch_norm=config.MODEL.SWIN_MLP.PATCH_NORM,
+                        use_checkpoint=config.TRAIN.USE_CHECKPOINT)
+    else:
+        raise NotImplementedError(f"Unkown model: {model_type}")
+
+    return model

+ 209 - 0
lib/SwinTransformer/models/simmim.py

@@ -0,0 +1,209 @@
+
+
+# --------------------------------------------------------
+# SimMIM
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Zhenda Xie
+# --------------------------------------------------------
+
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.models.layers import trunc_normal_
+
+from .swin_transformer import SwinTransformer
+from .swin_transformer_v2 import SwinTransformerV2
+
+
+def norm_targets(targets, patch_size):
+    assert patch_size % 2 == 1
+    
+    targets_ = targets
+    targets_count = torch.ones_like(targets)
+
+    targets_square = targets ** 2.
+    
+    targets_mean = F.avg_pool2d(targets, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=False)
+    targets_square_mean = F.avg_pool2d(targets_square, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=False)
+    targets_count = F.avg_pool2d(targets_count, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=True) * (patch_size ** 2)
+    
+    targets_var = (targets_square_mean - targets_mean ** 2.) * (targets_count / (targets_count - 1))
+    targets_var = torch.clamp(targets_var, min=0.)
+    
+    targets_ = (targets_ - targets_mean) / (targets_var + 1.e-6) ** 0.5
+    
+    return targets_
+
+
+class SwinTransformerForSimMIM(SwinTransformer):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+        assert self.num_classes == 0
+
+        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+        trunc_normal_(self.mask_token, mean=0., std=.02)
+
+    def forward(self, x, mask):
+        x = self.patch_embed(x)
+
+        assert mask is not None
+        B, L, _ = x.shape
+
+        mask_tokens = self.mask_token.expand(B, L, -1)
+        w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
+        x = x * (1. - w) + mask_tokens * w
+
+        if self.ape:
+            x = x + self.absolute_pos_embed
+        x = self.pos_drop(x)
+
+        for layer in self.layers:
+            x = layer(x)
+        x = self.norm(x)
+
+        x = x.transpose(1, 2)
+        B, C, L = x.shape
+        H = W = int(L ** 0.5)
+        x = x.reshape(B, C, H, W)
+        return x
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return super().no_weight_decay() | {'mask_token'}
+
+
+class SwinTransformerV2ForSimMIM(SwinTransformerV2):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+        assert self.num_classes == 0
+
+        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+        trunc_normal_(self.mask_token, mean=0., std=.02)
+
+    def forward(self, x, mask):
+        x = self.patch_embed(x)
+
+        assert mask is not None
+        B, L, _ = x.shape
+
+        mask_tokens = self.mask_token.expand(B, L, -1)
+        w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
+        x = x * (1. - w) + mask_tokens * w
+
+        if self.ape:
+            x = x + self.absolute_pos_embed
+        x = self.pos_drop(x)
+
+        for layer in self.layers:
+            x = layer(x)
+        x = self.norm(x)
+
+        x = x.transpose(1, 2)
+        B, C, L = x.shape
+        H = W = int(L ** 0.5)
+        x = x.reshape(B, C, H, W)
+        return x
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return super().no_weight_decay() | {'mask_token'}
+
+
+class SimMIM(nn.Module):
+    def __init__(self, config, encoder, encoder_stride, in_chans, patch_size):
+        super().__init__()
+        self.config = config
+        self.encoder = encoder
+        self.encoder_stride = encoder_stride
+
+        self.decoder = nn.Sequential(
+            nn.Conv2d(
+                in_channels=self.encoder.num_features,
+                out_channels=self.encoder_stride ** 2 * 3, kernel_size=1),
+            nn.PixelShuffle(self.encoder_stride),
+        )
+
+        self.in_chans = in_chans
+        self.patch_size = patch_size
+
+    def forward(self, x, mask):
+        z = self.encoder(x, mask)
+        x_rec = self.decoder(z)
+
+        mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
+        
+        # norm target as prompted
+        if self.config.NORM_TARGET.ENABLE:
+            x = norm_targets(x, self.config.NORM_TARGET.PATCH_SIZE)
+        
+        loss_recon = F.l1_loss(x, x_rec, reduction='none')
+        loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans
+        return loss
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        if hasattr(self.encoder, 'no_weight_decay'):
+            return {'encoder.' + i for i in self.encoder.no_weight_decay()}
+        return {}
+
+    @torch.jit.ignore
+    def no_weight_decay_keywords(self):
+        if hasattr(self.encoder, 'no_weight_decay_keywords'):
+            return {'encoder.' + i for i in self.encoder.no_weight_decay_keywords()}
+        return {}
+
+
+def build_simmim(config):
+    model_type = config.MODEL.TYPE
+    if model_type == 'swin':
+        encoder = SwinTransformerForSimMIM(
+            img_size=config.DATA.IMG_SIZE,
+            patch_size=config.MODEL.SWIN.PATCH_SIZE,
+            in_chans=config.MODEL.SWIN.IN_CHANS,
+            num_classes=0,
+            embed_dim=config.MODEL.SWIN.EMBED_DIM,
+            depths=config.MODEL.SWIN.DEPTHS,
+            num_heads=config.MODEL.SWIN.NUM_HEADS,
+            window_size=config.MODEL.SWIN.WINDOW_SIZE,
+            mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
+            qkv_bias=config.MODEL.SWIN.QKV_BIAS,
+            qk_scale=config.MODEL.SWIN.QK_SCALE,
+            drop_rate=config.MODEL.DROP_RATE,
+            drop_path_rate=config.MODEL.DROP_PATH_RATE,
+            ape=config.MODEL.SWIN.APE,
+            patch_norm=config.MODEL.SWIN.PATCH_NORM,
+            use_checkpoint=config.TRAIN.USE_CHECKPOINT)
+        encoder_stride = 32
+        in_chans = config.MODEL.SWIN.IN_CHANS
+        patch_size = config.MODEL.SWIN.PATCH_SIZE
+    elif model_type == 'swinv2':
+        encoder = SwinTransformerV2ForSimMIM(
+            img_size=config.DATA.IMG_SIZE,
+            patch_size=config.MODEL.SWINV2.PATCH_SIZE,
+            in_chans=config.MODEL.SWINV2.IN_CHANS,
+            num_classes=0,
+            embed_dim=config.MODEL.SWINV2.EMBED_DIM,
+            depths=config.MODEL.SWINV2.DEPTHS,
+            num_heads=config.MODEL.SWINV2.NUM_HEADS,
+            window_size=config.MODEL.SWINV2.WINDOW_SIZE,
+            mlp_ratio=config.MODEL.SWINV2.MLP_RATIO,
+            qkv_bias=config.MODEL.SWINV2.QKV_BIAS,
+            drop_rate=config.MODEL.DROP_RATE,
+            drop_path_rate=config.MODEL.DROP_PATH_RATE,
+            ape=config.MODEL.SWINV2.APE,
+            patch_norm=config.MODEL.SWINV2.PATCH_NORM,
+            use_checkpoint=config.TRAIN.USE_CHECKPOINT)
+        encoder_stride = 32
+        in_chans = config.MODEL.SWINV2.IN_CHANS
+        patch_size = config.MODEL.SWINV2.PATCH_SIZE
+    else:
+        raise NotImplementedError(f"Unknown pre-train model: {model_type}")
+
+    model = SimMIM(config=config.MODEL.SIMMIM, encoder=encoder, encoder_stride=encoder_stride, in_chans=in_chans, patch_size=patch_size)
+
+    return model

+ 468 - 0
lib/SwinTransformer/models/swin_mlp.py

@@ -0,0 +1,468 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+def window_partition(x, window_size):
+    """
+    Args:
+        x: (B, H, W, C)
+        window_size (int): window size
+
+    Returns:
+        windows: (num_windows*B, window_size, window_size, C)
+    """
+    B, H, W, C = x.shape
+    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    return windows
+
+
+def window_reverse(windows, window_size, H, W):
+    """
+    Args:
+        windows: (num_windows*B, window_size, window_size, C)
+        window_size (int): Window size
+        H (int): Height of image
+        W (int): Width of image
+
+    Returns:
+        x: (B, H, W, C)
+    """
+    B = int(windows.shape[0] / (H * W / window_size / window_size))
+    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+    return x
+
+
+class SwinMLPBlock(nn.Module):
+    r""" Swin MLP Block.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resulotion.
+        num_heads (int): Number of attention heads.
+        window_size (int): Window size.
+        shift_size (int): Shift size for SW-MSA.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        drop (float, optional): Dropout rate. Default: 0.0
+        drop_path (float, optional): Stochastic depth rate. Default: 0.0
+        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+                 mlp_ratio=4., drop=0., drop_path=0.,
+                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.dim = dim
+        self.input_resolution = input_resolution
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.shift_size = shift_size
+        self.mlp_ratio = mlp_ratio
+        if min(self.input_resolution) <= self.window_size:
+            # if window size is larger than input resolution, we don't partition windows
+            self.shift_size = 0
+            self.window_size = min(self.input_resolution)
+        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+        self.padding = [self.window_size - self.shift_size, self.shift_size,
+                        self.window_size - self.shift_size, self.shift_size]  # P_l,P_r,P_t,P_b
+
+        self.norm1 = norm_layer(dim)
+        # use group convolution to implement multi-head MLP
+        self.spatial_mlp = nn.Conv1d(self.num_heads * self.window_size ** 2,
+                                     self.num_heads * self.window_size ** 2,
+                                     kernel_size=1,
+                                     groups=self.num_heads)
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+    def forward(self, x):
+        H, W = self.input_resolution
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+
+        shortcut = x
+        x = self.norm1(x)
+        x = x.view(B, H, W, C)
+
+        # shift
+        if self.shift_size > 0:
+            P_l, P_r, P_t, P_b = self.padding
+            shifted_x = F.pad(x, [0, 0, P_l, P_r, P_t, P_b], "constant", 0)
+        else:
+            shifted_x = x
+        _, _H, _W, _ = shifted_x.shape
+
+        # partition windows
+        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
+        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
+
+        # Window/Shifted-Window Spatial MLP
+        x_windows_heads = x_windows.view(-1, self.window_size * self.window_size, self.num_heads, C // self.num_heads)
+        x_windows_heads = x_windows_heads.transpose(1, 2)  # nW*B, nH, window_size*window_size, C//nH
+        x_windows_heads = x_windows_heads.reshape(-1, self.num_heads * self.window_size * self.window_size,
+                                                  C // self.num_heads)
+        spatial_mlp_windows = self.spatial_mlp(x_windows_heads)  # nW*B, nH*window_size*window_size, C//nH
+        spatial_mlp_windows = spatial_mlp_windows.view(-1, self.num_heads, self.window_size * self.window_size,
+                                                       C // self.num_heads).transpose(1, 2)
+        spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size * self.window_size, C)
+
+        # merge windows
+        spatial_mlp_windows = spatial_mlp_windows.reshape(-1, self.window_size, self.window_size, C)
+        shifted_x = window_reverse(spatial_mlp_windows, self.window_size, _H, _W)  # B H' W' C
+
+        # reverse shift
+        if self.shift_size > 0:
+            P_l, P_r, P_t, P_b = self.padding
+            x = shifted_x[:, P_t:-P_b, P_l:-P_r, :].contiguous()
+        else:
+            x = shifted_x
+        x = x.view(B, H * W, C)
+
+        # FFN
+        x = shortcut + self.drop_path(x)
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+        return x
+
+    def extra_repr(self) -> str:
+        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+    def flops(self):
+        flops = 0
+        H, W = self.input_resolution
+        # norm1
+        flops += self.dim * H * W
+
+        # Window/Shifted-Window Spatial MLP
+        if self.shift_size > 0:
+            nW = (H / self.window_size + 1) * (W / self.window_size + 1)
+        else:
+            nW = H * W / self.window_size / self.window_size
+        flops += nW * self.dim * (self.window_size * self.window_size) * (self.window_size * self.window_size)
+        # mlp
+        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+        # norm2
+        flops += self.dim * H * W
+        return flops
+
+
+class PatchMerging(nn.Module):
+    r""" Patch Merging Layer.
+
+    Args:
+        input_resolution (tuple[int]): Resolution of input feature.
+        dim (int): Number of input channels.
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.input_resolution = input_resolution
+        self.dim = dim
+        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+        self.norm = norm_layer(4 * dim)
+
+    def forward(self, x):
+        """
+        x: B, H*W, C
+        """
+        H, W = self.input_resolution
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+        x = x.view(B, H, W, C)
+
+        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
+        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
+        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
+        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
+        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
+        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
+
+        x = self.norm(x)
+        x = self.reduction(x)
+
+        return x
+
+    def extra_repr(self) -> str:
+        return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+    def flops(self):
+        H, W = self.input_resolution
+        flops = H * W * self.dim
+        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+        return flops
+
+
+class BasicLayer(nn.Module):
+    """ A basic Swin MLP layer for one stage.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resolution.
+        depth (int): Number of blocks.
+        num_heads (int): Number of attention heads.
+        window_size (int): Local window size.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        drop (float, optional): Dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+                 mlp_ratio=4., drop=0., drop_path=0.,
+                 norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
+
+        super().__init__()
+        self.dim = dim
+        self.input_resolution = input_resolution
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList([
+            SwinMLPBlock(dim=dim, input_resolution=input_resolution,
+                         num_heads=num_heads, window_size=window_size,
+                         shift_size=0 if (i % 2 == 0) else window_size // 2,
+                         mlp_ratio=mlp_ratio,
+                         drop=drop,
+                         drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+                         norm_layer=norm_layer)
+            for i in range(depth)])
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+        else:
+            self.downsample = None
+
+    def forward(self, x):
+        for blk in self.blocks:
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x)
+        if self.downsample is not None:
+            x = self.downsample(x)
+        return x
+
+    def extra_repr(self) -> str:
+        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+    def flops(self):
+        flops = 0
+        for blk in self.blocks:
+            flops += blk.flops()
+        if self.downsample is not None:
+            flops += self.downsample.flops()
+        return flops
+
+
+class PatchEmbed(nn.Module):
+    r""" Image to Patch Embedding
+
+    Args:
+        img_size (int): Image size.  Default: 224.
+        patch_size (int): Patch token size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        norm_layer (nn.Module, optional): Normalization layer. Default: None
+    """
+
+    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.patches_resolution = patches_resolution
+        self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+        if norm_layer is not None:
+            self.norm = norm_layer(embed_dim)
+        else:
+            self.norm = None
+
+    def forward(self, x):
+        B, C, H, W = x.shape
+        # FIXME look at relaxing size constraints
+        assert H == self.img_size[0] and W == self.img_size[1], \
+            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
+        if self.norm is not None:
+            x = self.norm(x)
+        return x
+
+    def flops(self):
+        Ho, Wo = self.patches_resolution
+        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+        if self.norm is not None:
+            flops += Ho * Wo * self.embed_dim
+        return flops
+
+
+class SwinMLP(nn.Module):
+    r""" Swin MLP
+
+    Args:
+        img_size (int | tuple(int)): Input image size. Default 224
+        patch_size (int | tuple(int)): Patch size. Default: 4
+        in_chans (int): Number of input image channels. Default: 3
+        num_classes (int): Number of classes for classification head. Default: 1000
+        embed_dim (int): Patch embedding dimension. Default: 96
+        depths (tuple(int)): Depth of each Swin MLP layer.
+        num_heads (tuple(int)): Number of attention heads in different layers.
+        window_size (int): Window size. Default: 7
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+        drop_rate (float): Dropout rate. Default: 0
+        drop_path_rate (float): Stochastic depth rate. Default: 0.1
+        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+        patch_norm (bool): If True, add normalization after patch embedding. Default: True
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+    """
+
+    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
+                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
+                 window_size=7, mlp_ratio=4., drop_rate=0., drop_path_rate=0.1,
+                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
+                 use_checkpoint=False, **kwargs):
+        super().__init__()
+
+        self.num_classes = num_classes
+        self.num_layers = len(depths)
+        self.embed_dim = embed_dim
+        self.ape = ape
+        self.patch_norm = patch_norm
+        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
+        self.mlp_ratio = mlp_ratio
+
+        # split image into non-overlapping patches
+        self.patch_embed = PatchEmbed(
+            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+            norm_layer=norm_layer if self.patch_norm else None)
+        num_patches = self.patch_embed.num_patches
+        patches_resolution = self.patch_embed.patches_resolution
+        self.patches_resolution = patches_resolution
+
+        # absolute position embedding
+        if self.ape:
+            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+            trunc_normal_(self.absolute_pos_embed, std=.02)
+
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        # stochastic depth
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
+
+        # build layers
+        self.layers = nn.ModuleList()
+        for i_layer in range(self.num_layers):
+            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
+                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
+                                                 patches_resolution[1] // (2 ** i_layer)),
+                               depth=depths[i_layer],
+                               num_heads=num_heads[i_layer],
+                               window_size=window_size,
+                               mlp_ratio=self.mlp_ratio,
+                               drop=drop_rate,
+                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+                               norm_layer=norm_layer,
+                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+                               use_checkpoint=use_checkpoint)
+            self.layers.append(layer)
+
+        self.norm = norm_layer(self.num_features)
+        self.avgpool = nn.AdaptiveAvgPool1d(1)
+        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, (nn.Linear, nn.Conv1d)):
+            trunc_normal_(m.weight, std=.02)
+            if m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {'absolute_pos_embed'}
+
+    @torch.jit.ignore
+    def no_weight_decay_keywords(self):
+        return {'relative_position_bias_table'}
+
+    def forward_features(self, x):
+        x = self.patch_embed(x)
+        if self.ape:
+            x = x + self.absolute_pos_embed
+        x = self.pos_drop(x)
+
+        for layer in self.layers:
+            x = layer(x)
+
+        x = self.norm(x)  # B L C
+        x = self.avgpool(x.transpose(1, 2))  # B C 1
+        x = torch.flatten(x, 1)
+        return x
+
+    def forward(self, x):
+        x = self.forward_features(x)
+        x = self.head(x)
+        return x
+
+    def flops(self):
+        flops = 0
+        flops += self.patch_embed.flops()
+        for i, layer in enumerate(self.layers):
+            flops += layer.flops()
+        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
+        flops += self.num_features * self.num_classes
+        return flops

+ 614 - 0
lib/SwinTransformer/models/swin_transformer.py

@@ -0,0 +1,614 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+try:
+    import os, sys
+
+    kernel_path = os.path.abspath(os.path.join('..'))
+    sys.path.append(kernel_path)
+    from kernels.window_process.window_process import WindowProcess, WindowProcessReverse
+
+except:
+    WindowProcess = None
+    WindowProcessReverse = None
+    print("[Warning] Fused window process have not been installed. Please refer to get_started.md for installation.")
+
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+def window_partition(x, window_size):
+    """
+    Args:
+        x: (B, H, W, C)
+        window_size (int): window size
+
+    Returns:
+        windows: (num_windows*B, window_size, window_size, C)
+    """
+    B, H, W, C = x.shape
+    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    return windows
+
+
+def window_reverse(windows, window_size, H, W):
+    """
+    Args:
+        windows: (num_windows*B, window_size, window_size, C)
+        window_size (int): Window size
+        H (int): Height of image
+        W (int): Width of image
+
+    Returns:
+        x: (B, H, W, C)
+    """
+    B = int(windows.shape[0] / (H * W / window_size / window_size))
+    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+    return x
+
+
+class WindowAttention(nn.Module):
+    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+    It supports both of shifted and non-shifted window.
+
+    Args:
+        dim (int): Number of input channels.
+        window_size (tuple[int]): The height and width of the window.
+        num_heads (int): Number of attention heads.
+        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+    """
+
+    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+        super().__init__()
+        self.dim = dim
+        self.window_size = window_size  # Wh, Ww
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        # define a parameter table of relative position bias
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = torch.arange(self.window_size[0])
+        coords_w = torch.arange(self.window_size[1])
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
+        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
+        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += self.window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        trunc_normal_(self.relative_position_bias_table, std=.02)
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, x, mask=None):
+        """
+        Args:
+            x: input features with shape of (num_windows*B, N, C)
+            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+        """
+        B_, N, C = x.shape
+        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
+
+        q = q * self.scale
+        attn = (q @ k.transpose(-2, -1))
+
+        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
+        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
+        attn = attn + relative_position_bias.unsqueeze(0)
+
+        if mask is not None:
+            nW = mask.shape[0]
+            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+            attn = attn.view(-1, self.num_heads, N, N)
+            attn = self.softmax(attn)
+        else:
+            attn = self.softmax(attn)
+
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+    def extra_repr(self) -> str:
+        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+    def flops(self, N):
+        # calculate flops for 1 window with token length of N
+        flops = 0
+        # qkv = self.qkv(x)
+        flops += N * self.dim * 3 * self.dim
+        # attn = (q @ k.transpose(-2, -1))
+        flops += self.num_heads * N * (self.dim // self.num_heads) * N
+        #  x = (attn @ v)
+        flops += self.num_heads * N * N * (self.dim // self.num_heads)
+        # x = self.proj(x)
+        flops += N * self.dim * self.dim
+        return flops
+
+
+class SwinTransformerBlock(nn.Module):
+    r""" Swin Transformer Block.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resulotion.
+        num_heads (int): Number of attention heads.
+        window_size (int): Window size.
+        shift_size (int): Shift size for SW-MSA.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float, optional): Stochastic depth rate. Default: 0.0
+        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
+    """
+
+    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+                 act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+                 fused_window_process=False):
+        super().__init__()
+        self.dim = dim
+        self.input_resolution = input_resolution
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.shift_size = shift_size
+        self.mlp_ratio = mlp_ratio
+        if min(self.input_resolution) <= self.window_size:
+            # if window size is larger than input resolution, we don't partition windows
+            self.shift_size = 0
+            self.window_size = min(self.input_resolution)
+        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+        self.norm1 = norm_layer(dim)
+        self.attn = WindowAttention(
+            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+        if self.shift_size > 0:
+            # calculate attention mask for SW-MSA
+            H, W = self.input_resolution
+            img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
+            h_slices = (slice(0, -self.window_size),
+                        slice(-self.window_size, -self.shift_size),
+                        slice(-self.shift_size, None))
+            w_slices = (slice(0, -self.window_size),
+                        slice(-self.window_size, -self.shift_size),
+                        slice(-self.shift_size, None))
+            cnt = 0
+            for h in h_slices:
+                for w in w_slices:
+                    img_mask[:, h, w, :] = cnt
+                    cnt += 1
+
+            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
+            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+        else:
+            attn_mask = None
+
+        self.register_buffer("attn_mask", attn_mask)
+        self.fused_window_process = fused_window_process
+
+    def forward(self, x):
+        H, W = self.input_resolution
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+
+        shortcut = x
+        x = self.norm1(x)
+        x = x.view(B, H, W, C)
+
+        # cyclic shift
+        if self.shift_size > 0:
+            if not self.fused_window_process:
+                shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+                # partition windows
+                x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
+            else:
+                x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
+        else:
+            shifted_x = x
+            # partition windows
+            x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
+
+        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C
+
+        # W-MSA/SW-MSA
+        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C
+
+        # merge windows
+        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+
+        # reverse cyclic shift
+        if self.shift_size > 0:
+            if not self.fused_window_process:
+                shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
+                x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+            else:
+                x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
+        else:
+            shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C
+            x = shifted_x
+        x = x.view(B, H * W, C)
+        x = shortcut + self.drop_path(x)
+
+        # FFN
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+        return x
+
+    def extra_repr(self) -> str:
+        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+    def flops(self):
+        flops = 0
+        H, W = self.input_resolution
+        # norm1
+        flops += self.dim * H * W
+        # W-MSA/SW-MSA
+        nW = H * W / self.window_size / self.window_size
+        flops += nW * self.attn.flops(self.window_size * self.window_size)
+        # mlp
+        flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+        # norm2
+        flops += self.dim * H * W
+        return flops
+
+
+class PatchMerging(nn.Module):
+    r""" Patch Merging Layer.
+
+    Args:
+        input_resolution (tuple[int]): Resolution of input feature.
+        dim (int): Number of input channels.
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.input_resolution = input_resolution
+        self.dim = dim
+        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+        self.norm = norm_layer(4 * dim)
+
+    def forward(self, x):
+        """
+        x: B, H*W, C
+        """
+        H, W = self.input_resolution
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+        x = x.view(B, H, W, C)
+
+        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
+        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
+        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
+        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
+        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
+        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
+
+        x = self.norm(x)
+        x = self.reduction(x)
+
+        return x
+
+    def extra_repr(self) -> str:
+        return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+    def flops(self):
+        H, W = self.input_resolution
+        flops = H * W * self.dim
+        flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+        return flops
+
+
+class BasicLayer(nn.Module):
+    """ A basic Swin Transformer layer for one stage.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resolution.
+        depth (int): Number of blocks.
+        num_heads (int): Number of attention heads.
+        window_size (int): Local window size.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
+    """
+
+    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
+                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
+                 fused_window_process=False):
+
+        super().__init__()
+        self.dim = dim
+        self.input_resolution = input_resolution
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList([
+            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
+                                 num_heads=num_heads, window_size=window_size,
+                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
+                                 mlp_ratio=mlp_ratio,
+                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
+                                 drop=drop, attn_drop=attn_drop,
+                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+                                 norm_layer=norm_layer,
+                                 fused_window_process=fused_window_process)
+            for i in range(depth)])
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+        else:
+            self.downsample = None
+
+    def forward(self, x):
+        for blk in self.blocks:
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x)
+        if self.downsample is not None:
+            x = self.downsample(x)
+        return x
+
+    def extra_repr(self) -> str:
+        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+    def flops(self):
+        flops = 0
+        for blk in self.blocks:
+            flops += blk.flops()
+        if self.downsample is not None:
+            flops += self.downsample.flops()
+        return flops
+
+
+class PatchEmbed(nn.Module):
+    r""" Image to Patch Embedding
+
+    Args:
+        img_size (int): Image size.  Default: 224.
+        patch_size (int): Patch token size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        norm_layer (nn.Module, optional): Normalization layer. Default: None
+    """
+
+    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+        super().__init__()
+        img_size = to_2tuple(img_size)
+        patch_size = to_2tuple(patch_size)
+        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+        self.img_size = img_size
+        self.patch_size = patch_size
+        self.patches_resolution = patches_resolution
+        self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+        if norm_layer is not None:
+            self.norm = norm_layer(embed_dim)
+        else:
+            self.norm = None
+
+    def forward(self, x):
+        B, C, H, W = x.shape
+        # FIXME look at relaxing size constraints
+        assert H == self.img_size[0] and W == self.img_size[1], \
+            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
+        if self.norm is not None:
+            x = self.norm(x)
+        return x
+
+    def flops(self):
+        Ho, Wo = self.patches_resolution
+        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+        if self.norm is not None:
+            flops += Ho * Wo * self.embed_dim
+        return flops
+
+
+class SwinTransformer(nn.Module):
+    r""" Swin Transformer
+        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
+          https://arxiv.org/pdf/2103.14030
+
+    Args:
+        img_size (int | tuple(int)): Input image size. Default 224
+        patch_size (int | tuple(int)): Patch size. Default: 4
+        in_chans (int): Number of input image channels. Default: 3
+        num_classes (int): Number of classes for classification head. Default: 1000
+        embed_dim (int): Patch embedding dimension. Default: 96
+        depths (tuple(int)): Depth of each Swin Transformer layer.
+        num_heads (tuple(int)): Number of attention heads in different layers.
+        window_size (int): Window size. Default: 7
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+        drop_rate (float): Dropout rate. Default: 0
+        attn_drop_rate (float): Attention dropout rate. Default: 0
+        drop_path_rate (float): Stochastic depth rate. Default: 0.1
+        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+        patch_norm (bool): If True, add normalization after patch embedding. Default: True
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+        fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
+    """
+
+    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
+                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
+                 window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+                 norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
+                 use_checkpoint=False, fused_window_process=False, **kwargs):
+        super().__init__()
+
+        self.num_classes = num_classes
+        self.num_layers = len(depths)
+        self.embed_dim = embed_dim
+        self.ape = ape
+        self.patch_norm = patch_norm
+        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
+        self.mlp_ratio = mlp_ratio
+
+        # split image into non-overlapping patches
+        self.patch_embed = PatchEmbed(
+            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+            norm_layer=norm_layer if self.patch_norm else None)
+        num_patches = self.patch_embed.num_patches
+        patches_resolution = self.patch_embed.patches_resolution
+        self.patches_resolution = patches_resolution
+
+        # absolute position embedding
+        if self.ape:
+            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+            trunc_normal_(self.absolute_pos_embed, std=.02)
+
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        # stochastic depth
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
+
+        # build layers
+        self.layers = nn.ModuleList()
+        for i_layer in range(self.num_layers):
+            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
+                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
+                                                 patches_resolution[1] // (2 ** i_layer)),
+                               depth=depths[i_layer],
+                               num_heads=num_heads[i_layer],
+                               window_size=window_size,
+                               mlp_ratio=self.mlp_ratio,
+                               qkv_bias=qkv_bias, qk_scale=qk_scale,
+                               drop=drop_rate, attn_drop=attn_drop_rate,
+                               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+                               norm_layer=norm_layer,
+                               downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+                               use_checkpoint=use_checkpoint,
+                               fused_window_process=fused_window_process)
+            self.layers.append(layer)
+
+        self.norm = norm_layer(self.num_features)
+        self.avgpool = nn.AdaptiveAvgPool1d(1)
+        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    @torch.jit.ignore
+    def no_weight_decay(self):
+        return {'absolute_pos_embed'}
+
+    @torch.jit.ignore
+    def no_weight_decay_keywords(self):
+        return {'relative_position_bias_table'}
+
+    def forward_features(self, x):
+        x = self.patch_embed(x)
+        if self.ape:
+            x = x + self.absolute_pos_embed
+        x = self.pos_drop(x)
+
+        for layer in self.layers:
+            x = layer(x)
+
+        x = self.norm(x)  # B L C
+        x = self.avgpool(x.transpose(1, 2))  # B C 1
+        x = torch.flatten(x, 1)
+        return x
+
+    def forward(self, x):
+        x = self.forward_features(x)
+        x = self.head(x)
+        return x
+
+    def flops(self):
+        flops = 0
+        flops += self.patch_embed.flops()
+        for i, layer in enumerate(self.layers):
+            flops += layer.flops()
+        flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
+        flops += self.num_features * self.num_classes
+        return flops

Algunos archivos no se mostraron porque demasiados archivos cambiaron en este cambio