draw_net.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. #!/usr/bin/env python
  2. """
  3. Draw a graph of the net architecture.
  4. """
  5. from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
  6. from google.protobuf import text_format
  7. import caffe
  8. import caffe.draw
  9. from caffe.proto import caffe_pb2
  10. def parse_args():
  11. """Parse input arguments
  12. """
  13. parser = ArgumentParser(description=__doc__,
  14. formatter_class=ArgumentDefaultsHelpFormatter)
  15. parser.add_argument('input_net_proto_file',
  16. help='Input network prototxt file')
  17. parser.add_argument('output_image_file',
  18. help='Output image file')
  19. parser.add_argument('--rankdir',
  20. help=('One of TB (top-bottom, i.e., vertical), '
  21. 'RL (right-left, i.e., horizontal), or another '
  22. 'valid dot option; see '
  23. 'http://www.graphviz.org/doc/info/'
  24. 'attrs.html#k:rankdir'),
  25. default='LR')
  26. parser.add_argument('--phase',
  27. help=('Which network phase to draw: can be TRAIN, '
  28. 'TEST, or ALL. If ALL, then all layers are drawn '
  29. 'regardless of phase.'),
  30. default="ALL")
  31. args = parser.parse_args()
  32. return args
  33. def main():
  34. args = parse_args()
  35. net = caffe_pb2.NetParameter()
  36. text_format.Merge(open(args.input_net_proto_file).read(), net)
  37. print('Drawing net to %s' % args.output_image_file)
  38. phase=None;
  39. if args.phase == "TRAIN":
  40. phase = caffe.TRAIN
  41. elif args.phase == "TEST":
  42. phase = caffe.TEST
  43. elif args.phase != "ALL":
  44. raise ValueError("Unknown phase: " + args.phase)
  45. caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir,
  46. phase)
  47. if __name__ == '__main__':
  48. main()