model.py 960 B

1234567891011121314151617181920212223242526272829303132333435
  1. import os
  2. from functools import partial
  3. from typing import Callable
  4. import torch
  5. from torch import nn
  6. from torch.utils import checkpoint
  7. from mmengine.model import BaseModule
  8. from mmdet.registry import MODELS as MODELS_MMDET
  9. from mmseg.registry import MODELS as MODELS_MMSEG
  10. def import_abspy(name="models", path="classification/"):
  11. import sys
  12. import importlib
  13. path = os.path.abspath(path)
  14. assert os.path.isdir(path)
  15. sys.path.insert(0, path)
  16. module = importlib.import_module(name)
  17. sys.path.pop(0)
  18. return module
  19. build = import_abspy(
  20. "models",
  21. os.path.join(os.path.dirname(os.path.abspath(__file__)), "../classification/"),
  22. )
  23. Backbone_VSSM: nn.Module = build.vmamba.Backbone_VSSM
  24. @MODELS_MMSEG.register_module()
  25. @MODELS_MMDET.register_module()
  26. class MM_VSSM(BaseModule, Backbone_VSSM):
  27. def __init__(self, *args, **kwargs):
  28. BaseModule.__init__(self)
  29. Backbone_VSSM.__init__(self, *args, **kwargs)