test_io.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import numpy as np
  2. import unittest
  3. import caffe
  4. class TestBlobProtoToArray(unittest.TestCase):
  5. def test_old_format(self):
  6. data = np.zeros((10,10))
  7. blob = caffe.proto.caffe_pb2.BlobProto()
  8. blob.data.extend(list(data.flatten()))
  9. shape = (1,1,10,10)
  10. blob.num, blob.channels, blob.height, blob.width = shape
  11. arr = caffe.io.blobproto_to_array(blob)
  12. self.assertEqual(arr.shape, shape)
  13. def test_new_format(self):
  14. data = np.zeros((10,10))
  15. blob = caffe.proto.caffe_pb2.BlobProto()
  16. blob.data.extend(list(data.flatten()))
  17. blob.shape.dim.extend(list(data.shape))
  18. arr = caffe.io.blobproto_to_array(blob)
  19. self.assertEqual(arr.shape, data.shape)
  20. def test_no_shape(self):
  21. data = np.zeros((10,10))
  22. blob = caffe.proto.caffe_pb2.BlobProto()
  23. blob.data.extend(list(data.flatten()))
  24. with self.assertRaises(ValueError):
  25. caffe.io.blobproto_to_array(blob)
  26. def test_scalar(self):
  27. data = np.ones((1)) * 123
  28. blob = caffe.proto.caffe_pb2.BlobProto()
  29. blob.data.extend(list(data.flatten()))
  30. arr = caffe.io.blobproto_to_array(blob)
  31. self.assertEqual(arr, 123)