classifier.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. #!/usr/bin/env python
  2. """
  3. Classifier is an image classifier specialization of Net.
  4. """
  5. import numpy as np
  6. import caffe
  7. class Classifier(caffe.Net):
  8. """
  9. Classifier extends Net for image class prediction
  10. by scaling, center cropping, or oversampling.
  11. Parameters
  12. ----------
  13. image_dims : dimensions to scale input for cropping/sampling.
  14. Default is to scale to net input size for whole-image crop.
  15. mean, input_scale, raw_scale, channel_swap: params for
  16. preprocessing options.
  17. """
  18. def __init__(self, model_file, pretrained_file, image_dims=None,
  19. mean=None, input_scale=None, raw_scale=None,
  20. channel_swap=None):
  21. caffe.Net.__init__(self, model_file, pretrained_file, caffe.TEST)
  22. # configure pre-processing
  23. in_ = self.inputs[0]
  24. self.transformer = caffe.io.Transformer(
  25. {in_: self.blobs[in_].data.shape})
  26. self.transformer.set_transpose(in_, (2, 0, 1))
  27. if mean is not None:
  28. self.transformer.set_mean(in_, mean)
  29. if input_scale is not None:
  30. self.transformer.set_input_scale(in_, input_scale)
  31. if raw_scale is not None:
  32. self.transformer.set_raw_scale(in_, raw_scale)
  33. if channel_swap is not None:
  34. self.transformer.set_channel_swap(in_, channel_swap)
  35. self.crop_dims = np.array(self.blobs[in_].data.shape[2:])
  36. if not image_dims:
  37. image_dims = self.crop_dims
  38. self.image_dims = image_dims
  39. def predict(self, inputs, oversample=True):
  40. """
  41. Predict classification probabilities of inputs.
  42. Parameters
  43. ----------
  44. inputs : iterable of (H x W x K) input ndarrays.
  45. oversample : boolean
  46. average predictions across center, corners, and mirrors
  47. when True (default). Center-only prediction when False.
  48. Returns
  49. -------
  50. predictions: (N x C) ndarray of class probabilities for N images and C
  51. classes.
  52. """
  53. # Scale to standardize input dimensions.
  54. input_ = np.zeros((len(inputs),
  55. self.image_dims[0],
  56. self.image_dims[1],
  57. inputs[0].shape[2]),
  58. dtype=np.float32)
  59. for ix, in_ in enumerate(inputs):
  60. input_[ix] = caffe.io.resize_image(in_, self.image_dims)
  61. if oversample:
  62. # Generate center, corner, and mirrored crops.
  63. input_ = caffe.io.oversample(input_, self.crop_dims)
  64. else:
  65. # Take center crop.
  66. center = np.array(self.image_dims) / 2.0
  67. crop = np.tile(center, (1, 2))[0] + np.concatenate([
  68. -self.crop_dims / 2.0,
  69. self.crop_dims / 2.0
  70. ])
  71. crop = crop.astype(int)
  72. input_ = input_[:, crop[0]:crop[2], crop[1]:crop[3], :]
  73. # Classify
  74. caffe_in = np.zeros(np.array(input_.shape)[[0, 3, 1, 2]],
  75. dtype=np.float32)
  76. for ix, in_ in enumerate(input_):
  77. caffe_in[ix] = self.transformer.preprocess(self.inputs[0], in_)
  78. out = self.forward_all(**{self.inputs[0]: caffe_in})
  79. predictions = out[self.outputs[0]]
  80. # For oversampling, average predictions across crops.
  81. if oversample:
  82. predictions = predictions.reshape((len(predictions) / 10, 10, -1))
  83. predictions = predictions.mean(1)
  84. return predictions