coord_map.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. """
  2. Determine spatial relationships between layers to relate their coordinates.
  3. Coordinates are mapped from input-to-output (forward), but can
  4. be mapped output-to-input (backward) by the inverse mapping too.
  5. This helps crop and align feature maps among other uses.
  6. """
  7. from __future__ import division
  8. import numpy as np
  9. from caffe import layers as L
  10. PASS_THROUGH_LAYERS = ['AbsVal', 'BatchNorm', 'Bias', 'BNLL', 'Dropout',
  11. 'Eltwise', 'ELU', 'Log', 'LRN', 'Exp', 'MVN', 'Power',
  12. 'ReLU', 'PReLU', 'Scale', 'Sigmoid', 'Split', 'TanH',
  13. 'Threshold']
  14. def conv_params(fn):
  15. """
  16. Extract the spatial parameters that determine the coordinate mapping:
  17. kernel size, stride, padding, and dilation.
  18. Implementation detail: Convolution, Deconvolution, and Im2col layers
  19. define these in the convolution_param message, while Pooling has its
  20. own fields in pooling_param. This method deals with these details to
  21. extract canonical parameters.
  22. """
  23. params = fn.params.get('convolution_param', fn.params)
  24. axis = params.get('axis', 1)
  25. ks = np.array(params['kernel_size'], ndmin=1)
  26. dilation = np.array(params.get('dilation', 1), ndmin=1)
  27. assert len({'pad_h', 'pad_w', 'kernel_h', 'kernel_w', 'stride_h',
  28. 'stride_w'} & set(fn.params)) == 0, \
  29. 'cropping does not support legacy _h/_w params'
  30. return (axis, np.array(params.get('stride', 1), ndmin=1),
  31. (ks - 1) * dilation + 1,
  32. np.array(params.get('pad', 0), ndmin=1))
  33. def crop_params(fn):
  34. """
  35. Extract the crop layer parameters with defaults.
  36. """
  37. params = fn.params.get('crop_param', fn.params)
  38. axis = params.get('axis', 2) # default to spatial crop for N, C, H, W
  39. offset = np.array(params.get('offset', 0), ndmin=1)
  40. return (axis, offset)
  41. class UndefinedMapException(Exception):
  42. """
  43. Exception raised for layers that do not have a defined coordinate mapping.
  44. """
  45. pass
  46. def coord_map(fn):
  47. """
  48. Define the coordinate mapping by its
  49. - axis
  50. - scale: output coord[i * scale] <- input_coord[i]
  51. - shift: output coord[i] <- output_coord[i + shift]
  52. s.t. the identity mapping, as for pointwise layers like ReLu, is defined by
  53. (None, 1, 0) since it is independent of axis and does not transform coords.
  54. """
  55. if fn.type_name in ['Convolution', 'Pooling', 'Im2col']:
  56. axis, stride, ks, pad = conv_params(fn)
  57. return axis, 1 / stride, (pad - (ks - 1) / 2) / stride
  58. elif fn.type_name == 'Deconvolution':
  59. axis, stride, ks, pad = conv_params(fn)
  60. return axis, stride, (ks - 1) / 2 - pad
  61. elif fn.type_name in PASS_THROUGH_LAYERS:
  62. return None, 1, 0
  63. elif fn.type_name == 'Crop':
  64. axis, offset = crop_params(fn)
  65. axis -= 1 # -1 for last non-coordinate dim.
  66. return axis, 1, - offset
  67. else:
  68. raise UndefinedMapException
  69. class AxisMismatchException(Exception):
  70. """
  71. Exception raised for mappings with incompatible axes.
  72. """
  73. pass
  74. def compose(base_map, next_map):
  75. """
  76. Compose a base coord map with scale a1, shift b1 with a further coord map
  77. with scale a2, shift b2. The scales multiply and the further shift, b2,
  78. is scaled by base coord scale a1.
  79. """
  80. ax1, a1, b1 = base_map
  81. ax2, a2, b2 = next_map
  82. if ax1 is None:
  83. ax = ax2
  84. elif ax2 is None or ax1 == ax2:
  85. ax = ax1
  86. else:
  87. raise AxisMismatchException
  88. return ax, a1 * a2, a1 * b2 + b1
  89. def inverse(coord_map):
  90. """
  91. Invert a coord map by de-scaling and un-shifting;
  92. this gives the backward mapping for the gradient.
  93. """
  94. ax, a, b = coord_map
  95. return ax, 1 / a, -b / a
  96. def coord_map_from_to(top_from, top_to):
  97. """
  98. Determine the coordinate mapping betweeen a top (from) and a top (to).
  99. Walk the graph to find a common ancestor while composing the coord maps for
  100. from and to until they meet. As a last step the from map is inverted.
  101. """
  102. # We need to find a common ancestor of top_from and top_to.
  103. # We'll assume that all ancestors are equivalent here (otherwise the graph
  104. # is an inconsistent state (which we could improve this to check for)).
  105. # For now use a brute-force algorithm.
  106. def collect_bottoms(top):
  107. """
  108. Collect the bottoms to walk for the coordinate mapping.
  109. The general rule is that all the bottoms of a layer can be mapped, as
  110. most layers have the same coordinate mapping for each bottom.
  111. Crop layer is a notable exception. Only the first/cropped bottom is
  112. mappable; the second/dimensions bottom is excluded from the walk.
  113. """
  114. bottoms = top.fn.inputs
  115. if top.fn.type_name == 'Crop':
  116. bottoms = bottoms[:1]
  117. return bottoms
  118. # walk back from top_from, keeping the coord map as we go
  119. from_maps = {top_from: (None, 1, 0)}
  120. frontier = {top_from}
  121. while frontier:
  122. top = frontier.pop()
  123. try:
  124. bottoms = collect_bottoms(top)
  125. for bottom in bottoms:
  126. from_maps[bottom] = compose(from_maps[top], coord_map(top.fn))
  127. frontier.add(bottom)
  128. except UndefinedMapException:
  129. pass
  130. # now walk back from top_to until we hit a common blob
  131. to_maps = {top_to: (None, 1, 0)}
  132. frontier = {top_to}
  133. while frontier:
  134. top = frontier.pop()
  135. if top in from_maps:
  136. return compose(to_maps[top], inverse(from_maps[top]))
  137. try:
  138. bottoms = collect_bottoms(top)
  139. for bottom in bottoms:
  140. to_maps[bottom] = compose(to_maps[top], coord_map(top.fn))
  141. frontier.add(bottom)
  142. except UndefinedMapException:
  143. continue
  144. # if we got here, we did not find a blob in common
  145. raise RuntimeError('Could not compute map between tops; are they '
  146. 'connected by spatial layers?')
  147. def crop(top_from, top_to):
  148. """
  149. Define a Crop layer to crop a top (from) to another top (to) by
  150. determining the coordinate mapping between the two and net spec'ing
  151. the axis and shift parameters of the crop.
  152. """
  153. ax, a, b = coord_map_from_to(top_from, top_to)
  154. assert (a == 1).all(), 'scale mismatch on crop (a = {})'.format(a)
  155. assert (b <= 0).all(), 'cannot crop negative offset (b = {})'.format(b)
  156. assert (np.round(b) == b).all(), 'cannot crop noninteger offset ' \
  157. '(b = {})'.format(b)
  158. return L.Crop(top_from, top_to,
  159. crop_param=dict(axis=ax + 1, # +1 for first cropping dim.
  160. offset=list(-np.round(b).astype(int))))