amg.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import cv2 # type: ignore
  6. from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
  7. import argparse
  8. import json
  9. import os
  10. from typing import Any, Dict, List
  11. parser = argparse.ArgumentParser(
  12. description=(
  13. "Runs automatic mask generation on an input image or directory of images, "
  14. "and outputs masks as either PNGs or COCO-style RLEs. Requires open-cv, "
  15. "as well as pycocotools if saving in RLE format."
  16. )
  17. )
  18. parser.add_argument(
  19. "--input",
  20. type=str,
  21. required=True,
  22. help="Path to either a single input image or folder of images.",
  23. )
  24. parser.add_argument(
  25. "--output",
  26. type=str,
  27. required=True,
  28. help=(
  29. "Path to the directory where masks will be output. Output will be either a folder "
  30. "of PNGs per image or a single json with COCO-style masks."
  31. ),
  32. )
  33. parser.add_argument(
  34. "--model-type",
  35. type=str,
  36. required=True,
  37. help="The type of model to load, in ['default', 'vit_h', 'vit_l', 'vit_b']",
  38. )
  39. parser.add_argument(
  40. "--checkpoint",
  41. type=str,
  42. required=True,
  43. help="The path to the SAM checkpoint to use for mask generation.",
  44. )
  45. parser.add_argument("--device", type=str, default="cuda", help="The device to run generation on.")
  46. parser.add_argument(
  47. "--convert-to-rle",
  48. action="store_true",
  49. help=(
  50. "Save masks as COCO RLEs in a single json instead of as a folder of PNGs. "
  51. "Requires pycocotools."
  52. ),
  53. )
  54. amg_settings = parser.add_argument_group("AMG Settings")
  55. amg_settings.add_argument(
  56. "--points-per-side",
  57. type=int,
  58. default=None,
  59. help="Generate masks by sampling a grid over the image with this many points to a side.",
  60. )
  61. amg_settings.add_argument(
  62. "--points-per-batch",
  63. type=int,
  64. default=None,
  65. help="How many input points to process simultaneously in one batch.",
  66. )
  67. amg_settings.add_argument(
  68. "--pred-iou-thresh",
  69. type=float,
  70. default=None,
  71. help="Exclude masks with a predicted score from the model that is lower than this threshold.",
  72. )
  73. amg_settings.add_argument(
  74. "--stability-score-thresh",
  75. type=float,
  76. default=None,
  77. help="Exclude masks with a stability score lower than this threshold.",
  78. )
  79. amg_settings.add_argument(
  80. "--stability-score-offset",
  81. type=float,
  82. default=None,
  83. help="Larger values perturb the mask more when measuring stability score.",
  84. )
  85. amg_settings.add_argument(
  86. "--box-nms-thresh",
  87. type=float,
  88. default=None,
  89. help="The overlap threshold for excluding a duplicate mask.",
  90. )
  91. amg_settings.add_argument(
  92. "--crop-n-layers",
  93. type=int,
  94. default=None,
  95. help=(
  96. "If >0, mask generation is run on smaller crops of the image to generate more masks. "
  97. "The value sets how many different scales to crop at."
  98. ),
  99. )
  100. amg_settings.add_argument(
  101. "--crop-nms-thresh",
  102. type=float,
  103. default=None,
  104. help="The overlap threshold for excluding duplicate masks across different crops.",
  105. )
  106. amg_settings.add_argument(
  107. "--crop-overlap-ratio",
  108. type=int,
  109. default=None,
  110. help="Larger numbers mean image crops will overlap more.",
  111. )
  112. amg_settings.add_argument(
  113. "--crop-n-points-downscale-factor",
  114. type=int,
  115. default=None,
  116. help="The number of points-per-side in each layer of crop is reduced by this factor.",
  117. )
  118. amg_settings.add_argument(
  119. "--min-mask-region-area",
  120. type=int,
  121. default=None,
  122. help=(
  123. "Disconnected mask regions or holes with area smaller than this value "
  124. "in pixels are removed by postprocessing."
  125. ),
  126. )
  127. def write_masks_to_folder(masks: List[Dict[str, Any]], path: str) -> None:
  128. header = "id,area,bbox_x0,bbox_y0,bbox_w,bbox_h,point_input_x,point_input_y,predicted_iou,stability_score,crop_box_x0,crop_box_y0,crop_box_w,crop_box_h" # noqa
  129. metadata = [header]
  130. for i, mask_data in enumerate(masks):
  131. mask = mask_data["segmentation"]
  132. filename = f"{i}.png"
  133. cv2.imwrite(os.path.join(path, filename), mask * 255)
  134. mask_metadata = [
  135. str(i),
  136. str(mask_data["area"]),
  137. *[str(x) for x in mask_data["bbox"]],
  138. *[str(x) for x in mask_data["point_coords"][0]],
  139. str(mask_data["predicted_iou"]),
  140. str(mask_data["stability_score"]),
  141. *[str(x) for x in mask_data["crop_box"]],
  142. ]
  143. row = ",".join(mask_metadata)
  144. metadata.append(row)
  145. metadata_path = os.path.join(path, "metadata.csv")
  146. with open(metadata_path, "w") as f:
  147. f.write("\n".join(metadata))
  148. return
  149. def get_amg_kwargs(args):
  150. amg_kwargs = {
  151. "points_per_side": args.points_per_side,
  152. "points_per_batch": args.points_per_batch,
  153. "pred_iou_thresh": args.pred_iou_thresh,
  154. "stability_score_thresh": args.stability_score_thresh,
  155. "stability_score_offset": args.stability_score_offset,
  156. "box_nms_thresh": args.box_nms_thresh,
  157. "crop_n_layers": args.crop_n_layers,
  158. "crop_nms_thresh": args.crop_nms_thresh,
  159. "crop_overlap_ratio": args.crop_overlap_ratio,
  160. "crop_n_points_downscale_factor": args.crop_n_points_downscale_factor,
  161. "min_mask_region_area": args.min_mask_region_area,
  162. }
  163. amg_kwargs = {k: v for k, v in amg_kwargs.items() if v is not None}
  164. return amg_kwargs
  165. def main(args: argparse.Namespace) -> None:
  166. print("Loading model...")
  167. sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
  168. _ = sam.to(device=args.device)
  169. output_mode = "coco_rle" if args.convert_to_rle else "binary_mask"
  170. amg_kwargs = get_amg_kwargs(args)
  171. generator = SamAutomaticMaskGenerator(sam, output_mode=output_mode, **amg_kwargs)
  172. if not os.path.isdir(args.input):
  173. targets = [args.input]
  174. else:
  175. targets = [
  176. f for f in os.listdir(args.input) if not os.path.isdir(os.path.join(args.input, f))
  177. ]
  178. targets = [os.path.join(args.input, f) for f in targets]
  179. os.makedirs(args.output, exist_ok=True)
  180. for t in targets:
  181. print(f"Processing '{t}'...")
  182. image = cv2.imread(t)
  183. if image is None:
  184. print(f"Could not load '{t}' as an image, skipping...")
  185. continue
  186. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  187. masks = generator.generate(image)
  188. base = os.path.basename(t)
  189. base = os.path.splitext(base)[0]
  190. save_base = os.path.join(args.output, base)
  191. if output_mode == "binary_mask":
  192. os.makedirs(save_base, exist_ok=False)
  193. write_masks_to_folder(masks, save_base)
  194. else:
  195. save_file = save_base + ".json"
  196. with open(save_file, "w") as f:
  197. json.dump(masks, f)
  198. print("Done!")
  199. if __name__ == "__main__":
  200. args = parser.parse_args()
  201. main(args)