draw.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. """
  2. Caffe network visualization: draw the NetParameter protobuffer.
  3. .. note::
  4. This requires pydot>=1.0.2, which is not included in requirements.txt since
  5. it requires graphviz and other prerequisites outside the scope of the
  6. Caffe.
  7. """
  8. from caffe.proto import caffe_pb2
  9. """
  10. pydot is not supported under python 3 and pydot2 doesn't work properly.
  11. pydotplus works nicely (pip install pydotplus)
  12. """
  13. try:
  14. # Try to load pydotplus
  15. import pydotplus as pydot
  16. except ImportError:
  17. import pydot
  18. # Internal layer and blob styles.
  19. LAYER_STYLE_DEFAULT = {'shape': 'record',
  20. 'fillcolor': '#6495ED',
  21. 'style': 'filled'}
  22. NEURON_LAYER_STYLE = {'shape': 'record',
  23. 'fillcolor': '#90EE90',
  24. 'style': 'filled'}
  25. BLOB_STYLE = {'shape': 'octagon',
  26. 'fillcolor': '#E0E0E0',
  27. 'style': 'filled'}
  28. def get_pooling_types_dict():
  29. """Get dictionary mapping pooling type number to type name
  30. """
  31. desc = caffe_pb2.PoolingParameter.PoolMethod.DESCRIPTOR
  32. d = {}
  33. for k, v in desc.values_by_name.items():
  34. d[v.number] = k
  35. return d
  36. def get_edge_label(layer):
  37. """Define edge label based on layer type.
  38. """
  39. if layer.type == 'Data':
  40. edge_label = 'Batch ' + str(layer.data_param.batch_size)
  41. elif layer.type == 'Convolution' or layer.type == 'Deconvolution':
  42. edge_label = str(layer.convolution_param.num_output)
  43. elif layer.type == 'InnerProduct':
  44. edge_label = str(layer.inner_product_param.num_output)
  45. else:
  46. edge_label = '""'
  47. return edge_label
  48. def get_layer_label(layer, rankdir):
  49. """Define node label based on layer type.
  50. Parameters
  51. ----------
  52. layer : ?
  53. rankdir : {'LR', 'TB', 'BT'}
  54. Direction of graph layout.
  55. Returns
  56. -------
  57. string :
  58. A label for the current layer
  59. """
  60. if rankdir in ('TB', 'BT'):
  61. # If graph orientation is vertical, horizontal space is free and
  62. # vertical space is not; separate words with spaces
  63. separator = ' '
  64. else:
  65. # If graph orientation is horizontal, vertical space is free and
  66. # horizontal space is not; separate words with newlines
  67. separator = '\\n'
  68. if layer.type == 'Convolution' or layer.type == 'Deconvolution':
  69. # Outer double quotes needed or else colon characters don't parse
  70. # properly
  71. node_label = '"%s%s(%s)%skernel size: %d%sstride: %d%spad: %d"' %\
  72. (layer.name,
  73. separator,
  74. layer.type,
  75. separator,
  76. layer.convolution_param.kernel_size[0] if len(layer.convolution_param.kernel_size._values) else 1,
  77. separator,
  78. layer.convolution_param.stride[0] if len(layer.convolution_param.stride._values) else 1,
  79. separator,
  80. layer.convolution_param.pad[0] if len(layer.convolution_param.pad._values) else 0)
  81. elif layer.type == 'Pooling':
  82. pooling_types_dict = get_pooling_types_dict()
  83. node_label = '"%s%s(%s %s)%skernel size: %d%sstride: %d%spad: %d"' %\
  84. (layer.name,
  85. separator,
  86. pooling_types_dict[layer.pooling_param.pool],
  87. layer.type,
  88. separator,
  89. layer.pooling_param.kernel_size,
  90. separator,
  91. layer.pooling_param.stride,
  92. separator,
  93. layer.pooling_param.pad)
  94. else:
  95. node_label = '"%s%s(%s)"' % (layer.name, separator, layer.type)
  96. return node_label
  97. def choose_color_by_layertype(layertype):
  98. """Define colors for nodes based on the layer type.
  99. """
  100. color = '#6495ED' # Default
  101. if layertype == 'Convolution' or layertype == 'Deconvolution':
  102. color = '#FF5050'
  103. elif layertype == 'Pooling':
  104. color = '#FF9900'
  105. elif layertype == 'InnerProduct':
  106. color = '#CC33FF'
  107. return color
  108. def get_pydot_graph(caffe_net, rankdir, label_edges=True, phase=None):
  109. """Create a data structure which represents the `caffe_net`.
  110. Parameters
  111. ----------
  112. caffe_net : object
  113. rankdir : {'LR', 'TB', 'BT'}
  114. Direction of graph layout.
  115. label_edges : boolean, optional
  116. Label the edges (default is True).
  117. phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
  118. Include layers from this network phase. If None, include all layers.
  119. (the default is None)
  120. Returns
  121. -------
  122. pydot graph object
  123. """
  124. pydot_graph = pydot.Dot(caffe_net.name if caffe_net.name else 'Net',
  125. graph_type='digraph',
  126. rankdir=rankdir)
  127. pydot_nodes = {}
  128. pydot_edges = []
  129. for layer in caffe_net.layer:
  130. if phase is not None:
  131. included = False
  132. if len(layer.include) == 0:
  133. included = True
  134. if len(layer.include) > 0 and len(layer.exclude) > 0:
  135. raise ValueError('layer ' + layer.name + ' has both include '
  136. 'and exclude specified.')
  137. for layer_phase in layer.include:
  138. included = included or layer_phase.phase == phase
  139. for layer_phase in layer.exclude:
  140. included = included and not layer_phase.phase == phase
  141. if not included:
  142. continue
  143. node_label = get_layer_label(layer, rankdir)
  144. node_name = "%s_%s" % (layer.name, layer.type)
  145. if (len(layer.bottom) == 1 and len(layer.top) == 1 and
  146. layer.bottom[0] == layer.top[0]):
  147. # We have an in-place neuron layer.
  148. pydot_nodes[node_name] = pydot.Node(node_label,
  149. **NEURON_LAYER_STYLE)
  150. else:
  151. layer_style = LAYER_STYLE_DEFAULT
  152. layer_style['fillcolor'] = choose_color_by_layertype(layer.type)
  153. pydot_nodes[node_name] = pydot.Node(node_label, **layer_style)
  154. for bottom_blob in layer.bottom:
  155. pydot_nodes[bottom_blob + '_blob'] = pydot.Node('%s' % bottom_blob,
  156. **BLOB_STYLE)
  157. edge_label = '""'
  158. pydot_edges.append({'src': bottom_blob + '_blob',
  159. 'dst': node_name,
  160. 'label': edge_label})
  161. for top_blob in layer.top:
  162. pydot_nodes[top_blob + '_blob'] = pydot.Node('%s' % (top_blob))
  163. if label_edges:
  164. edge_label = get_edge_label(layer)
  165. else:
  166. edge_label = '""'
  167. pydot_edges.append({'src': node_name,
  168. 'dst': top_blob + '_blob',
  169. 'label': edge_label})
  170. # Now, add the nodes and edges to the graph.
  171. for node in pydot_nodes.values():
  172. pydot_graph.add_node(node)
  173. for edge in pydot_edges:
  174. pydot_graph.add_edge(
  175. pydot.Edge(pydot_nodes[edge['src']],
  176. pydot_nodes[edge['dst']],
  177. label=edge['label']))
  178. return pydot_graph
  179. def draw_net(caffe_net, rankdir, ext='png', phase=None):
  180. """Draws a caffe net and returns the image string encoded using the given
  181. extension.
  182. Parameters
  183. ----------
  184. caffe_net : a caffe.proto.caffe_pb2.NetParameter protocol buffer.
  185. ext : string, optional
  186. The image extension (the default is 'png').
  187. phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
  188. Include layers from this network phase. If None, include all layers.
  189. (the default is None)
  190. Returns
  191. -------
  192. string :
  193. Postscript representation of the graph.
  194. """
  195. return get_pydot_graph(caffe_net, rankdir, phase=phase).create(format=ext)
  196. def draw_net_to_file(caffe_net, filename, rankdir='LR', phase=None):
  197. """Draws a caffe net, and saves it to file using the format given as the
  198. file extension. Use '.raw' to output raw text that you can manually feed
  199. to graphviz to draw graphs.
  200. Parameters
  201. ----------
  202. caffe_net : a caffe.proto.caffe_pb2.NetParameter protocol buffer.
  203. filename : string
  204. The path to a file where the networks visualization will be stored.
  205. rankdir : {'LR', 'TB', 'BT'}
  206. Direction of graph layout.
  207. phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
  208. Include layers from this network phase. If None, include all layers.
  209. (the default is None)
  210. """
  211. ext = filename[filename.rfind('.')+1:]
  212. with open(filename, 'wb') as fid:
  213. fid.write(draw_net(caffe_net, rankdir, ext, phase))