blob.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. # --------------------------------------------------------
  2. # Fast R-CNN
  3. # Copyright (c) 2015 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ross Girshick
  6. # --------------------------------------------------------
  7. """Blob helper functions."""
  8. import numpy as np
  9. import cv2
  10. def im_list_to_blob(ims):
  11. """Convert a list of images into a network input.
  12. Assumes images are already prepared (means subtracted, BGR order, ...).
  13. """
  14. max_shape = np.array([im.shape for im in ims]).max(axis=0)
  15. num_images = len(ims)
  16. blob = np.zeros((num_images, max_shape[0], max_shape[1], 3),
  17. dtype=np.float32)
  18. for i in xrange(num_images):
  19. im = ims[i]
  20. blob[i, 0:im.shape[0], 0:im.shape[1], :] = im
  21. # Move channels (axis 3) to axis 1
  22. # Axis order will become: (batch elem, channel, height, width)
  23. channel_swap = (0, 3, 1, 2)
  24. blob = blob.transpose(channel_swap)
  25. return blob
  26. def prep_im_for_blob(im, pixel_means, target_size, max_size):
  27. """Mean subtract and scale an image for use in a blob."""
  28. im = im.astype(np.float32, copy=False)
  29. im -= pixel_means
  30. im_shape = im.shape
  31. im_size_min = np.min(im_shape[0:2])
  32. im_size_max = np.max(im_shape[0:2])
  33. im_scale = float(target_size) / float(im_size_min)
  34. # Prevent the biggest axis from being more than MAX_SIZE
  35. if np.round(im_scale * im_size_max) > max_size:
  36. im_scale = float(max_size) / float(im_size_max)
  37. im = cv2.resize(im, None, None, fx=im_scale, fy=im_scale,
  38. interpolation=cv2.INTER_LINEAR)
  39. return im, im_scale