csm_triton.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693
  1. import torch
  2. import warnings
  3. WITH_TRITON = True
  4. # WITH_TRITON = False
  5. try:
  6. import triton
  7. import triton.language as tl
  8. except:
  9. WITH_TRITON = False
  10. warnings.warn("Triton not installed, fall back to pytorch implements.")
  11. # to make sure cached_property can be loaded for triton
  12. if WITH_TRITON:
  13. try:
  14. from functools import cached_property
  15. except:
  16. warnings.warn("if you are using py37, add this line to functools.py: "
  17. "cached_property = lambda func: property(lru_cache()(func))")
  18. # torch implementation ========================================
  19. def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
  20. if in_channel_first:
  21. B, C, H, W = x.shape
  22. if scans == 0:
  23. y = x.new_empty((B, 4, C, H * W))
  24. y[:, 0, :, :] = x.flatten(2, 3)
  25. y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3)
  26. y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1])
  27. elif scans == 1:
  28. y = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
  29. elif scans == 2:
  30. y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
  31. y = torch.cat([y, y.flip(dims=[-1])], dim=1)
  32. else:
  33. B, H, W, C = x.shape
  34. if scans == 0:
  35. y = x.new_empty((B, H * W, 4, C))
  36. y[:, :, 0, :] = x.flatten(1, 2)
  37. y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2)
  38. y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1])
  39. elif scans == 1:
  40. y = x.view(B, H * W, 1, C).repeat(1, 1, 4, 1)
  41. elif scans == 2:
  42. y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1)
  43. y = torch.cat([y, y.flip(dims=[1])], dim=2)
  44. if in_channel_first and (not out_channel_first):
  45. y = y.permute(0, 3, 1, 2).contiguous()
  46. elif (not in_channel_first) and out_channel_first:
  47. y = y.permute(0, 2, 3, 1).contiguous()
  48. return y
  49. def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
  50. if out_channel_first:
  51. B, K, D, H, W = y.shape
  52. y = y.view(B, K, D, -1)
  53. if scans == 0:
  54. y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
  55. y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
  56. elif scans == 1:
  57. y = y.sum(1)
  58. elif scans == 2:
  59. y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
  60. y = y.sum(1)
  61. else:
  62. B, H, W, K, D = y.shape
  63. y = y.view(B, -1, K, D)
  64. if scans == 0:
  65. y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
  66. y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D)
  67. elif scans == 1:
  68. y = y.sum(2)
  69. elif scans == 2:
  70. y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D)
  71. y = y.sum(2)
  72. if in_channel_first and (not out_channel_first):
  73. y = y.permute(0, 2, 1).contiguous()
  74. elif (not in_channel_first) and out_channel_first:
  75. y = y.permute(0, 2, 1).contiguous()
  76. return y
  77. def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
  78. if in_channel_first:
  79. B, _, C, H, W = x.shape
  80. if scans == 0:
  81. y = torch.stack([
  82. x[:, 0].flatten(2, 3),
  83. x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3),
  84. torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
  85. torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
  86. ], dim=1)
  87. elif scans == 1:
  88. y = x.flatten(2, 3)
  89. elif scans == 2:
  90. y = torch.stack([
  91. x[:, 0].flatten(2, 3),
  92. x[:, 1].flatten(2, 3),
  93. torch.flip(x[:, 2].flatten(2, 3), dims=[-1]),
  94. torch.flip(x[:, 3].flatten(2, 3), dims=[-1]),
  95. ], dim=1)
  96. else:
  97. B, H, W, _, C = x.shape
  98. if scans == 0:
  99. y = torch.stack([
  100. x[:, :, :, 0].flatten(1, 2),
  101. x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2),
  102. torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]),
  103. torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
  104. ], dim=2)
  105. elif scans == 1:
  106. y = x.flatten(1, 2)
  107. elif scans == 2:
  108. y = torch.stack([
  109. x[:, 0].flatten(1, 2),
  110. x[:, 1].flatten(1, 2),
  111. torch.flip(x[:, 2].flatten(1, 2), dims=[-1]),
  112. torch.flip(x[:, 3].flatten(1, 2), dims=[-1]),
  113. ], dim=2)
  114. if in_channel_first and (not out_channel_first):
  115. y = y.permute(0, 3, 1, 2).contiguous()
  116. elif (not in_channel_first) and out_channel_first:
  117. y = y.permute(0, 2, 3, 1).contiguous()
  118. return y
  119. def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0):
  120. if out_channel_first:
  121. B, K, D, H, W = y.shape
  122. y = y.view(B, K, D, -1)
  123. if scans == 0:
  124. y = torch.stack([
  125. y[:, 0],
  126. y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3),
  127. torch.flip(y[:, 2], dims=[-1]),
  128. torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]),
  129. ], dim=1)
  130. elif scans == 1:
  131. y = y
  132. elif scans == 2:
  133. y = torch.stack([
  134. y[:, 0],
  135. y[:, 1],
  136. torch.flip(y[:, 2], dims=[-1]),
  137. torch.flip(y[:, 3], dims=[-1]),
  138. ], dim=1)
  139. else:
  140. B, H, W, K, D = y.shape
  141. y = y.view(B, -1, K, D)
  142. if scans == 0:
  143. y = torch.stack([
  144. y[:, :, 0],
  145. y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2),
  146. torch.flip(y[:, :, 2], dims=[1]),
  147. torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]),
  148. ], dim=2)
  149. elif scans == 1:
  150. y = y
  151. elif scans == 2:
  152. y = torch.stack([
  153. y[:, :, 0],
  154. y[:, :, 1],
  155. torch.flip(y[:, :, 2], dims=[1]),
  156. torch.flip(y[:, :, 3], dims=[1]),
  157. ], dim=2)
  158. if out_channel_first and (not in_channel_first):
  159. y = y.permute(0, 3, 1, 2).contiguous()
  160. elif (not out_channel_first) and in_channel_first:
  161. y = y.permute(0, 2, 3, 1).contiguous()
  162. return y
  163. class CrossScanF(torch.autograd.Function):
  164. @staticmethod
  165. def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
  166. # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
  167. # y: (B, 4, C, H * W) | (B, H * W, 4, C)
  168. ctx.in_channel_first = in_channel_first
  169. ctx.out_channel_first = out_channel_first
  170. ctx.one_by_one = one_by_one
  171. ctx.scans = scans
  172. if one_by_one:
  173. B, K, C, H, W = x.shape
  174. if not in_channel_first:
  175. B, H, W, K, C = x.shape
  176. else:
  177. B, C, H, W = x.shape
  178. if not in_channel_first:
  179. B, H, W, C = x.shape
  180. ctx.shape = (B, C, H, W)
  181. _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
  182. y = _fn(x, in_channel_first, out_channel_first, scans)
  183. return y
  184. @staticmethod
  185. def backward(ctx, ys: torch.Tensor):
  186. # out: (b, k, d, l)
  187. in_channel_first = ctx.in_channel_first
  188. out_channel_first = ctx.out_channel_first
  189. one_by_one = ctx.one_by_one
  190. scans = ctx.scans
  191. B, C, H, W = ctx.shape
  192. ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C)
  193. _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
  194. y = _fn(ys, in_channel_first, out_channel_first, scans)
  195. if one_by_one:
  196. y = y.view(B, 4, -1, H, W) if in_channel_first else y.view(B, H, W, 4, -1)
  197. else:
  198. y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1)
  199. return y, None, None, None, None
  200. class CrossMergeF(torch.autograd.Function):
  201. @staticmethod
  202. def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
  203. # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
  204. # y: (B, 4, C, H * W) | (B, H * W, 4, C)
  205. ctx.in_channel_first = in_channel_first
  206. ctx.out_channel_first = out_channel_first
  207. ctx.one_by_one = one_by_one
  208. ctx.scans = scans
  209. B, K, C, H, W = ys.shape
  210. if not out_channel_first:
  211. B, H, W, K, C = ys.shape
  212. ctx.shape = (B, C, H, W)
  213. _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd
  214. y = _fn(ys, in_channel_first, out_channel_first, scans)
  215. return y
  216. @staticmethod
  217. def backward(ctx, x: torch.Tensor):
  218. # B, D, L = x.shape
  219. # out: (b, k, d, h, w)
  220. in_channel_first = ctx.in_channel_first
  221. out_channel_first = ctx.out_channel_first
  222. one_by_one = ctx.one_by_one
  223. scans = ctx.scans
  224. B, C, H, W = ctx.shape
  225. if not one_by_one:
  226. if in_channel_first:
  227. x = x.view(B, C, H, W)
  228. else:
  229. x = x.view(B, H, W, C)
  230. else:
  231. if in_channel_first:
  232. x = x.view(B, 4, C, H, W)
  233. else:
  234. x = x.view(B, H, W, 4, C)
  235. _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd
  236. x = _fn(x, in_channel_first, out_channel_first, scans)
  237. x = x.view(B, 4, C, H, W) if out_channel_first else x.view(B, H, W, 4, C)
  238. return x, None, None, None, None
  239. # triton implements ========================================
  240. @triton.jit
  241. def triton_cross_scan_flex(
  242. x: tl.tensor, # (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
  243. y: tl.tensor, # (B, 4, C, H, W) | (B, H, W, 4, C)
  244. x_layout: tl.constexpr,
  245. y_layout: tl.constexpr,
  246. operation: tl.constexpr,
  247. onebyone: tl.constexpr,
  248. scans: tl.constexpr,
  249. BC: tl.constexpr,
  250. BH: tl.constexpr,
  251. BW: tl.constexpr,
  252. DC: tl.constexpr,
  253. DH: tl.constexpr,
  254. DW: tl.constexpr,
  255. NH: tl.constexpr,
  256. NW: tl.constexpr,
  257. ):
  258. # x_layout = 0
  259. # y_layout = 1 # 0 BCHW, 1 BHWC
  260. # operation = 0 # 0 scan, 1 merge
  261. # onebyone = 0 # 0 false, 1 true
  262. # scans = 0 # 0 cross scan, 1 unidirectional, 2 bidirectional
  263. i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2)
  264. i_h, i_w = (i_hw // NW), (i_hw % NW)
  265. _mask_h = (i_h * BH + tl.arange(0, BH)) < DH
  266. _mask_w = (i_w * BW + tl.arange(0, BW)) < DW
  267. _mask_hw = _mask_h[:, None] & _mask_w[None, :]
  268. _for_C = min(DC - i_c * BC, BC)
  269. pos_h = (i_h * BH + tl.arange(0, BH)[:, None])
  270. pos_w = (i_w * BW + tl.arange(0, BW)[None, :])
  271. neg_h = (DH - i_h * BH - 1 - tl.arange(0, BH)[:, None])
  272. neg_w = (DW - i_w * BW - 1 - tl.arange(0, BW)[None, :])
  273. if scans == 0:
  274. # none; trans; flip; trans + flip;
  275. HWRoute0 = pos_h * DW + pos_w
  276. HWRoute1 = pos_w * DH + pos_h # trans
  277. HWRoute2 = neg_h * DW + neg_w # flip
  278. HWRoute3 = neg_w * DH + neg_h # trans + flip
  279. elif scans == 1:
  280. # none; none; none; none;
  281. HWRoute0 = pos_h * DW + pos_w
  282. HWRoute1 = HWRoute0
  283. HWRoute2 = HWRoute0
  284. HWRoute3 = HWRoute0
  285. elif scans == 2:
  286. # none; none; flip; flip;
  287. HWRoute0 = pos_h * DW + pos_w
  288. HWRoute1 = HWRoute0
  289. HWRoute2 = neg_h * DW + neg_w # flip
  290. HWRoute3 = HWRoute2
  291. _tmp1 = DC * DH * DW
  292. y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC)
  293. if y_layout == 0:
  294. p_y1 = y_ptr_base + HWRoute0
  295. p_y2 = y_ptr_base + _tmp1 + HWRoute1
  296. p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2
  297. p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3
  298. else:
  299. p_y1 = y_ptr_base + HWRoute0 * 4 * DC
  300. p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC
  301. p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC
  302. p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC
  303. if onebyone == 0:
  304. x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
  305. if x_layout == 0:
  306. p_x = x_ptr_base + HWRoute0
  307. else:
  308. p_x = x_ptr_base + HWRoute0 * DC
  309. if operation == 0:
  310. for idxc in range(_for_C):
  311. _idx_x = idxc * DH * DW if x_layout == 0 else idxc
  312. _idx_y = idxc * DH * DW if y_layout == 0 else idxc
  313. _x = tl.load(p_x + _idx_x, mask=_mask_hw)
  314. tl.store(p_y1 + _idx_y, _x, mask=_mask_hw)
  315. tl.store(p_y2 + _idx_y, _x, mask=_mask_hw)
  316. tl.store(p_y3 + _idx_y, _x, mask=_mask_hw)
  317. tl.store(p_y4 + _idx_y, _x, mask=_mask_hw)
  318. elif operation == 1:
  319. for idxc in range(_for_C):
  320. _idx_x = idxc * DH * DW if x_layout == 0 else idxc
  321. _idx_y = idxc * DH * DW if y_layout == 0 else idxc
  322. _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw)
  323. _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw)
  324. _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw)
  325. _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw)
  326. tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw)
  327. else:
  328. x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC)
  329. if x_layout == 0:
  330. p_x1 = x_ptr_base + HWRoute0
  331. p_x2 = p_x1 + _tmp1
  332. p_x3 = p_x2 + _tmp1
  333. p_x4 = p_x3 + _tmp1
  334. else:
  335. p_x1 = x_ptr_base + HWRoute0 * 4 * DC
  336. p_x2 = p_x1 + DC
  337. p_x3 = p_x2 + DC
  338. p_x4 = p_x3 + DC
  339. if operation == 0:
  340. for idxc in range(_for_C):
  341. _idx_x = idxc * DH * DW if x_layout == 0 else idxc
  342. _idx_y = idxc * DH * DW if y_layout == 0 else idxc
  343. tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw)
  344. tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw)
  345. tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw)
  346. tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw)
  347. else:
  348. for idxc in range(_for_C):
  349. _idx_x = idxc * DH * DW if x_layout == 0 else idxc
  350. _idx_y = idxc * DH * DW if y_layout == 0 else idxc
  351. tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw)
  352. tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw)
  353. tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw)
  354. tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw)
  355. class CrossScanTritonF(torch.autograd.Function):
  356. @staticmethod
  357. def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
  358. if one_by_one:
  359. if in_channel_first:
  360. B, _, C, H, W = x.shape
  361. else:
  362. B, H, W, _, C = x.shape
  363. else:
  364. if in_channel_first:
  365. B, C, H, W = x.shape
  366. else:
  367. B, H, W, C = x.shape
  368. B, C, H, W = int(B), int(C), int(H), int(W)
  369. BC, BH, BW = 1, 32, 32
  370. NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
  371. ctx.in_channel_first = in_channel_first
  372. ctx.out_channel_first = out_channel_first
  373. ctx.one_by_one = one_by_one
  374. ctx.scans = scans
  375. ctx.shape = (B, C, H, W)
  376. ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
  377. y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C))
  378. triton_cross_scan_flex[(NH * NW, NC, B)](
  379. x.contiguous(), y,
  380. (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
  381. BC, BH, BW, C, H, W, NH, NW
  382. )
  383. return y
  384. @staticmethod
  385. def backward(ctx, y: torch.Tensor):
  386. in_channel_first = ctx.in_channel_first
  387. out_channel_first = ctx.out_channel_first
  388. one_by_one = ctx.one_by_one
  389. scans = ctx.scans
  390. B, C, H, W = ctx.shape
  391. BC, BH, BW, NC, NH, NW = ctx.triton_shape
  392. if one_by_one:
  393. x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C))
  394. else:
  395. x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C))
  396. triton_cross_scan_flex[(NH * NW, NC, B)](
  397. x, y.contiguous(),
  398. (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
  399. BC, BH, BW, C, H, W, NH, NW
  400. )
  401. return x, None, None, None, None
  402. class CrossMergeTritonF(torch.autograd.Function):
  403. @staticmethod
  404. def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0):
  405. if out_channel_first:
  406. B, _, C, H, W = y.shape
  407. else:
  408. B, H, W, _, C = y.shape
  409. B, C, H, W = int(B), int(C), int(H), int(W)
  410. BC, BH, BW = 1, 32, 32
  411. NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC)
  412. ctx.in_channel_first = in_channel_first
  413. ctx.out_channel_first = out_channel_first
  414. ctx.one_by_one = one_by_one
  415. ctx.scans = scans
  416. ctx.shape = (B, C, H, W)
  417. ctx.triton_shape = (BC, BH, BW, NC, NH, NW)
  418. if one_by_one:
  419. x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C))
  420. else:
  421. x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C))
  422. triton_cross_scan_flex[(NH * NW, NC, B)](
  423. x, y.contiguous(),
  424. (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans,
  425. BC, BH, BW, C, H, W, NH, NW
  426. )
  427. return x
  428. @staticmethod
  429. def backward(ctx, x: torch.Tensor):
  430. in_channel_first = ctx.in_channel_first
  431. out_channel_first = ctx.out_channel_first
  432. one_by_one = ctx.one_by_one
  433. scans = ctx.scans
  434. B, C, H, W = ctx.shape
  435. BC, BH, BW, NC, NH, NW = ctx.triton_shape
  436. y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C))
  437. triton_cross_scan_flex[(NH * NW, NC, B)](
  438. x.contiguous(), y,
  439. (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans,
  440. BC, BH, BW, C, H, W, NH, NW
  441. )
  442. return y, None, None, None, None, None
  443. # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
  444. def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
  445. # x: (B, C, H, W) | (B, H, W, C) | (B, 4, C, H, W) | (B, H, W, 4, C)
  446. # y: (B, 4, C, L) | (B, L, 4, C)
  447. # scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
  448. CSF = CrossScanTritonF if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF
  449. with torch.cuda.device(x.device):
  450. return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans)
  451. # @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)
  452. def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False):
  453. # y: (B, 4, C, L) | (B, L, 4, C)
  454. # x: (B, C, H * W) | (B, H * W, C) | (B, 4, C, H * W) | (B, H * W, 4, C)
  455. # scans: 0: cross scan; 1 unidirectional; 2: bidirectional;
  456. CMF = CrossMergeTritonF if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF
  457. with torch.cuda.device(y.device):
  458. return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans)
  459. # checks =================================================================
  460. class CHECK:
  461. def check_csm_triton():
  462. B, C, H, W = 256, 192, 56, 57
  463. dtype=torch.float16
  464. dtype=torch.float32
  465. x = torch.randn((B, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
  466. y = torch.randn((B, 4, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True)
  467. x1 = x.clone().detach().requires_grad_(True)
  468. y1 = y.clone().detach().requires_grad_(True)
  469. def cross_scan(x: torch.Tensor):
  470. B, C, H, W = x.shape
  471. L = H * W
  472. xs = torch.stack([
  473. x.view(B, C, L),
  474. torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L),
  475. torch.flip(x.contiguous().view(B, C, L), dims=[-1]),
  476. torch.flip(torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
  477. ], dim=1).view(B, 4, C, L)
  478. return xs
  479. def cross_merge(out_y: torch.Tensor):
  480. B, K, D, H, W = out_y.shape
  481. L = H * W
  482. out_y = out_y.view(B, K, D, L)
  483. inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
  484. wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
  485. invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
  486. y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y
  487. return y
  488. def cross_scan_1b1(x: torch.Tensor):
  489. B, K, C, H, W = x.shape
  490. L = H * W
  491. xs = torch.stack([
  492. x[:, 0].view(B, C, L),
  493. torch.transpose(x[:, 1], dim0=2, dim1=3).contiguous().view(B, C, L),
  494. torch.flip(x[:, 2].contiguous().view(B, C, L), dims=[-1]),
  495. torch.flip(torch.transpose(x[:, 3], dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]),
  496. ], dim=1).view(B, 4, C, L)
  497. return xs
  498. def unidi_scan(x):
  499. B, C, H, W = x.shape
  500. x = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1)
  501. return x
  502. def unidi_merge(ys):
  503. B, K, C, H, W = ys.shape
  504. return ys.view(B, 4, -1, H * W).sum(1)
  505. def bidi_scan(x):
  506. B, C, H, W = x.shape
  507. x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1)
  508. x = torch.cat([x, x.flip(dims=[-1])], dim=1)
  509. return x
  510. def bidi_merge(ys):
  511. B, K, D, H, W = ys.shape
  512. ys = ys.view(B, K, D, -1)
  513. ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
  514. return ys.contiguous().sum(1)
  515. if True:
  516. res0 = triton.testing.do_bench(lambda :cross_scan(x))
  517. res1 = triton.testing.do_bench(lambda :cross_scan_fn(x, True, True, False))
  518. # res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x))
  519. res3 = triton.testing.do_bench(lambda :cross_merge(y))
  520. res4 = triton.testing.do_bench(lambda :cross_merge_fn(y, True, True, False))
  521. # res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y))
  522. # print(res0, res1, res2, res3, res4, res5)
  523. print(res0, res1, res3, res4)
  524. res0 = triton.testing.do_bench(lambda :cross_scan(x).sum().backward())
  525. res1 = triton.testing.do_bench(lambda :cross_scan_fn(x, True, True, False).sum().backward())
  526. # res2 = triton.testing.do_bench(lambda :CrossScanTriton.apply(x).sum().backward())
  527. res3 = triton.testing.do_bench(lambda :cross_merge(y).sum().backward())
  528. res4 = triton.testing.do_bench(lambda :cross_merge_fn(y, True, True, False).sum().backward())
  529. # res5 = triton.testing.do_bench(lambda :CrossMergeTriton.apply(y).sum().backward())
  530. # print(res0, res1, res2, res3, res4, res5)
  531. print(res0, res1, res3, res4)
  532. print("test cross scan")
  533. for (cs0, cm0, cs1, cm1) in [
  534. # channel_first -> channel_first
  535. (cross_scan, cross_merge, cross_scan_fn, cross_merge_fn),
  536. (unidi_scan, unidi_merge, lambda x: cross_scan_fn(x, scans=1), lambda x: cross_merge_fn(x, scans=1)),
  537. (bidi_scan, bidi_merge, lambda x: cross_scan_fn(x, scans=2), lambda x: cross_merge_fn(x, scans=2)),
  538. # flex: BLC->BCL; BCL->BLC; BLC->BLC;
  539. (cross_scan, cross_merge, lambda x: cross_scan_fn(x.permute(0, 2, 3, 1), in_channel_first=False), lambda x: cross_merge_fn(x, in_channel_first=False).permute(0, 2, 1)),
  540. (cross_scan, cross_merge, lambda x: cross_scan_fn(x, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn(x.permute(0, 3, 4, 1, 2), out_channel_first=False)),
  541. (cross_scan, cross_merge, lambda x: cross_scan_fn(x.permute(0, 2, 3, 1), in_channel_first=False, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn(x.permute(0, 3, 4, 1, 2), in_channel_first=False, out_channel_first=False).permute(0, 2, 1)),
  542. # previous
  543. # (cross_scan, cross_merge, lambda x: CrossScanTriton.apply(x), lambda x: CrossMergeTriton.apply(x)),
  544. # (unidi_scan, unidi_merge, lambda x: getCSM(1)[0].apply(x), lambda x: getCSM(1)[1].apply(x)),
  545. # (bidi_scan, bidi_merge, lambda x: getCSM(2)[0].apply(x), lambda x: getCSM(2)[1].apply(x)),
  546. ]:
  547. x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
  548. o0 = cs0(x)
  549. o1 = cs1(x1)
  550. o0.backward(y.view(B, 4, C, H * W))
  551. o1.backward(y.view(B, 4, C, H * W))
  552. print((o0 - o1).abs().max())
  553. print((x.grad - x1.grad).abs().max())
  554. o0 = cm0(y)
  555. o1 = cm1(y1)
  556. o0.backward(x.view(B, C, H * W))
  557. o1.backward(x.view(B, C, H * W))
  558. print((o0 - o1).abs().max())
  559. print((y.grad - y1.grad).abs().max())
  560. x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
  561. print("===============", flush=True)
  562. print("test cross scan one by one")
  563. for (cs0, cs1) in [
  564. (cross_scan_1b1, lambda x: cross_scan_fn(x, one_by_one=True)),
  565. # (cross_scan_1b1, lambda x: CrossScanTriton1b1.apply(x)),
  566. ]:
  567. o0 = cs0(y)
  568. o1 = cs1(y1)
  569. o0.backward(y.view(B, 4, C, H * W))
  570. o1.backward(y.view(B, 4, C, H * W))
  571. print((o0 - o1).abs().max())
  572. print((y.grad - y1.grad).abs().max())
  573. x.grad, x1.grad, y.grad, y1.grad = None, None, None, None
  574. print("===============", flush=True)
  575. def check_csm_scan3():
  576. if False:
  577. x = torch.arange(0, 16).view(1, 1, 4, 4).cuda()
  578. out1 = cross_scan_fn(x, scans=3, force_torch=True).view(1, 4, 1, 4, 4)
  579. out2 = cross_merge_fn(out1, scans=3, force_torch=True).view(1, 1, 4, 4)
  580. out4 = cross_merge_fn(out1, one_by_one=True, scans=3, force_torch=True).view(1, 4, 1, 4, 4)
  581. out3 = cross_scan_fn(out4, one_by_one=True, scans=3, force_torch=True).view(1, 4, 1, 4, 4)
  582. out5 = cross_scan_fn(x.view(1, 4, 4, 1), in_channel_first=False, out_channel_first=False, scans=3, force_torch=True).view(1, 4, 4, 4, 1)
  583. out6 = cross_merge_fn(out5, in_channel_first=False, out_channel_first=False, scans=3, force_torch=True).view(1, 4, 4, 1)
  584. out8 = cross_merge_fn(out5, in_channel_first=False, out_channel_first=False, one_by_one=True, scans=3, force_torch=True).view(1, 4, 4, 4, 1)
  585. out7 = cross_scan_fn(out8, in_channel_first=False, out_channel_first=False, one_by_one=True, scans=3, force_torch=True).view(1, 4, 4, 4, 1)
  586. print(out1.view(4, -1))
  587. print(out2.view(-1))
  588. print(out3.view(4, -1))
  589. print(out4.view(4, -1))
  590. print(out5.view(-1, 4).t())
  591. print(out6.view(-1))
  592. print(out7.view(-1, 4).t())
  593. print(out8.view(-1, 4).t())
  594. B, C, H, W = 27, 253, 57, 58
  595. x = torch.randn((B, C, H, W)).cuda()
  596. for scans in [0, 1, 2, 3]:
  597. o1 = cross_scan_fn(x, scans=scans, force_torch=True).view(B, 4, C, H, W)
  598. print((cross_scan_fn(x, scans=scans) == cross_scan_fn(x, scans=scans, force_torch=True)).all())
  599. print((cross_merge_fn(o1, scans=scans) == cross_merge_fn(o1, scans=scans, force_torch=True)).all())
  600. kwargs = dict(in_channel_first=False, out_channel_first=False)
  601. x2 = x.permute(0, 2, 3, 1).contiguous()
  602. o2 = o1.permute(0, 3, 4, 1, 2).contiguous()
  603. print((cross_scan_fn(x, scans=scans, **kwargs) == cross_scan_fn(x, scans=scans, force_torch=True, **kwargs)).all())
  604. print((cross_merge_fn(o2, scans=scans, **kwargs) == cross_merge_fn(o2, scans=scans, force_torch=True, **kwargs)).all())
  605. breakpoint()
  606. if __name__ == "__main__":
  607. CHECK.check_csm_scan3()
  608. CHECK.check_csm_triton()