tuner.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import os
  2. import shutil
  3. import subprocess
  4. import sys
  5. # import tempfile
  6. # import importlib
  7. import random
  8. import string
  9. import json
  10. from functools import partial
  11. from multiprocessing import Pipe, Pool, Process
  12. from pathlib import Path
  13. from tqdm import tqdm
  14. import numpy as np
  15. def read_file(filename):
  16. """ return the contents of the file named filename or None if file not found """
  17. if os.path.isfile(filename):
  18. with open(filename, 'r') as f:
  19. return f.read()
  20. def write_file(filename, string):
  21. """dump the contents of string to a file called filename"""
  22. with open(filename, 'w', encoding="utf-8") as f:
  23. f.write(string)
  24. def prepare_kernel_string(kernel_string, params):
  25. for k, v in params.items():
  26. kernel_string = "#define " + k + " " + str(v) + "\n" + kernel_string
  27. return kernel_string
  28. def compile_extension(temp_dir, install=False, verbose=True):
  29. # Need to copy this process's environments, otherwise it can't find the compilers
  30. env = {**os.environ,
  31. 'TUNING_SOURCE_DIR': str(temp_dir),
  32. 'TUNING_EXTENSION_NAME': str(temp_dir.stem)}
  33. # https://stackoverflow.com/questions/53173314/how-to-change-distutils-output-directory
  34. # Need separate build directories for parallel compilation
  35. output = subprocess.run(
  36. # [sys.executable, "tuning_setup.py", 'build', f'--build-base={str(temp_dir)}',
  37. # f'--build-lib={str(temp_dir)}'],
  38. [sys.executable, "tuning_setup.py", 'build' if not install else 'develop'],
  39. cwd=temp_dir,
  40. env=env,
  41. capture_output=True,
  42. # check=True
  43. )
  44. if verbose:
  45. print(output)
  46. print('Done compiling' if not install else 'Done installing')
  47. def uninstall_extensions(tuning_extension_names, verbose=True):
  48. # Need to copy this process's environments, otherwise it can't find the compilers
  49. env = {**os.environ}
  50. output = subprocess.run(
  51. [sys.executable, '-m', 'pip', 'uninstall', '-y', *tuning_extension_names],
  52. env=env,
  53. capture_output=True,
  54. # check=True
  55. )
  56. if verbose:
  57. print(output)
  58. print('Done uninstalling')
  59. def benchmark_extension(benchmark_script, *benchmark_args, verbose=True):
  60. # Need to copy this process's environments, otherwise it can't find the compilers
  61. env = os.environ
  62. # https://stackoverflow.com/questions/53173314/how-to-change-distutils-output-directory
  63. # Need separate build directories for parallel compilation
  64. process = subprocess.run(
  65. [sys.executable, benchmark_script, *benchmark_args],
  66. env=os.environ,
  67. capture_output=True,
  68. # check=True
  69. )
  70. if verbose:
  71. print(process)
  72. print('Done benchmarking')
  73. return json.loads(process.stdout.decode(sys.stdout.encoding))
  74. # def benchmark(connection, temp_dir):
  75. # import torch
  76. # # module = importlib.import_module(tuning_extension_name)
  77. # torch.ops.load_library(temp_dir / 'torch_butterfly_tuning.so')
  78. # batch_size = 1024
  79. # n = 32
  80. # twiddle = torch.randn(1, 1, 5, n // 2, 2, 2, device='cuda')
  81. # input = torch.randn(batch_size, 1, n, device=twiddle.device)
  82. # output = torch.ops.torch_butterfly.butterfly_multiply_fw(twiddle, input, True)
  83. # # https://medium.com/@auro_227/timing-your-pytorch-code-fragments-e1a556e81f2
  84. # res = []
  85. # for _ in range(32):
  86. # start = torch.cuda.Event(enable_timing=True)
  87. # end = torch.cuda.Event(enable_timing=True)
  88. # start.record()
  89. # output = torch.ops.torch_butterfly.butterfly_multiply_fw(twiddle, input, True)
  90. # end.record()
  91. # torch.cuda.synchronize()
  92. # res.append(start.elapsed_time(end))
  93. # print(output.shape)
  94. # res = np.array(res)
  95. # connection.send((np.mean(res), np.std(res)))
  96. def set_up_tuning_temp_dir(params: dict, source_files, extension_dir, verbose=True):
  97. if verbose:
  98. print('params: ', params)
  99. # TD [2021-10-22]: tempfile.mkdtemp sometimes create dir name with '_' in it, thus messing up
  100. # the extension name.
  101. # temp_dir = Path(tempfile.mkdtemp(prefix="temp_", dir=Path.cwd().parent)).absolute()
  102. tuning_extension_name = 'temp_' + ''.join(random.choices(string.ascii_uppercase + string.digits, k=10))
  103. temp_dir = (Path.cwd().parent / tuning_extension_name).absolute()
  104. if temp_dir.exists():
  105. shutil.rmtree(temp_dir) # shutil.copytree doesn't want directory that already exists
  106. shutil.copytree(extension_dir, temp_dir)
  107. sources = [temp_dir / name for name in source_files]
  108. for kernel_source in sources:
  109. ks = read_file(kernel_source)
  110. ks = prepare_kernel_string(ks, params)
  111. write_file(kernel_source, ks)
  112. return temp_dir
  113. class KernelTuner:
  114. def __init__(self, extension_dir, source_files, params_list, benchmark_script,
  115. benchmark_args, npool=8, verbose=True):
  116. self.extension_dir = extension_dir
  117. self.source_files = source_files
  118. self.params_list = params_list
  119. self.benchmark_script = benchmark_script
  120. self.benchmark_args = benchmark_args
  121. self.npool = npool
  122. self.verbose = verbose
  123. def tune(self):
  124. temp_dirs = [set_up_tuning_temp_dir(params, self.source_files, self.extension_dir,
  125. verbose=self.verbose)
  126. for params in self.params_list]
  127. # Compile in parallel (for speed), then install sequentially to ensure correctness
  128. with Pool(self.npool) as p:
  129. p.map(compile_extension, temp_dirs)
  130. # with Pool(1) as p:
  131. # p.map(partial(compile_extension, install=True), [temp_dirs])
  132. for temp_dir in tqdm(temp_dirs):
  133. try:
  134. compile_extension(temp_dir, install=True)
  135. except:
  136. pass
  137. # # We benchmark on a separate process so that they can import the extension that just got compiled.
  138. # for params, temp_dir in params_tempdir:
  139. # print('Benchmarking: ', params)
  140. # recv_conn, send_conn = Pipe(duplex=False)
  141. # benchmark_process = Process(target=benchmark_fwd, args=(send_conn, str(temp_dir.stem)))
  142. # benchmark_process.start()
  143. # result = recv_conn.recv()
  144. # benchmark_process.join()
  145. # print('result', result)
  146. results = []
  147. for params, temp_dir in tqdm(list(zip(self.params_list, temp_dirs))):
  148. try:
  149. results.append((params,
  150. benchmark_extension(self.benchmark_script,
  151. *['--name', temp_dir.stem] + self.benchmark_args)))
  152. except:
  153. pass
  154. print(results)
  155. uninstall_extensions([temp_dir.stem for temp_dir in temp_dirs])
  156. for temp_dir in temp_dirs:
  157. shutil.rmtree(temp_dir)
  158. return results