net_spec.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. """Python net specification.
  2. This module provides a way to write nets directly in Python, using a natural,
  3. functional style. See examples/pycaffe/caffenet.py for an example.
  4. Currently this works as a thin wrapper around the Python protobuf interface,
  5. with layers and parameters automatically generated for the "layers" and
  6. "params" pseudo-modules, which are actually objects using __getattr__ magic
  7. to generate protobuf messages.
  8. Note that when using to_proto or Top.to_proto, names of intermediate blobs will
  9. be automatically generated. To explicitly specify blob names, use the NetSpec
  10. class -- assign to its attributes directly to name layers, and call
  11. NetSpec.to_proto to serialize all assigned layers.
  12. This interface is expected to continue to evolve as Caffe gains new capabilities
  13. for specifying nets. In particular, the automatically generated layer names
  14. are not guaranteed to be forward-compatible.
  15. """
  16. from collections import OrderedDict, Counter
  17. from .proto import caffe_pb2
  18. from google import protobuf
  19. import six
  20. def param_name_dict():
  21. """Find out the correspondence between layer names and parameter names."""
  22. layer = caffe_pb2.LayerParameter()
  23. # get all parameter names (typically underscore case) and corresponding
  24. # type names (typically camel case), which contain the layer names
  25. # (note that not all parameters correspond to layers, but we'll ignore that)
  26. param_names = [f.name for f in layer.DESCRIPTOR.fields if f.name.endswith('_param')]
  27. param_type_names = [type(getattr(layer, s)).__name__ for s in param_names]
  28. # strip the final '_param' or 'Parameter'
  29. param_names = [s[:-len('_param')] for s in param_names]
  30. param_type_names = [s[:-len('Parameter')] for s in param_type_names]
  31. return dict(zip(param_type_names, param_names))
  32. def to_proto(*tops):
  33. """Generate a NetParameter that contains all layers needed to compute
  34. all arguments."""
  35. layers = OrderedDict()
  36. autonames = Counter()
  37. for top in tops:
  38. top.fn._to_proto(layers, {}, autonames)
  39. net = caffe_pb2.NetParameter()
  40. net.layer.extend(layers.values())
  41. return net
  42. def assign_proto(proto, name, val):
  43. """Assign a Python object to a protobuf message, based on the Python
  44. type (in recursive fashion). Lists become repeated fields/messages, dicts
  45. become messages, and other types are assigned directly. For convenience,
  46. repeated fields whose values are not lists are converted to single-element
  47. lists; e.g., `my_repeated_int_field=3` is converted to
  48. `my_repeated_int_field=[3]`."""
  49. is_repeated_field = hasattr(getattr(proto, name), 'extend')
  50. if is_repeated_field and not isinstance(val, list):
  51. val = [val]
  52. if isinstance(val, list):
  53. if isinstance(val[0], dict):
  54. for item in val:
  55. proto_item = getattr(proto, name).add()
  56. for k, v in six.iteritems(item):
  57. assign_proto(proto_item, k, v)
  58. else:
  59. getattr(proto, name).extend(val)
  60. elif isinstance(val, dict):
  61. for k, v in six.iteritems(val):
  62. assign_proto(getattr(proto, name), k, v)
  63. else:
  64. setattr(proto, name, val)
  65. class Top(object):
  66. """A Top specifies a single output blob (which could be one of several
  67. produced by a layer.)"""
  68. def __init__(self, fn, n):
  69. self.fn = fn
  70. self.n = n
  71. def to_proto(self):
  72. """Generate a NetParameter that contains all layers needed to compute
  73. this top."""
  74. return to_proto(self)
  75. def _to_proto(self, layers, names, autonames):
  76. return self.fn._to_proto(layers, names, autonames)
  77. class Function(object):
  78. """A Function specifies a layer, its parameters, and its inputs (which
  79. are Tops from other layers)."""
  80. def __init__(self, type_name, inputs, params):
  81. self.type_name = type_name
  82. self.inputs = inputs
  83. self.params = params
  84. self.ntop = self.params.get('ntop', 1)
  85. # use del to make sure kwargs are not double-processed as layer params
  86. if 'ntop' in self.params:
  87. del self.params['ntop']
  88. self.in_place = self.params.get('in_place', False)
  89. if 'in_place' in self.params:
  90. del self.params['in_place']
  91. self.tops = tuple(Top(self, n) for n in range(self.ntop))
  92. def _get_name(self, names, autonames):
  93. if self not in names and self.ntop > 0:
  94. names[self] = self._get_top_name(self.tops[0], names, autonames)
  95. elif self not in names:
  96. autonames[self.type_name] += 1
  97. names[self] = self.type_name + str(autonames[self.type_name])
  98. return names[self]
  99. def _get_top_name(self, top, names, autonames):
  100. if top not in names:
  101. autonames[top.fn.type_name] += 1
  102. names[top] = top.fn.type_name + str(autonames[top.fn.type_name])
  103. return names[top]
  104. def _to_proto(self, layers, names, autonames):
  105. if self in layers:
  106. return
  107. bottom_names = []
  108. for inp in self.inputs:
  109. inp._to_proto(layers, names, autonames)
  110. bottom_names.append(layers[inp.fn].top[inp.n])
  111. layer = caffe_pb2.LayerParameter()
  112. layer.type = self.type_name
  113. layer.bottom.extend(bottom_names)
  114. if self.in_place:
  115. layer.top.extend(layer.bottom)
  116. else:
  117. for top in self.tops:
  118. layer.top.append(self._get_top_name(top, names, autonames))
  119. layer.name = self._get_name(names, autonames)
  120. for k, v in six.iteritems(self.params):
  121. # special case to handle generic *params
  122. if k.endswith('param'):
  123. assign_proto(layer, k, v)
  124. else:
  125. try:
  126. assign_proto(getattr(layer,
  127. _param_names[self.type_name] + '_param'), k, v)
  128. except (AttributeError, KeyError):
  129. assign_proto(layer, k, v)
  130. layers[self] = layer
  131. class NetSpec(object):
  132. """A NetSpec contains a set of Tops (assigned directly as attributes).
  133. Calling NetSpec.to_proto generates a NetParameter containing all of the
  134. layers needed to produce all of the assigned Tops, using the assigned
  135. names."""
  136. def __init__(self):
  137. super(NetSpec, self).__setattr__('tops', OrderedDict())
  138. def __setattr__(self, name, value):
  139. self.tops[name] = value
  140. def __getattr__(self, name):
  141. return self.tops[name]
  142. def __setitem__(self, key, value):
  143. self.__setattr__(key, value)
  144. def __getitem__(self, item):
  145. return self.__getattr__(item)
  146. def to_proto(self):
  147. names = {v: k for k, v in six.iteritems(self.tops)}
  148. autonames = Counter()
  149. layers = OrderedDict()
  150. for name, top in six.iteritems(self.tops):
  151. top._to_proto(layers, names, autonames)
  152. net = caffe_pb2.NetParameter()
  153. net.layer.extend(layers.values())
  154. return net
  155. class Layers(object):
  156. """A Layers object is a pseudo-module which generates functions that specify
  157. layers; e.g., Layers().Convolution(bottom, kernel_size=3) will produce a Top
  158. specifying a 3x3 convolution applied to bottom."""
  159. def __getattr__(self, name):
  160. def layer_fn(*args, **kwargs):
  161. fn = Function(name, args, kwargs)
  162. if fn.ntop == 0:
  163. return fn
  164. elif fn.ntop == 1:
  165. return fn.tops[0]
  166. else:
  167. return fn.tops
  168. return layer_fn
  169. class Parameters(object):
  170. """A Parameters object is a pseudo-module which generates constants used
  171. in layer parameters; e.g., Parameters().Pooling.MAX is the value used
  172. to specify max pooling."""
  173. def __getattr__(self, name):
  174. class Param:
  175. def __getattr__(self, param_name):
  176. return getattr(getattr(caffe_pb2, name + 'Parameter'), param_name)
  177. return Param()
  178. _param_names = param_name_dict()
  179. layers = Layers()
  180. params = Parameters()