test_solver.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import unittest
  2. import tempfile
  3. import os
  4. import numpy as np
  5. import six
  6. import caffe
  7. from test_net import simple_net_file
  8. class TestSolver(unittest.TestCase):
  9. def setUp(self):
  10. self.num_output = 13
  11. net_f = simple_net_file(self.num_output)
  12. f = tempfile.NamedTemporaryFile(mode='w+', delete=False)
  13. f.write("""net: '""" + net_f + """'
  14. test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9
  15. weight_decay: 0.0005 lr_policy: 'inv' gamma: 0.0001 power: 0.75
  16. display: 100 max_iter: 100 snapshot_after_train: false
  17. snapshot_prefix: "model" """)
  18. f.close()
  19. self.solver = caffe.SGDSolver(f.name)
  20. # also make sure get_solver runs
  21. caffe.get_solver(f.name)
  22. caffe.set_mode_cpu()
  23. # fill in valid labels
  24. self.solver.net.blobs['label'].data[...] = \
  25. np.random.randint(self.num_output,
  26. size=self.solver.net.blobs['label'].data.shape)
  27. self.solver.test_nets[0].blobs['label'].data[...] = \
  28. np.random.randint(self.num_output,
  29. size=self.solver.test_nets[0].blobs['label'].data.shape)
  30. os.remove(f.name)
  31. os.remove(net_f)
  32. def test_solve(self):
  33. self.assertEqual(self.solver.iter, 0)
  34. self.solver.solve()
  35. self.assertEqual(self.solver.iter, 100)
  36. def test_net_memory(self):
  37. """Check that nets survive after the solver is destroyed."""
  38. nets = [self.solver.net] + list(self.solver.test_nets)
  39. self.assertEqual(len(nets), 2)
  40. del self.solver
  41. total = 0
  42. for net in nets:
  43. for ps in six.itervalues(net.params):
  44. for p in ps:
  45. total += p.data.sum() + p.diff.sum()
  46. for bl in six.itervalues(net.blobs):
  47. total += bl.data.sum() + bl.diff.sum()
  48. def test_snapshot(self):
  49. self.solver.snapshot()
  50. # Check that these files exist and then remove them
  51. files = ['model_iter_0.caffemodel', 'model_iter_0.solverstate']
  52. for fn in files:
  53. assert os.path.isfile(fn)
  54. os.remove(fn)