test_net.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import unittest
  2. import tempfile
  3. import os
  4. import numpy as np
  5. import six
  6. import caffe
  7. def simple_net_file(num_output):
  8. """Make a simple net prototxt, based on test_net.cpp, returning the name
  9. of the (temporary) file."""
  10. f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
  11. f.write("""name: 'testnet' force_backward: true
  12. layer { type: 'DummyData' name: 'data' top: 'data' top: 'label'
  13. dummy_data_param { num: 5 channels: 2 height: 3 width: 4
  14. num: 5 channels: 1 height: 1 width: 1
  15. data_filler { type: 'gaussian' std: 1 }
  16. data_filler { type: 'constant' } } }
  17. layer { type: 'Convolution' name: 'conv' bottom: 'data' top: 'conv'
  18. convolution_param { num_output: 11 kernel_size: 2 pad: 3
  19. weight_filler { type: 'gaussian' std: 1 }
  20. bias_filler { type: 'constant' value: 2 } }
  21. param { decay_mult: 1 } param { decay_mult: 0 }
  22. }
  23. layer { type: 'InnerProduct' name: 'ip' bottom: 'conv' top: 'ip'
  24. inner_product_param { num_output: """ + str(num_output) + """
  25. weight_filler { type: 'gaussian' std: 2.5 }
  26. bias_filler { type: 'constant' value: -3 } } }
  27. layer { type: 'SoftmaxWithLoss' name: 'loss' bottom: 'ip' bottom: 'label'
  28. top: 'loss' }""")
  29. f.close()
  30. return f.name
  31. class TestNet(unittest.TestCase):
  32. def setUp(self):
  33. self.num_output = 13
  34. net_file = simple_net_file(self.num_output)
  35. self.net = caffe.Net(net_file, caffe.TRAIN)
  36. # fill in valid labels
  37. self.net.blobs['label'].data[...] = \
  38. np.random.randint(self.num_output,
  39. size=self.net.blobs['label'].data.shape)
  40. os.remove(net_file)
  41. def test_memory(self):
  42. """Check that holding onto blob data beyond the life of a Net is OK"""
  43. params = sum(map(list, six.itervalues(self.net.params)), [])
  44. blobs = self.net.blobs.values()
  45. del self.net
  46. # now sum everything (forcing all memory to be read)
  47. total = 0
  48. for p in params:
  49. total += p.data.sum() + p.diff.sum()
  50. for bl in blobs:
  51. total += bl.data.sum() + bl.diff.sum()
  52. def test_forward_backward(self):
  53. self.net.forward()
  54. self.net.backward()
  55. def test_inputs_outputs(self):
  56. self.assertEqual(self.net.inputs, [])
  57. self.assertEqual(self.net.outputs, ['loss'])
  58. def test_save_and_read(self):
  59. f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
  60. f.close()
  61. self.net.save(f.name)
  62. net_file = simple_net_file(self.num_output)
  63. net2 = caffe.Net(net_file, f.name, caffe.TRAIN)
  64. os.remove(net_file)
  65. os.remove(f.name)
  66. for name in self.net.params:
  67. for i in range(len(self.net.params[name])):
  68. self.assertEqual(abs(self.net.params[name][i].data
  69. - net2.params[name][i].data).sum(), 0)