layernorm_gated.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. # Copyright (c) 2024, Tri Dao.
  2. # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
  3. # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
  4. # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
  5. # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
  6. import math
  7. import torch
  8. import torch.nn.functional as F
  9. import triton
  10. import triton.language as tl
  11. from einops import rearrange
  12. def rms_norm_ref(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, upcast=True):
  13. dtype = x.dtype
  14. N = x.shape[-1]
  15. weight = weight.float()
  16. bias = bias.float() if bias is not None else None
  17. if upcast:
  18. x = x.float()
  19. z = z.float() if z is not None else z
  20. if z is not None and not norm_before_gate:
  21. x = x * F.silu(z)
  22. if group_size is None:
  23. rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
  24. out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
  25. else:
  26. x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
  27. rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
  28. out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
  29. if bias is not None:
  30. out = out + bias
  31. if z is not None and norm_before_gate:
  32. out *= F.silu(z)
  33. return out.to(dtype)
  34. @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
  35. @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
  36. @triton.jit
  37. def _layer_norm_fwd_1pass_kernel(
  38. X, # pointer to the input
  39. Y, # pointer to the output
  40. W, # pointer to the weights
  41. B, # pointer to the biases
  42. Z, # pointer to the other branch
  43. Mean, # pointer to the mean
  44. Rstd, # pointer to the 1/std
  45. stride_x_row, # how much to increase the pointer when moving by 1 row
  46. stride_y_row,
  47. stride_z_row,
  48. M, # number of rows in X
  49. N, # number of columns in X
  50. eps, # epsilon to avoid division by zero
  51. BLOCK_N: tl.constexpr,
  52. HAS_BIAS: tl.constexpr,
  53. HAS_Z: tl.constexpr,
  54. NORM_BEFORE_GATE: tl.constexpr,
  55. IS_RMS_NORM: tl.constexpr,
  56. ):
  57. # Map the program id to the row of X and Y it should compute.
  58. row = tl.program_id(0)
  59. group = tl.program_id(1)
  60. X += row * stride_x_row + group * N
  61. Y += row * stride_y_row + group * N
  62. if HAS_Z:
  63. Z += row * stride_z_row + group * N
  64. if not IS_RMS_NORM:
  65. Mean += group * M
  66. Rstd += group * M
  67. W += group * N
  68. if HAS_BIAS:
  69. B += group * N
  70. # Compute mean and variance
  71. cols = tl.arange(0, BLOCK_N)
  72. x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
  73. if HAS_Z and not NORM_BEFORE_GATE:
  74. z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
  75. x *= z * tl.sigmoid(z)
  76. if not IS_RMS_NORM:
  77. mean = tl.sum(x, axis=0) / N
  78. tl.store(Mean + row, mean)
  79. xbar = tl.where(cols < N, x - mean, 0.)
  80. var = tl.sum(xbar * xbar, axis=0) / N
  81. else:
  82. xbar = tl.where(cols < N, x, 0.)
  83. var = tl.sum(xbar * xbar, axis=0) / N
  84. rstd = 1 / tl.sqrt(var + eps)
  85. tl.store(Rstd + row, rstd)
  86. # Normalize and apply linear transformation
  87. mask = cols < N
  88. w = tl.load(W + cols, mask=mask).to(tl.float32)
  89. if HAS_BIAS:
  90. b = tl.load(B + cols, mask=mask).to(tl.float32)
  91. x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
  92. y = x_hat * w + b if HAS_BIAS else x_hat * w
  93. if HAS_Z and NORM_BEFORE_GATE:
  94. z = tl.load(Z + cols, mask=mask).to(tl.float32)
  95. y *= z * tl.sigmoid(z)
  96. # Write output
  97. tl.store(Y + cols, y, mask=mask)
  98. def _layer_norm_fwd(x, weight, bias, eps, z=None, out=None, group_size=None, norm_before_gate=True, is_rms_norm=False):
  99. M, N = x.shape
  100. if group_size is None:
  101. group_size = N
  102. assert N % group_size == 0
  103. ngroups = N // group_size
  104. assert x.stride(-1) == 1
  105. if z is not None:
  106. assert z.stride(-1) == 1
  107. assert z.shape == (M, N)
  108. assert weight.shape == (N,)
  109. assert weight.stride(-1) == 1
  110. if bias is not None:
  111. assert bias.stride(-1) == 1
  112. assert bias.shape == (N,)
  113. # allocate output
  114. if out is not None:
  115. assert out.shape == x.shape
  116. else:
  117. out = torch.empty_like(x)
  118. assert out.stride(-1) == 1
  119. mean = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device) if not is_rms_norm else None
  120. rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
  121. # Less than 64KB per feature: enqueue fused kernel
  122. MAX_FUSED_SIZE = 65536 // x.element_size()
  123. BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
  124. if group_size > BLOCK_N:
  125. raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
  126. # heuristics for number of warps
  127. num_warps = min(max(BLOCK_N // 256, 1), 8)
  128. grid = (M, ngroups)
  129. with torch.cuda.device(x.device.index):
  130. _layer_norm_fwd_1pass_kernel[grid](x, out, weight, bias, z, mean, rstd,
  131. x.stride(0), out.stride(0), z.stride(0) if z is not None else 0,
  132. M, group_size, eps,
  133. BLOCK_N=BLOCK_N,
  134. NORM_BEFORE_GATE=norm_before_gate,
  135. IS_RMS_NORM=is_rms_norm,
  136. num_warps=num_warps)
  137. return out, mean, rstd
  138. @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
  139. @triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
  140. @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
  141. @triton.jit
  142. def _layer_norm_bwd_kernel(
  143. X, # pointer to the input
  144. W, # pointer to the weights
  145. B, # pointer to the biases
  146. Z, # pointer to the other branch
  147. Y, # pointer to the output to be recomputed
  148. DY, # pointer to the output gradient
  149. DX, # pointer to the input gradient
  150. DW, # pointer to the partial sum of weights gradient
  151. DB, # pointer to the partial sum of biases gradient
  152. DZ, # pointer to the other branch
  153. Mean, # pointer to the mean
  154. Rstd, # pointer to the 1/std
  155. stride_x_row, # how much to increase the pointer when moving by 1 row
  156. stride_z_row,
  157. stride_y_row,
  158. stride_dy_row,
  159. stride_dx_row,
  160. stride_dz_row,
  161. stride_dw_row,
  162. stride_db_row,
  163. M, # number of rows in X
  164. N, # number of columns in X
  165. eps, # epsilon to avoid division by zero
  166. rows_per_program,
  167. NORM_BEFORE_GATE: tl.constexpr,
  168. IS_RMS_NORM: tl.constexpr,
  169. HAS_BIAS: tl.constexpr,
  170. HAS_Z: tl.constexpr,
  171. RECOMPUTE_OUTPUT: tl.constexpr,
  172. BLOCK_N: tl.constexpr,
  173. ):
  174. # Map the program id to the elements of X, DX, and DY it should compute.
  175. row_block_id = tl.program_id(0)
  176. group = tl.program_id(1)
  177. row_start = row_block_id * rows_per_program
  178. cols = tl.arange(0, BLOCK_N)
  179. mask = cols < N
  180. X += row_start * stride_x_row + group * N
  181. if HAS_Z:
  182. Z += row_start * stride_z_row + group * N
  183. DZ += row_start * stride_dz_row + group * N
  184. DY += row_start * stride_dy_row + group * N
  185. DX += row_start * stride_dx_row + group * N
  186. if RECOMPUTE_OUTPUT:
  187. Y += row_start * stride_y_row + group * N
  188. if not IS_RMS_NORM:
  189. Mean += group * M
  190. Rstd += group * M
  191. W += group * N
  192. w = tl.load(W + cols, mask=mask).to(tl.float32)
  193. if (RECOMPUTE_OUTPUT or HAS_Z) and HAS_BIAS:
  194. B += group * N
  195. b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32)
  196. dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
  197. if HAS_BIAS:
  198. db = tl.zeros((BLOCK_N,), dtype=tl.float32)
  199. row_end = min((row_block_id + 1) * rows_per_program, M)
  200. for row in range(row_start, row_end):
  201. # Load data to SRAM
  202. x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
  203. dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
  204. if not IS_RMS_NORM:
  205. mean = tl.load(Mean + row)
  206. if HAS_Z and not NORM_BEFORE_GATE:
  207. z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
  208. x_og = x
  209. x = x_og * z * tl.sigmoid(z)
  210. rstd = tl.load(Rstd + row)
  211. # Compute dx
  212. xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
  213. xhat = tl.where(mask, xhat, 0.)
  214. if HAS_Z and NORM_BEFORE_GATE:
  215. z = tl.load(Z + cols, mask=mask, other=0.).to(tl.float32)
  216. z_sigmoid = tl.sigmoid(z)
  217. y = xhat * w + b if HAS_BIAS else xhat * w
  218. if RECOMPUTE_OUTPUT:
  219. tl.store(Y + cols, y * z * z_sigmoid, mask=mask)
  220. dz = dy * y * z_sigmoid * (1 + z * (1 - z_sigmoid))
  221. tl.store(DZ + cols, dz, mask=mask)
  222. dy *= z * z_sigmoid
  223. else:
  224. if RECOMPUTE_OUTPUT:
  225. y = xhat * w + b if HAS_BIAS else xhat * w
  226. tl.store(Y + cols, y, mask=mask)
  227. wdy = w * dy
  228. c1 = tl.sum(xhat * wdy, axis=0) / N
  229. if not IS_RMS_NORM:
  230. c2 = tl.sum(wdy, axis=0) / N
  231. dx = (wdy - (xhat * c1 + c2)) * rstd
  232. else:
  233. dx = (wdy - xhat * c1) * rstd
  234. dw += dy * xhat
  235. if HAS_BIAS:
  236. db += dy
  237. if HAS_Z and not NORM_BEFORE_GATE:
  238. z_sigmoid = tl.sigmoid(z)
  239. dz = dx * x_og * z_sigmoid * (1 + z * (1 - z_sigmoid))
  240. tl.store(DZ + cols, dz, mask=mask)
  241. dx *= z * z_sigmoid
  242. # Write dx
  243. tl.store(DX + cols, dx, mask=mask)
  244. X += stride_x_row
  245. if HAS_Z:
  246. Z += stride_z_row
  247. DZ += stride_dz_row
  248. if RECOMPUTE_OUTPUT:
  249. Y += stride_y_row
  250. DY += stride_dy_row
  251. DX += stride_dx_row
  252. tl.store(DW + row_block_id * stride_dw_row + group * N + cols, dw, mask=mask)
  253. if HAS_BIAS:
  254. tl.store(DB + row_block_id * stride_db_row + group * N + cols, db, mask=mask)
  255. def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, z=None, group_size=None,
  256. norm_before_gate=True, is_rms_norm=False, recompute_output=False, dz=None, out=None):
  257. M, N = x.shape
  258. if group_size is None:
  259. group_size = N
  260. assert N % group_size == 0
  261. ngroups = N // group_size
  262. assert x.stride(-1) == 1
  263. assert dy.stride(-1) == 1
  264. assert dy.shape == (M, N)
  265. if z is not None:
  266. assert z.stride(-1) == 1
  267. assert z.shape == (M, N)
  268. assert weight.shape == (N,)
  269. assert weight.stride(-1) == 1
  270. if bias is not None:
  271. assert bias.stride(-1) == 1
  272. assert bias.shape == (N,)
  273. # allocate output
  274. dx = torch.empty_like(x)
  275. if dz is not None:
  276. assert z is not None
  277. assert dz.shape == z.shape
  278. assert dz.stride(-1) == 1
  279. else:
  280. dz = torch.empty_like(z) if z is not None else None
  281. if recompute_output:
  282. if out is None:
  283. out = torch.empty_like(x)
  284. assert out.shape == x.shape
  285. # Less than 64KB per feature: enqueue fused kernel
  286. MAX_FUSED_SIZE = 65536 // x.element_size()
  287. BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
  288. if group_size > BLOCK_N:
  289. raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
  290. # heuristics for number of warps
  291. num_warps = min(max(BLOCK_N // 256, 1), 8)
  292. sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
  293. # If group size is small (e.g., 64), we're only using 1 warp. So having just 108 programs
  294. # would limit the occupancy.
  295. nrow_groups = math.ceil(sm_count * math.ceil(4 / num_warps) / ngroups)
  296. _dw = torch.empty((nrow_groups, N), dtype=torch.float32, device=weight.device)
  297. _db = torch.empty((nrow_groups, N), dtype=torch.float32, device=bias.device) if bias is not None else None
  298. rows_per_program = math.ceil(M / nrow_groups)
  299. grid = (nrow_groups, ngroups)
  300. with torch.cuda.device(x.device.index):
  301. _layer_norm_bwd_kernel[grid](x, weight, bias, z, out if recompute_output else None,
  302. dy, dx, _dw, _db, dz, mean, rstd,
  303. x.stride(0),
  304. z.stride(0) if z is not None else 0,
  305. 0 if not recompute_output else out.stride(0),
  306. dy.stride(0), dx.stride(0),
  307. dz.stride(0) if dz is not None else 0,
  308. _dw.stride(0),
  309. _db.stride(0) if _db is not None else 0,
  310. M, group_size, eps,
  311. rows_per_program,
  312. BLOCK_N=BLOCK_N,
  313. NORM_BEFORE_GATE=norm_before_gate,
  314. IS_RMS_NORM=is_rms_norm,
  315. num_warps=num_warps)
  316. dw = _dw.sum(0).to(weight.dtype)
  317. db = _db.sum(0).to(bias.dtype) if bias is not None else None
  318. return (dx, dw, db, dz) if not recompute_output else (dx, dw, db, dz, out)
  319. class LayerNormFn(torch.autograd.Function):
  320. @staticmethod
  321. def forward(ctx, x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True,
  322. is_rms_norm=False):
  323. """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
  324. """
  325. x_shape_og = x.shape
  326. # reshape input data into 2D tensor
  327. x = x.reshape(-1, x.shape[-1])
  328. if x.stride(-1) != 1:
  329. x = x.contiguous()
  330. if z is not None:
  331. assert z.shape == x_shape_og
  332. z = z.reshape(-1, z.shape[-1])
  333. if z.stride(-1) != 1:
  334. z = z.contiguous()
  335. weight = weight.contiguous()
  336. if bias is not None:
  337. bias = bias.contiguous()
  338. y, mean, rstd = _layer_norm_fwd(x, weight, bias, eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm)
  339. ctx.save_for_backward(x, weight, bias, mean, rstd, z)
  340. ctx.x_shape_og = x_shape_og
  341. ctx.eps = eps
  342. ctx.group_size = group_size
  343. ctx.norm_before_gate = norm_before_gate
  344. ctx.is_rms_norm = is_rms_norm
  345. return y.reshape(x_shape_og)
  346. @staticmethod
  347. def backward(ctx, dy):
  348. x, weight, bias, mean, rstd, z = ctx.saved_tensors
  349. dy = dy.reshape(-1, dy.shape[-1])
  350. if dy.stride(-1) != 1:
  351. dy = dy.contiguous()
  352. assert dy.shape == x.shape
  353. dx, dw, db, dz = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, z, ctx.group_size,
  354. ctx.norm_before_gate, ctx.is_rms_norm)
  355. return dx.reshape(ctx.x_shape_og), dw, db, dz.reshape(ctx.x_shape_og) if dz is not None else None, None, None, None, None
  356. def layernorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True, is_rms_norm=False):
  357. return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm)
  358. def rmsnorm_fn(x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True):
  359. return LayerNormFn.apply(x, weight, bias, z, eps, group_size, norm_before_gate, True)
  360. class LayerNorm(torch.nn.Module):
  361. def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
  362. """If group_size is not None, we do GroupNorm with each group having group_size elements.
  363. group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
  364. """
  365. factory_kwargs = {"device": device, "dtype": dtype}
  366. super().__init__()
  367. self.eps = eps
  368. self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  369. self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  370. self.group_size = group_size
  371. self.norm_before_gate = norm_before_gate
  372. self.reset_parameters()
  373. def reset_parameters(self):
  374. torch.nn.init.ones_(self.weight)
  375. torch.nn.init.zeros_(self.bias)
  376. def forward(self, x, z=None):
  377. """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
  378. """
  379. return layernorm_fn(x, self.weight, self.bias, z=z, group_size=self.group_size, eps=self.eps,
  380. norm_before_gate=self.norm_before_gate)
  381. class RMSNorm(torch.nn.Module):
  382. def __init__(self, hidden_size, eps=1e-5, group_size=None, norm_before_gate=True, device=None, dtype=None):
  383. """If group_size is not None, we do GroupNorm with each group having group_size elements.
  384. group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
  385. """
  386. factory_kwargs = {"device": device, "dtype": dtype}
  387. super().__init__()
  388. self.eps = eps
  389. self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  390. self.register_parameter("bias", None)
  391. self.group_size = group_size
  392. self.norm_before_gate = norm_before_gate
  393. self.reset_parameters()
  394. def reset_parameters(self):
  395. torch.nn.init.ones_(self.weight)
  396. def forward(self, x, z=None):
  397. """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
  398. """
  399. return rmsnorm_fn(x, self.weight, self.bias, z=z, eps=self.eps, group_size=self.group_size,
  400. norm_before_gate=self.norm_before_gate)