cpu_nms.pyx 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. # --------------------------------------------------------
  2. # Fast R-CNN
  3. # Copyright (c) 2015 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ross Girshick
  6. # --------------------------------------------------------
  7. import numpy as np
  8. cimport numpy as np
  9. cdef inline np.float32_t max(np.float32_t a, np.float32_t b):
  10. return a if a >= b else b
  11. cdef inline np.float32_t min(np.float32_t a, np.float32_t b):
  12. return a if a <= b else b
  13. def cpu_nms(np.ndarray[np.float32_t, ndim=2] dets, np.float thresh):
  14. cdef np.ndarray[np.float32_t, ndim=1] x1 = dets[:, 0]
  15. cdef np.ndarray[np.float32_t, ndim=1] y1 = dets[:, 1]
  16. cdef np.ndarray[np.float32_t, ndim=1] x2 = dets[:, 2]
  17. cdef np.ndarray[np.float32_t, ndim=1] y2 = dets[:, 3]
  18. cdef np.ndarray[np.float32_t, ndim=1] scores = dets[:, 4]
  19. cdef np.ndarray[np.float32_t, ndim=1] areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  20. #cdef np.ndarray[np.int_t, ndim=1] order = scores.argsort()[::-1] #20160531, by MrX
  21. cdef np.ndarray[np.intp_t, ndim=1] order = scores.argsort()[::-1]
  22. cdef int ndets = dets.shape[0]
  23. cdef np.ndarray[np.int_t, ndim=1] suppressed = \
  24. np.zeros((ndets), dtype=np.int)
  25. # nominal indices
  26. cdef int _i, _j
  27. # sorted indices
  28. cdef int i, j
  29. # temp variables for box i's (the box currently under consideration)
  30. cdef np.float32_t ix1, iy1, ix2, iy2, iarea
  31. # variables for computing overlap with box j (lower scoring box)
  32. cdef np.float32_t xx1, yy1, xx2, yy2
  33. cdef np.float32_t w, h
  34. cdef np.float32_t inter, ovr
  35. keep = []
  36. for _i in range(ndets):
  37. i = order[_i]
  38. if suppressed[i] == 1:
  39. continue
  40. keep.append(i)
  41. ix1 = x1[i]
  42. iy1 = y1[i]
  43. ix2 = x2[i]
  44. iy2 = y2[i]
  45. iarea = areas[i]
  46. for _j in range(_i + 1, ndets):
  47. j = order[_j]
  48. if suppressed[j] == 1:
  49. continue
  50. xx1 = max(ix1, x1[j])
  51. yy1 = max(iy1, y1[j])
  52. xx2 = min(ix2, x2[j])
  53. yy2 = min(iy2, y2[j])
  54. w = max(0.0, xx2 - xx1 + 1)
  55. h = max(0.0, yy2 - yy1 + 1)
  56. inter = w * h
  57. ovr = inter / (iarea + areas[j] - inter)
  58. if ovr >= thresh:
  59. suppressed[j] = 1
  60. return keep