__init__.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import os
  2. from functools import partial
  3. import torch
  4. from .vmamba import VSSM
  5. def build_vssm_model(config, **kwargs):
  6. model_type = config.MODEL.TYPE
  7. if model_type in ["vssm"]:
  8. model = VSSM(
  9. patch_size=config.MODEL.VSSM.PATCH_SIZE,
  10. in_chans=config.MODEL.VSSM.IN_CHANS,
  11. num_classes=config.MODEL.NUM_CLASSES,
  12. depths=config.MODEL.VSSM.DEPTHS,
  13. dims=config.MODEL.VSSM.EMBED_DIM,
  14. # ===================
  15. ssm_d_state=config.MODEL.VSSM.SSM_D_STATE,
  16. ssm_ratio=config.MODEL.VSSM.SSM_RATIO,
  17. ssm_rank_ratio=config.MODEL.VSSM.SSM_RANK_RATIO,
  18. ssm_dt_rank=("auto" if config.MODEL.VSSM.SSM_DT_RANK == "auto" else int(config.MODEL.VSSM.SSM_DT_RANK)),
  19. ssm_act_layer=config.MODEL.VSSM.SSM_ACT_LAYER,
  20. ssm_conv=config.MODEL.VSSM.SSM_CONV,
  21. ssm_conv_bias=config.MODEL.VSSM.SSM_CONV_BIAS,
  22. ssm_drop_rate=config.MODEL.VSSM.SSM_DROP_RATE,
  23. ssm_init=config.MODEL.VSSM.SSM_INIT,
  24. forward_type=config.MODEL.VSSM.SSM_FORWARDTYPE,
  25. # ===================
  26. mlp_ratio=config.MODEL.VSSM.MLP_RATIO,
  27. mlp_act_layer=config.MODEL.VSSM.MLP_ACT_LAYER,
  28. mlp_drop_rate=config.MODEL.VSSM.MLP_DROP_RATE,
  29. # ===================
  30. drop_path_rate=config.MODEL.DROP_PATH_RATE,
  31. patch_norm=config.MODEL.VSSM.PATCH_NORM,
  32. norm_layer=config.MODEL.VSSM.NORM_LAYER,
  33. downsample_version=config.MODEL.VSSM.DOWNSAMPLE,
  34. patchembed_version=config.MODEL.VSSM.PATCHEMBED,
  35. gmlp=config.MODEL.VSSM.GMLP,
  36. use_checkpoint=config.TRAIN.USE_CHECKPOINT,
  37. # ===================
  38. posembed=config.MODEL.VSSM.POSEMBED,
  39. imgsize=config.DATA.IMG_SIZE,
  40. )
  41. return model
  42. return None
  43. def build_model(config, is_pretrain=False):
  44. model = None
  45. if model is None:
  46. model = build_vssm_model(config)
  47. if model is None:
  48. from .simvmamba import simple_build
  49. model = simple_build(config.MODEL.TYPE)
  50. return model