setup.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Modified by $@#Anonymous#@$ #20240123
  2. # Copyright (c) 2023, Albert Gu, Tri Dao.
  3. import sys
  4. import warnings
  5. import os
  6. import re
  7. import ast
  8. from pathlib import Path
  9. from packaging.version import parse, Version
  10. import platform
  11. import shutil
  12. from setuptools import setup, find_packages
  13. import subprocess
  14. from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
  15. import torch
  16. from torch.utils.cpp_extension import (
  17. BuildExtension,
  18. CppExtension,
  19. CUDAExtension,
  20. CUDA_HOME,
  21. )
  22. # ninja build does not work unless include_dirs are abs path
  23. this_dir = os.path.dirname(os.path.abspath(__file__))
  24. # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
  25. FORCE_CXX11_ABI = os.getenv("FORCE_CXX11_ABI", "FALSE") == "TRUE"
  26. def get_compute_capability():
  27. device = torch.device("cuda")
  28. capability = torch.cuda.get_device_capability(device)
  29. return int(str(capability[0]) + str(capability[1]))
  30. def get_cuda_bare_metal_version(cuda_dir):
  31. raw_output = subprocess.check_output(
  32. [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
  33. )
  34. output = raw_output.split()
  35. release_idx = output.index("release") + 1
  36. bare_metal_version = parse(output[release_idx].split(",")[0])
  37. return raw_output, bare_metal_version
  38. MODES = ["oflex"]
  39. # MODES = ["core", "ndstate", "oflex"]
  40. # MODES = ["core", "ndstate", "oflex", "nrow"]
  41. def get_ext():
  42. cc_flag = []
  43. print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
  44. print("\n\nCUDA_HOME = {}\n\n".format(CUDA_HOME))
  45. # Check if card has compute capability 8.0 or higher for BFloat16 operations
  46. if get_compute_capability() < 80:
  47. warnings.warn("This code uses BFloat16 date type, which is only supported on GPU architectures with compute capability 8.0 or higher")
  48. multi_threads = True
  49. if CUDA_HOME is not None:
  50. _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
  51. print("CUDA version: ", bare_metal_version, flush=True)
  52. if bare_metal_version < Version("11.6"):
  53. warnings.warn("CUDA version ealier than 11.6 may leads to performance mismatch.")
  54. if bare_metal_version < Version("11.2"):
  55. multi_threads = False
  56. cc_flag.append(f"-arch=sm_{get_compute_capability()}")
  57. if multi_threads:
  58. cc_flag.extend(["--threads", "4"])
  59. # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
  60. # torch._C._GLIBCXX_USE_CXX11_ABI
  61. # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
  62. if FORCE_CXX11_ABI:
  63. torch._C._GLIBCXX_USE_CXX11_ABI = True
  64. sources = dict(
  65. core=[
  66. "csrc/selective_scan/cus/selective_scan.cpp",
  67. "csrc/selective_scan/cus/selective_scan_core_fwd.cu",
  68. "csrc/selective_scan/cus/selective_scan_core_bwd.cu",
  69. ],
  70. nrow=[
  71. "csrc/selective_scan/cusnrow/selective_scan_nrow.cpp",
  72. "csrc/selective_scan/cusnrow/selective_scan_core_fwd.cu",
  73. "csrc/selective_scan/cusnrow/selective_scan_core_fwd2.cu",
  74. "csrc/selective_scan/cusnrow/selective_scan_core_fwd3.cu",
  75. "csrc/selective_scan/cusnrow/selective_scan_core_fwd4.cu",
  76. "csrc/selective_scan/cusnrow/selective_scan_core_bwd.cu",
  77. "csrc/selective_scan/cusnrow/selective_scan_core_bwd2.cu",
  78. "csrc/selective_scan/cusnrow/selective_scan_core_bwd3.cu",
  79. "csrc/selective_scan/cusnrow/selective_scan_core_bwd4.cu",
  80. ],
  81. ndstate=[
  82. "csrc/selective_scan/cusndstate/selective_scan_ndstate.cpp",
  83. "csrc/selective_scan/cusndstate/selective_scan_core_fwd.cu",
  84. "csrc/selective_scan/cusndstate/selective_scan_core_bwd.cu",
  85. ],
  86. oflex=[
  87. "csrc/selective_scan/cusoflex/selective_scan_oflex.cpp",
  88. "csrc/selective_scan/cusoflex/selective_scan_core_fwd.cu",
  89. "csrc/selective_scan/cusoflex/selective_scan_core_bwd.cu",
  90. ],
  91. )
  92. names = dict(
  93. core="selective_scan_cuda_core",
  94. nrow="selective_scan_cuda_nrow",
  95. ndstate="selective_scan_cuda_ndstate",
  96. oflex="selective_scan_cuda_oflex",
  97. )
  98. ext_modules = [
  99. CUDAExtension(
  100. name=names.get(MODE, None),
  101. sources=sources.get(MODE, None),
  102. extra_compile_args={
  103. "cxx": ["-O3", "-std=c++17"],
  104. "nvcc": [
  105. "-O3",
  106. "-std=c++17",
  107. "-U__CUDA_NO_HALF_OPERATORS__",
  108. "-U__CUDA_NO_HALF_CONVERSIONS__",
  109. "-U__CUDA_NO_BFLOAT16_OPERATORS__",
  110. "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
  111. "-U__CUDA_NO_BFLOAT162_OPERATORS__",
  112. "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
  113. "--expt-relaxed-constexpr",
  114. "--expt-extended-lambda",
  115. "--use_fast_math",
  116. "--ptxas-options=-v",
  117. "-lineinfo",
  118. ]
  119. + cc_flag
  120. },
  121. include_dirs=[Path(this_dir) / "csrc" / "selective_scan"],
  122. )
  123. for MODE in MODES
  124. ]
  125. return ext_modules
  126. ext_modules = get_ext()
  127. setup(
  128. name="selective_scan",
  129. version="0.0.2",
  130. packages=[],
  131. author="Tri Dao, Albert Gu, $@#Anonymous#@$ ",
  132. author_email="tri@tridao.me, agu@cs.cmu.edu, $@#Anonymous#EMAIL@$",
  133. description="selective scan",
  134. long_description="",
  135. long_description_content_type="text/markdown",
  136. url="https://github.com/state-spaces/mamba",
  137. classifiers=[
  138. "Programming Language :: Python :: 3",
  139. "License :: OSI Approved :: BSD License",
  140. "Operating System :: Unix",
  141. ],
  142. ext_modules=ext_modules,
  143. cmdclass={"bdist_wheel": _bdist_wheel, "build_ext": BuildExtension} if ext_modules else {"bdist_wheel": _bdist_wheel,},
  144. python_requires=">=3.7",
  145. install_requires=[
  146. "torch",
  147. "packaging",
  148. "ninja",
  149. "einops",
  150. ],
  151. )