pycaffe.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. """
  2. Wrap the internal caffe C++ module (_caffe.so) with a clean, Pythonic
  3. interface.
  4. """
  5. from collections import OrderedDict
  6. try:
  7. from itertools import izip_longest
  8. except:
  9. from itertools import zip_longest as izip_longest
  10. import numpy as np
  11. from ._caffe import Net, SGDSolver, NesterovSolver, AdaGradSolver, \
  12. RMSPropSolver, AdaDeltaSolver, AdamSolver
  13. import caffe.io
  14. import six
  15. # We directly update methods from Net here (rather than using composition or
  16. # inheritance) so that nets created by caffe (e.g., by SGDSolver) will
  17. # automatically have the improved interface.
  18. @property
  19. def _Net_blobs(self):
  20. """
  21. An OrderedDict (bottom to top, i.e., input to output) of network
  22. blobs indexed by name
  23. """
  24. if not hasattr(self, '_blobs_dict'):
  25. self._blobs_dict = OrderedDict(zip(self._blob_names, self._blobs))
  26. return self._blobs_dict
  27. @property
  28. def _Net_blob_loss_weights(self):
  29. """
  30. An OrderedDict (bottom to top, i.e., input to output) of network
  31. blob loss weights indexed by name
  32. """
  33. if not hasattr(self, '_blobs_loss_weights_dict'):
  34. self._blob_loss_weights_dict = OrderedDict(zip(self._blob_names,
  35. self._blob_loss_weights))
  36. return self._blob_loss_weights_dict
  37. @property
  38. def _Net_params(self):
  39. """
  40. An OrderedDict (bottom to top, i.e., input to output) of network
  41. parameters indexed by name; each is a list of multiple blobs (e.g.,
  42. weights and biases)
  43. """
  44. if not hasattr(self, '_params_dict'):
  45. self._params_dict = OrderedDict([(name, lr.blobs)
  46. for name, lr in zip(
  47. self._layer_names, self.layers)
  48. if len(lr.blobs) > 0])
  49. return self._params_dict
  50. @property
  51. def _Net_inputs(self):
  52. if not hasattr(self, '_input_list'):
  53. keys = list(self.blobs.keys())
  54. self._input_list = [keys[i] for i in self._inputs]
  55. return self._input_list
  56. @property
  57. def _Net_outputs(self):
  58. if not hasattr(self, '_output_list'):
  59. keys = list(self.blobs.keys())
  60. self._output_list = [keys[i] for i in self._outputs]
  61. return self._output_list
  62. def _Net_forward(self, blobs=None, start=None, end=None, **kwargs):
  63. """
  64. Forward pass: prepare inputs and run the net forward.
  65. Parameters
  66. ----------
  67. blobs : list of blobs to return in addition to output blobs.
  68. kwargs : Keys are input blob names and values are blob ndarrays.
  69. For formatting inputs for Caffe, see Net.preprocess().
  70. If None, input is taken from data layers.
  71. start : optional name of layer at which to begin the forward pass
  72. end : optional name of layer at which to finish the forward pass
  73. (inclusive)
  74. Returns
  75. -------
  76. outs : {blob name: blob ndarray} dict.
  77. """
  78. if blobs is None:
  79. blobs = []
  80. if start is not None:
  81. start_ind = list(self._layer_names).index(start)
  82. else:
  83. start_ind = 0
  84. if end is not None:
  85. end_ind = list(self._layer_names).index(end)
  86. outputs = set([end] + blobs)
  87. else:
  88. end_ind = len(self.layers) - 1
  89. outputs = set(self.outputs + blobs)
  90. if kwargs:
  91. if set(kwargs.keys()) != set(self.inputs):
  92. raise Exception('Input blob arguments do not match net inputs.')
  93. # Set input according to defined shapes and make arrays single and
  94. # C-contiguous as Caffe expects.
  95. for in_, blob in six.iteritems(kwargs):
  96. if blob.shape[0] != self.blobs[in_].shape[0]:
  97. raise Exception('Input is not batch sized')
  98. self.blobs[in_].data[...] = blob
  99. self._forward(start_ind, end_ind)
  100. # Unpack blobs to extract
  101. return {out: self.blobs[out].data for out in outputs}
  102. def _Net_backward(self, diffs=None, start=None, end=None, **kwargs):
  103. """
  104. Backward pass: prepare diffs and run the net backward.
  105. Parameters
  106. ----------
  107. diffs : list of diffs to return in addition to bottom diffs.
  108. kwargs : Keys are output blob names and values are diff ndarrays.
  109. If None, top diffs are taken from forward loss.
  110. start : optional name of layer at which to begin the backward pass
  111. end : optional name of layer at which to finish the backward pass
  112. (inclusive)
  113. Returns
  114. -------
  115. outs: {blob name: diff ndarray} dict.
  116. """
  117. if diffs is None:
  118. diffs = []
  119. if start is not None:
  120. start_ind = list(self._layer_names).index(start)
  121. else:
  122. start_ind = len(self.layers) - 1
  123. if end is not None:
  124. end_ind = list(self._layer_names).index(end)
  125. outputs = set([end] + diffs)
  126. else:
  127. end_ind = 0
  128. outputs = set(self.inputs + diffs)
  129. if kwargs:
  130. if set(kwargs.keys()) != set(self.outputs):
  131. raise Exception('Top diff arguments do not match net outputs.')
  132. # Set top diffs according to defined shapes and make arrays single and
  133. # C-contiguous as Caffe expects.
  134. for top, diff in six.iteritems(kwargs):
  135. if diff.shape[0] != self.blobs[top].shape[0]:
  136. raise Exception('Diff is not batch sized')
  137. self.blobs[top].diff[...] = diff
  138. self._backward(start_ind, end_ind)
  139. # Unpack diffs to extract
  140. return {out: self.blobs[out].diff for out in outputs}
  141. def _Net_forward_all(self, blobs=None, **kwargs):
  142. """
  143. Run net forward in batches.
  144. Parameters
  145. ----------
  146. blobs : list of blobs to extract as in forward()
  147. kwargs : Keys are input blob names and values are blob ndarrays.
  148. Refer to forward().
  149. Returns
  150. -------
  151. all_outs : {blob name: list of blobs} dict.
  152. """
  153. # Collect outputs from batches
  154. all_outs = {out: [] for out in set(self.outputs + (blobs or []))}
  155. for batch in self._batch(kwargs):
  156. outs = self.forward(blobs=blobs, **batch)
  157. for out, out_blob in six.iteritems(outs):
  158. all_outs[out].extend(out_blob.copy())
  159. # Package in ndarray.
  160. for out in all_outs:
  161. all_outs[out] = np.asarray(all_outs[out])
  162. # Discard padding.
  163. pad = len(six.next(six.itervalues(all_outs))) - len(six.next(six.itervalues(kwargs)))
  164. if pad:
  165. for out in all_outs:
  166. all_outs[out] = all_outs[out][:-pad]
  167. return all_outs
  168. def _Net_forward_backward_all(self, blobs=None, diffs=None, **kwargs):
  169. """
  170. Run net forward + backward in batches.
  171. Parameters
  172. ----------
  173. blobs: list of blobs to extract as in forward()
  174. diffs: list of diffs to extract as in backward()
  175. kwargs: Keys are input (for forward) and output (for backward) blob names
  176. and values are ndarrays. Refer to forward() and backward().
  177. Prefilled variants are called for lack of input or output blobs.
  178. Returns
  179. -------
  180. all_blobs: {blob name: blob ndarray} dict.
  181. all_diffs: {blob name: diff ndarray} dict.
  182. """
  183. # Batch blobs and diffs.
  184. all_outs = {out: [] for out in set(self.outputs + (blobs or []))}
  185. all_diffs = {diff: [] for diff in set(self.inputs + (diffs or []))}
  186. forward_batches = self._batch({in_: kwargs[in_]
  187. for in_ in self.inputs if in_ in kwargs})
  188. backward_batches = self._batch({out: kwargs[out]
  189. for out in self.outputs if out in kwargs})
  190. # Collect outputs from batches (and heed lack of forward/backward batches).
  191. for fb, bb in izip_longest(forward_batches, backward_batches, fillvalue={}):
  192. batch_blobs = self.forward(blobs=blobs, **fb)
  193. batch_diffs = self.backward(diffs=diffs, **bb)
  194. for out, out_blobs in six.iteritems(batch_blobs):
  195. all_outs[out].extend(out_blobs.copy())
  196. for diff, out_diffs in six.iteritems(batch_diffs):
  197. all_diffs[diff].extend(out_diffs.copy())
  198. # Package in ndarray.
  199. for out, diff in zip(all_outs, all_diffs):
  200. all_outs[out] = np.asarray(all_outs[out])
  201. all_diffs[diff] = np.asarray(all_diffs[diff])
  202. # Discard padding at the end and package in ndarray.
  203. pad = len(six.next(six.itervalues(all_outs))) - len(six.next(six.itervalues(kwargs)))
  204. if pad:
  205. for out, diff in zip(all_outs, all_diffs):
  206. all_outs[out] = all_outs[out][:-pad]
  207. all_diffs[diff] = all_diffs[diff][:-pad]
  208. return all_outs, all_diffs
  209. def _Net_set_input_arrays(self, data, labels):
  210. """
  211. Set input arrays of the in-memory MemoryDataLayer.
  212. (Note: this is only for networks declared with the memory data layer.)
  213. """
  214. if labels.ndim == 1:
  215. labels = np.ascontiguousarray(labels[:, np.newaxis, np.newaxis,
  216. np.newaxis])
  217. return self._set_input_arrays(data, labels)
  218. def _Net_batch(self, blobs):
  219. """
  220. Batch blob lists according to net's batch size.
  221. Parameters
  222. ----------
  223. blobs: Keys blob names and values are lists of blobs (of any length).
  224. Naturally, all the lists should have the same length.
  225. Yields
  226. ------
  227. batch: {blob name: list of blobs} dict for a single batch.
  228. """
  229. num = len(six.next(six.itervalues(blobs)))
  230. batch_size = six.next(six.itervalues(self.blobs)).shape[0]
  231. remainder = num % batch_size
  232. num_batches = num // batch_size
  233. # Yield full batches.
  234. for b in range(num_batches):
  235. i = b * batch_size
  236. yield {name: blobs[name][i:i + batch_size] for name in blobs}
  237. # Yield last padded batch, if any.
  238. if remainder > 0:
  239. padded_batch = {}
  240. for name in blobs:
  241. padding = np.zeros((batch_size - remainder,)
  242. + blobs[name].shape[1:])
  243. padded_batch[name] = np.concatenate([blobs[name][-remainder:],
  244. padding])
  245. yield padded_batch
  246. def _Net_get_id_name(func, field):
  247. """
  248. Generic property that maps func to the layer names into an OrderedDict.
  249. Used for top_names and bottom_names.
  250. Parameters
  251. ----------
  252. func: function id -> [id]
  253. field: implementation field name (cache)
  254. Returns
  255. ------
  256. A one-parameter function that can be set as a property.
  257. """
  258. @property
  259. def get_id_name(self):
  260. if not hasattr(self, field):
  261. id_to_name = list(self.blobs)
  262. res = OrderedDict([(self._layer_names[i],
  263. [id_to_name[j] for j in func(self, i)])
  264. for i in range(len(self.layers))])
  265. setattr(self, field, res)
  266. return getattr(self, field)
  267. return get_id_name
  268. # Attach methods to Net.
  269. Net.blobs = _Net_blobs
  270. Net.blob_loss_weights = _Net_blob_loss_weights
  271. Net.params = _Net_params
  272. Net.forward = _Net_forward
  273. Net.backward = _Net_backward
  274. Net.forward_all = _Net_forward_all
  275. Net.forward_backward_all = _Net_forward_backward_all
  276. Net.set_input_arrays = _Net_set_input_arrays
  277. Net._batch = _Net_batch
  278. Net.inputs = _Net_inputs
  279. Net.outputs = _Net_outputs
  280. Net.top_names = _Net_get_id_name(Net._top_ids, "_top_names")
  281. Net.bottom_names = _Net_get_id_name(Net._bottom_ids, "_bottom_names")