config.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # --------------------------------------------------------
  2. # Modified by $@#Anonymous#@$
  3. # --------------------------------------------------------
  4. # Swin Transformer
  5. # Copyright (c) 2021 Microsoft
  6. # Licensed under The MIT License [see LICENSE for details]
  7. # Written by Ze Liu
  8. # --------------------------------------------------------'
  9. import os
  10. import yaml
  11. from yacs.config import CfgNode as CN
  12. _C = CN()
  13. # Base config files
  14. _C.BASE = ['']
  15. # -----------------------------------------------------------------------------
  16. # Data settings
  17. # -----------------------------------------------------------------------------
  18. _C.DATA = CN()
  19. # Batch size for a single GPU, could be overwritten by command line argument
  20. _C.DATA.BATCH_SIZE = 128
  21. # Path to dataset, could be overwritten by command line argument
  22. _C.DATA.DATA_PATH = ''
  23. # Dataset name
  24. _C.DATA.DATASET = 'imagenet'
  25. # Input image size
  26. _C.DATA.IMG_SIZE = 224
  27. # Interpolation to resize image (random, bilinear, bicubic)
  28. _C.DATA.INTERPOLATION = 'bicubic'
  29. # Use zipped dataset instead of folder dataset
  30. # could be overwritten by command line argument
  31. _C.DATA.ZIP_MODE = False
  32. # Cache Data in Memory, could be overwritten by command line argument
  33. _C.DATA.CACHE_MODE = 'part'
  34. # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
  35. _C.DATA.PIN_MEMORY = True
  36. # Number of data loading threads
  37. _C.DATA.NUM_WORKERS = 8
  38. # [SimMIM] Mask patch size for MaskGenerator
  39. _C.DATA.MASK_PATCH_SIZE = 32
  40. # [SimMIM] Mask ratio for MaskGenerator
  41. _C.DATA.MASK_RATIO = 0.6
  42. # -----------------------------------------------------------------------------
  43. # Model settings
  44. # -----------------------------------------------------------------------------
  45. _C.MODEL = CN()
  46. # Model type
  47. _C.MODEL.TYPE = 'vssm'
  48. # Model name
  49. _C.MODEL.NAME = 'vssm_tiny_224'
  50. # Pretrained weight from checkpoint, could be imagenet22k pretrained weight
  51. # could be overwritten by command line argument
  52. _C.MODEL.PRETRAINED = ''
  53. # Checkpoint to resume, could be overwritten by command line argument
  54. _C.MODEL.RESUME = ''
  55. # Number of classes, overwritten in data preparation
  56. _C.MODEL.NUM_CLASSES = 1000
  57. # Dropout rate
  58. _C.MODEL.DROP_RATE = 0.0
  59. # Drop path rate
  60. _C.MODEL.DROP_PATH_RATE = 0.1
  61. # Label Smoothing
  62. _C.MODEL.LABEL_SMOOTHING = 0.1
  63. # MMpretrain models for test
  64. _C.MODEL.MMCKPT = False
  65. # VSSM parameters
  66. _C.MODEL.VSSM = CN()
  67. _C.MODEL.VSSM.PATCH_SIZE = 4
  68. _C.MODEL.VSSM.IN_CHANS = 3
  69. _C.MODEL.VSSM.DEPTHS = [2, 2, 9, 2]
  70. _C.MODEL.VSSM.EMBED_DIM = 96
  71. _C.MODEL.VSSM.SSM_D_STATE = 16
  72. _C.MODEL.VSSM.SSM_RATIO = 2.0
  73. _C.MODEL.VSSM.SSM_RANK_RATIO = 2.0
  74. _C.MODEL.VSSM.SSM_DT_RANK = "auto"
  75. _C.MODEL.VSSM.SSM_ACT_LAYER = "silu"
  76. _C.MODEL.VSSM.SSM_CONV = 3
  77. _C.MODEL.VSSM.SSM_CONV_BIAS = True
  78. _C.MODEL.VSSM.SSM_DROP_RATE = 0.0
  79. _C.MODEL.VSSM.SSM_INIT = "v0"
  80. _C.MODEL.VSSM.SSM_FORWARDTYPE = "v2"
  81. _C.MODEL.VSSM.MLP_RATIO = 4.0
  82. _C.MODEL.VSSM.MLP_ACT_LAYER = "gelu"
  83. _C.MODEL.VSSM.MLP_DROP_RATE = 0.0
  84. _C.MODEL.VSSM.PATCH_NORM = True
  85. _C.MODEL.VSSM.NORM_LAYER = "ln"
  86. _C.MODEL.VSSM.DOWNSAMPLE = "v2"
  87. _C.MODEL.VSSM.PATCHEMBED = "v2"
  88. _C.MODEL.VSSM.POSEMBED = False
  89. _C.MODEL.VSSM.GMLP = False
  90. # -----------------------------------------------------------------------------
  91. # Training settings
  92. # -----------------------------------------------------------------------------
  93. _C.TRAIN = CN()
  94. _C.TRAIN.START_EPOCH = 0
  95. _C.TRAIN.EPOCHS = 300
  96. _C.TRAIN.WARMUP_EPOCHS = 20
  97. _C.TRAIN.WEIGHT_DECAY = 0.05
  98. _C.TRAIN.BASE_LR = 5e-4
  99. _C.TRAIN.WARMUP_LR = 5e-7
  100. _C.TRAIN.MIN_LR = 5e-6
  101. # Clip gradient norm
  102. _C.TRAIN.CLIP_GRAD = 5.0
  103. # Auto resume from latest checkpoint
  104. _C.TRAIN.AUTO_RESUME = True
  105. # Gradient accumulation steps
  106. # could be overwritten by command line argument
  107. _C.TRAIN.ACCUMULATION_STEPS = 1
  108. # Whether to use gradient checkpointing to save memory
  109. # could be overwritten by command line argument
  110. _C.TRAIN.USE_CHECKPOINT = False
  111. # LR scheduler
  112. _C.TRAIN.LR_SCHEDULER = CN()
  113. _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
  114. # Epoch interval to decay LR, used in StepLRScheduler
  115. _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
  116. # LR decay rate, used in StepLRScheduler
  117. _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
  118. # warmup_prefix used in CosineLRScheduler
  119. _C.TRAIN.LR_SCHEDULER.WARMUP_PREFIX = True
  120. # [SimMIM] Gamma / Multi steps value, used in MultiStepLRScheduler
  121. _C.TRAIN.LR_SCHEDULER.GAMMA = 0.1
  122. _C.TRAIN.LR_SCHEDULER.MULTISTEPS = []
  123. # Optimizer
  124. _C.TRAIN.OPTIMIZER = CN()
  125. _C.TRAIN.OPTIMIZER.NAME = 'adamw'
  126. # Optimizer Epsilon
  127. _C.TRAIN.OPTIMIZER.EPS = 1e-8
  128. # Optimizer Betas
  129. _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
  130. # SGD momentum
  131. _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
  132. # [SimMIM] Layer decay for fine-tuning
  133. _C.TRAIN.LAYER_DECAY = 1.0
  134. # MoE
  135. _C.TRAIN.MOE = CN()
  136. # Only save model on master device
  137. _C.TRAIN.MOE.SAVE_MASTER = False
  138. # -----------------------------------------------------------------------------
  139. # Augmentation settings
  140. # -----------------------------------------------------------------------------
  141. _C.AUG = CN()
  142. # Color jitter factor
  143. _C.AUG.COLOR_JITTER = 0.4
  144. # Use AutoAugment policy. "v0" or "original"
  145. _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
  146. # Random erase prob
  147. _C.AUG.REPROB = 0.25
  148. # Random erase mode
  149. _C.AUG.REMODE = 'pixel'
  150. # Random erase count
  151. _C.AUG.RECOUNT = 1
  152. # Mixup alpha, mixup enabled if > 0
  153. _C.AUG.MIXUP = 0.8
  154. # Cutmix alpha, cutmix enabled if > 0
  155. _C.AUG.CUTMIX = 1.0
  156. # Cutmix min/max ratio, overrides alpha and enables cutmix if set
  157. _C.AUG.CUTMIX_MINMAX = None
  158. # Probability of performing mixup or cutmix when either/both is enabled
  159. _C.AUG.MIXUP_PROB = 1.0
  160. # Probability of switching to cutmix when both mixup and cutmix enabled
  161. _C.AUG.MIXUP_SWITCH_PROB = 0.5
  162. # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
  163. _C.AUG.MIXUP_MODE = 'batch'
  164. # -----------------------------------------------------------------------------
  165. # Testing settings
  166. # -----------------------------------------------------------------------------
  167. _C.TEST = CN()
  168. # Whether to use center crop when testing
  169. _C.TEST.CROP = True
  170. # Whether to use SequentialSampler as validation sampler
  171. _C.TEST.SEQUENTIAL = False
  172. _C.TEST.SHUFFLE = False
  173. # -----------------------------------------------------------------------------
  174. # Misc
  175. # -----------------------------------------------------------------------------
  176. # [SimMIM] Whether to enable pytorch amp, overwritten by command line argument
  177. _C.ENABLE_AMP = False
  178. # Enable Pytorch automatic mixed precision (amp).
  179. _C.AMP_ENABLE = True
  180. # [Deprecated] Mixed precision opt level of apex, if O0, no apex amp is used ('O0', 'O1', 'O2')
  181. _C.AMP_OPT_LEVEL = ''
  182. # Path to output folder, overwritten by command line argument
  183. _C.OUTPUT = ''
  184. # Tag of experiment, overwritten by command line argument
  185. _C.TAG = 'default'
  186. # Frequency to save checkpoint
  187. _C.SAVE_FREQ = 1
  188. # Frequency to logging info
  189. _C.PRINT_FREQ = 10
  190. # Fixed random seed
  191. _C.SEED = 0
  192. # Perform evaluation only, overwritten by command line argument
  193. _C.EVAL_MODE = False
  194. # Test throughput only, overwritten by command line argument
  195. _C.THROUGHPUT_MODE = False
  196. # Test traincost only, overwritten by command line argument
  197. _C.TRAINCOST_MODE = False
  198. # for acceleration
  199. _C.FUSED_LAYERNORM = False
  200. def _update_config_from_file(config, cfg_file):
  201. config.defrost()
  202. with open(cfg_file, 'r') as f:
  203. yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
  204. for cfg in yaml_cfg.setdefault('BASE', ['']):
  205. if cfg:
  206. _update_config_from_file(
  207. config, os.path.join(os.path.dirname(cfg_file), cfg)
  208. )
  209. print('=> merge config from {}'.format(cfg_file))
  210. config.merge_from_file(cfg_file)
  211. config.freeze()
  212. def update_config(config, args):
  213. if args.cfg != "":
  214. _update_config_from_file(config, args.cfg)
  215. config.defrost()
  216. if args.opts:
  217. config.merge_from_list(args.opts)
  218. def _check_args(name):
  219. if hasattr(args, name) and eval(f'args.{name}'):
  220. return True
  221. return False
  222. # merge from specific arguments
  223. if _check_args('batch_size'):
  224. config.DATA.BATCH_SIZE = args.batch_size
  225. if _check_args('data_path'):
  226. config.DATA.DATA_PATH = args.data_path
  227. if _check_args('zip'):
  228. config.DATA.ZIP_MODE = True
  229. if _check_args('cache_mode'):
  230. config.DATA.CACHE_MODE = args.cache_mode
  231. if _check_args('pretrained'):
  232. config.MODEL.PRETRAINED = args.pretrained
  233. if _check_args('resume'):
  234. config.MODEL.RESUME = args.resume
  235. if _check_args('accumulation_steps'):
  236. config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
  237. if _check_args('use_checkpoint'):
  238. config.TRAIN.USE_CHECKPOINT = True
  239. if _check_args('disable_amp'):
  240. config.AMP_ENABLE = False
  241. if _check_args('output'):
  242. config.OUTPUT = args.output
  243. if _check_args('tag'):
  244. config.TAG = args.tag
  245. if _check_args('eval'):
  246. config.EVAL_MODE = True
  247. if _check_args('throughput'):
  248. config.THROUGHPUT_MODE = True
  249. if _check_args('traincost'):
  250. config.TRAINCOST_MODE = True
  251. # [SimMIM]
  252. if _check_args('enable_amp'):
  253. config.ENABLE_AMP = args.enable_amp
  254. # for acceleration
  255. if _check_args('fused_layernorm'):
  256. config.FUSED_LAYERNORM = True
  257. ## Overwrite optimizer if not None, currently we use it for [fused_adam, fused_lamb]
  258. if _check_args('optim'):
  259. config.TRAIN.OPTIMIZER.NAME = args.optim
  260. # output folder
  261. config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
  262. config.freeze()
  263. def get_config(args):
  264. """Get a yacs CfgNode object with default values."""
  265. # Return a clone so that the defaults will not be altered
  266. # This is for the "local variable" use pattern
  267. config = _C.clone()
  268. update_config(config, args)
  269. return config