test_python_layer.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import unittest
  2. import tempfile
  3. import os
  4. import six
  5. import caffe
  6. class SimpleLayer(caffe.Layer):
  7. """A layer that just multiplies by ten"""
  8. def setup(self, bottom, top):
  9. pass
  10. def reshape(self, bottom, top):
  11. top[0].reshape(*bottom[0].data.shape)
  12. def forward(self, bottom, top):
  13. top[0].data[...] = 10 * bottom[0].data
  14. def backward(self, top, propagate_down, bottom):
  15. bottom[0].diff[...] = 10 * top[0].diff
  16. class ExceptionLayer(caffe.Layer):
  17. """A layer for checking exceptions from Python"""
  18. def setup(self, bottom, top):
  19. raise RuntimeError
  20. class ParameterLayer(caffe.Layer):
  21. """A layer that just multiplies by ten"""
  22. def setup(self, bottom, top):
  23. self.blobs.add_blob(1)
  24. self.blobs[0].data[0] = 0
  25. def reshape(self, bottom, top):
  26. top[0].reshape(*bottom[0].data.shape)
  27. def forward(self, bottom, top):
  28. pass
  29. def backward(self, top, propagate_down, bottom):
  30. self.blobs[0].diff[0] = 1
  31. def python_net_file():
  32. with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
  33. f.write("""name: 'pythonnet' force_backward: true
  34. input: 'data' input_shape { dim: 10 dim: 9 dim: 8 }
  35. layer { type: 'Python' name: 'one' bottom: 'data' top: 'one'
  36. python_param { module: 'test_python_layer' layer: 'SimpleLayer' } }
  37. layer { type: 'Python' name: 'two' bottom: 'one' top: 'two'
  38. python_param { module: 'test_python_layer' layer: 'SimpleLayer' } }
  39. layer { type: 'Python' name: 'three' bottom: 'two' top: 'three'
  40. python_param { module: 'test_python_layer' layer: 'SimpleLayer' } }""")
  41. return f.name
  42. def exception_net_file():
  43. with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
  44. f.write("""name: 'pythonnet' force_backward: true
  45. input: 'data' input_shape { dim: 10 dim: 9 dim: 8 }
  46. layer { type: 'Python' name: 'layer' bottom: 'data' top: 'top'
  47. python_param { module: 'test_python_layer' layer: 'ExceptionLayer' } }
  48. """)
  49. return f.name
  50. def parameter_net_file():
  51. with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
  52. f.write("""name: 'pythonnet' force_backward: true
  53. input: 'data' input_shape { dim: 10 dim: 9 dim: 8 }
  54. layer { type: 'Python' name: 'layer' bottom: 'data' top: 'top'
  55. python_param { module: 'test_python_layer' layer: 'ParameterLayer' } }
  56. """)
  57. return f.name
  58. @unittest.skipIf('Python' not in caffe.layer_type_list(),
  59. 'Caffe built without Python layer support')
  60. class TestPythonLayer(unittest.TestCase):
  61. def setUp(self):
  62. net_file = python_net_file()
  63. self.net = caffe.Net(net_file, caffe.TRAIN)
  64. os.remove(net_file)
  65. def test_forward(self):
  66. x = 8
  67. self.net.blobs['data'].data[...] = x
  68. self.net.forward()
  69. for y in self.net.blobs['three'].data.flat:
  70. self.assertEqual(y, 10**3 * x)
  71. def test_backward(self):
  72. x = 7
  73. self.net.blobs['three'].diff[...] = x
  74. self.net.backward()
  75. for y in self.net.blobs['data'].diff.flat:
  76. self.assertEqual(y, 10**3 * x)
  77. def test_reshape(self):
  78. s = 4
  79. self.net.blobs['data'].reshape(s, s, s, s)
  80. self.net.forward()
  81. for blob in six.itervalues(self.net.blobs):
  82. for d in blob.data.shape:
  83. self.assertEqual(s, d)
  84. def test_exception(self):
  85. net_file = exception_net_file()
  86. self.assertRaises(RuntimeError, caffe.Net, net_file, caffe.TEST)
  87. os.remove(net_file)
  88. def test_parameter(self):
  89. net_file = parameter_net_file()
  90. net = caffe.Net(net_file, caffe.TRAIN)
  91. # Test forward and backward
  92. net.forward()
  93. net.backward()
  94. layer = net.layers[list(net._layer_names).index('layer')]
  95. self.assertEqual(layer.blobs[0].data[0], 0)
  96. self.assertEqual(layer.blobs[0].diff[0], 1)
  97. layer.blobs[0].data[0] += layer.blobs[0].diff[0]
  98. self.assertEqual(layer.blobs[0].data[0], 1)
  99. # Test saving and loading
  100. h, caffemodel_file = tempfile.mkstemp()
  101. net.save(caffemodel_file)
  102. layer.blobs[0].data[0] = -1
  103. self.assertEqual(layer.blobs[0].data[0], -1)
  104. net.copy_from(caffemodel_file)
  105. self.assertEqual(layer.blobs[0].data[0], 1)
  106. os.remove(caffemodel_file)
  107. # Test weight sharing
  108. net2 = caffe.Net(net_file, caffe.TRAIN)
  109. net2.share_with(net)
  110. layer = net.layers[list(net2._layer_names).index('layer')]
  111. self.assertEqual(layer.blobs[0].data[0], 1)
  112. os.remove(net_file)