12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- #!/usr/bin/env python
- """
- Draw a graph of the net architecture.
- """
- from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
- from google.protobuf import text_format
- import caffe
- import caffe.draw
- from caffe.proto import caffe_pb2
- def parse_args():
- """Parse input arguments
- """
- parser = ArgumentParser(description=__doc__,
- formatter_class=ArgumentDefaultsHelpFormatter)
- parser.add_argument('input_net_proto_file',
- help='Input network prototxt file')
- parser.add_argument('output_image_file',
- help='Output image file')
- parser.add_argument('--rankdir',
- help=('One of TB (top-bottom, i.e., vertical), '
- 'RL (right-left, i.e., horizontal), or another '
- 'valid dot option; see '
- 'http://www.graphviz.org/doc/info/'
- 'attrs.html#k:rankdir'),
- default='LR')
- parser.add_argument('--phase',
- help=('Which network phase to draw: can be TRAIN, '
- 'TEST, or ALL. If ALL, then all layers are drawn '
- 'regardless of phase.'),
- default="ALL")
- args = parser.parse_args()
- return args
- def main():
- args = parse_args()
- net = caffe_pb2.NetParameter()
- text_format.Merge(open(args.input_net_proto_file).read(), net)
- print('Drawing net to %s' % args.output_image_file)
- phase=None;
- if args.phase == "TRAIN":
- phase = caffe.TRAIN
- elif args.phase == "TEST":
- phase = caffe.TEST
- elif args.phase != "ALL":
- raise ValueError("Unknown phase: " + args.phase)
- caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir,
- phase)
- if __name__ == '__main__':
- main()
|