utils.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from torchvision import transforms
  2. import numpy as np
  3. from . import TRANSFORMS
  4. def get_transforms(cfg, train, cfg_transforms):
  5. transform_list = []
  6. for t in cfg_transforms:
  7. t = {k: v for k, v in t.items()}
  8. t_type = t.pop('type')
  9. t_tran = TRANSFORMS.get_module(t_type)(**t)
  10. transform_list.extend(t_tran) if isinstance(t_tran, list) else transform_list.append(t_tran)
  11. transform_out = TRANSFORMS.get_module('Compose')(transform_list)
  12. if train:
  13. if cfg.size <= 32:
  14. transform_out[0] = transforms.RandomCrop(cfg.size, padding=4)
  15. return transform_out
  16. def make_divisible(v, divisor=8, min_value=None):
  17. if min_value is None:
  18. min_value = divisor
  19. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  20. # Make sure that round down does not go down by more than 10%.
  21. if new_v < 0.9 * v:
  22. new_v += divisor
  23. return new_v
  24. def get_scales(n_scale, base_h, base_w, min_h, max_h, min_w, max_w, check_scale_div_factor=32):
  25. hs = list(np.linspace(min_h, max_h, n_scale))
  26. if base_h not in hs:
  27. hs.append(base_h)
  28. ws = list(np.linspace(min_w, max_w, n_scale))
  29. if base_w not in ws:
  30. ws.append(base_w)
  31. scales = set()
  32. for h, w in zip(hs, ws):
  33. h = make_divisible(h, check_scale_div_factor)
  34. w = make_divisible(w, check_scale_div_factor)
  35. scales.add((h, w))
  36. scales = list(scales)
  37. return scales