123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- import unittest
- import tempfile
- import os
- import six
- import caffe
- class SimpleLayer(caffe.Layer):
- """A layer that just multiplies by ten"""
- def setup(self, bottom, top):
- pass
- def reshape(self, bottom, top):
- top[0].reshape(*bottom[0].data.shape)
- def forward(self, bottom, top):
- top[0].data[...] = 10 * bottom[0].data
- def backward(self, top, propagate_down, bottom):
- bottom[0].diff[...] = 10 * top[0].diff
- class ExceptionLayer(caffe.Layer):
- """A layer for checking exceptions from Python"""
- def setup(self, bottom, top):
- raise RuntimeError
- class ParameterLayer(caffe.Layer):
- """A layer that just multiplies by ten"""
- def setup(self, bottom, top):
- self.blobs.add_blob(1)
- self.blobs[0].data[0] = 0
- def reshape(self, bottom, top):
- top[0].reshape(*bottom[0].data.shape)
- def forward(self, bottom, top):
- pass
- def backward(self, top, propagate_down, bottom):
- self.blobs[0].diff[0] = 1
- def python_net_file():
- with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
- f.write("""name: 'pythonnet' force_backward: true
- input: 'data' input_shape { dim: 10 dim: 9 dim: 8 }
- layer { type: 'Python' name: 'one' bottom: 'data' top: 'one'
- python_param { module: 'test_python_layer' layer: 'SimpleLayer' } }
- layer { type: 'Python' name: 'two' bottom: 'one' top: 'two'
- python_param { module: 'test_python_layer' layer: 'SimpleLayer' } }
- layer { type: 'Python' name: 'three' bottom: 'two' top: 'three'
- python_param { module: 'test_python_layer' layer: 'SimpleLayer' } }""")
- return f.name
- def exception_net_file():
- with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
- f.write("""name: 'pythonnet' force_backward: true
- input: 'data' input_shape { dim: 10 dim: 9 dim: 8 }
- layer { type: 'Python' name: 'layer' bottom: 'data' top: 'top'
- python_param { module: 'test_python_layer' layer: 'ExceptionLayer' } }
- """)
- return f.name
- def parameter_net_file():
- with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
- f.write("""name: 'pythonnet' force_backward: true
- input: 'data' input_shape { dim: 10 dim: 9 dim: 8 }
- layer { type: 'Python' name: 'layer' bottom: 'data' top: 'top'
- python_param { module: 'test_python_layer' layer: 'ParameterLayer' } }
- """)
- return f.name
- @unittest.skipIf('Python' not in caffe.layer_type_list(),
- 'Caffe built without Python layer support')
- class TestPythonLayer(unittest.TestCase):
- def setUp(self):
- net_file = python_net_file()
- self.net = caffe.Net(net_file, caffe.TRAIN)
- os.remove(net_file)
- def test_forward(self):
- x = 8
- self.net.blobs['data'].data[...] = x
- self.net.forward()
- for y in self.net.blobs['three'].data.flat:
- self.assertEqual(y, 10**3 * x)
- def test_backward(self):
- x = 7
- self.net.blobs['three'].diff[...] = x
- self.net.backward()
- for y in self.net.blobs['data'].diff.flat:
- self.assertEqual(y, 10**3 * x)
- def test_reshape(self):
- s = 4
- self.net.blobs['data'].reshape(s, s, s, s)
- self.net.forward()
- for blob in six.itervalues(self.net.blobs):
- for d in blob.data.shape:
- self.assertEqual(s, d)
- def test_exception(self):
- net_file = exception_net_file()
- self.assertRaises(RuntimeError, caffe.Net, net_file, caffe.TEST)
- os.remove(net_file)
- def test_parameter(self):
- net_file = parameter_net_file()
- net = caffe.Net(net_file, caffe.TRAIN)
- # Test forward and backward
- net.forward()
- net.backward()
- layer = net.layers[list(net._layer_names).index('layer')]
- self.assertEqual(layer.blobs[0].data[0], 0)
- self.assertEqual(layer.blobs[0].diff[0], 1)
- layer.blobs[0].data[0] += layer.blobs[0].diff[0]
- self.assertEqual(layer.blobs[0].data[0], 1)
- # Test saving and loading
- h, caffemodel_file = tempfile.mkstemp()
- net.save(caffemodel_file)
- layer.blobs[0].data[0] = -1
- self.assertEqual(layer.blobs[0].data[0], -1)
- net.copy_from(caffemodel_file)
- self.assertEqual(layer.blobs[0].data[0], 1)
- os.remove(caffemodel_file)
-
- # Test weight sharing
- net2 = caffe.Net(net_file, caffe.TRAIN)
- net2.share_with(net)
- layer = net.layers[list(net2._layer_names).index('layer')]
- self.assertEqual(layer.blobs[0].data[0], 1)
- os.remove(net_file)
|