_caffe.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. #include <Python.h> // NOLINT(build/include_alpha)
  2. // Produce deprecation warnings (needs to come before arrayobject.h inclusion).
  3. #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
  4. #include <boost/make_shared.hpp>
  5. #include <boost/python.hpp>
  6. #include <boost/python/raw_function.hpp>
  7. #include <boost/python/suite/indexing/vector_indexing_suite.hpp>
  8. #include <boost/python/enum.hpp>
  9. #include <numpy/arrayobject.h>
  10. // these need to be included after boost on OS X
  11. #include <string> // NOLINT(build/include_order)
  12. #include <vector> // NOLINT(build/include_order)
  13. #include <fstream> // NOLINT
  14. #include "caffe/caffe.hpp"
  15. #include "caffe/layers/memory_data_layer.hpp"
  16. #include "caffe/layers/python_layer.hpp"
  17. #include "caffe/sgd_solvers.hpp"
  18. // Temporary solution for numpy < 1.7 versions: old macro, no promises.
  19. // You're strongly advised to upgrade to >= 1.7.
  20. #ifndef NPY_ARRAY_C_CONTIGUOUS
  21. #define NPY_ARRAY_C_CONTIGUOUS NPY_C_CONTIGUOUS
  22. #define PyArray_SetBaseObject(arr, x) (PyArray_BASE(arr) = (x))
  23. #endif
  24. namespace bp = boost::python;
  25. namespace caffe {
  26. // For Python, for now, we'll just always use float as the type.
  27. typedef float Dtype;
  28. const int NPY_DTYPE = NPY_FLOAT32;
  29. // Selecting mode.
  30. void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); }
  31. void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); }
  32. // For convenience, check that input files can be opened, and raise an
  33. // exception that boost will send to Python if not (caffe could still crash
  34. // later if the input files are disturbed before they are actually used, but
  35. // this saves frustration in most cases).
  36. static void CheckFile(const string& filename) {
  37. std::ifstream f(filename.c_str());
  38. if (!f.good()) {
  39. f.close();
  40. throw std::runtime_error("Could not open file " + filename);
  41. }
  42. f.close();
  43. }
  44. void CheckContiguousArray(PyArrayObject* arr, string name,
  45. int channels, int height, int width) {
  46. if (!(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS)) {
  47. throw std::runtime_error(name + " must be C contiguous");
  48. }
  49. if (PyArray_NDIM(arr) != 4) {
  50. throw std::runtime_error(name + " must be 4-d");
  51. }
  52. if (PyArray_TYPE(arr) != NPY_FLOAT32) {
  53. throw std::runtime_error(name + " must be float32");
  54. }
  55. if (PyArray_DIMS(arr)[1] != channels) {
  56. throw std::runtime_error(name + " has wrong number of channels");
  57. }
  58. if (PyArray_DIMS(arr)[2] != height) {
  59. throw std::runtime_error(name + " has wrong height");
  60. }
  61. if (PyArray_DIMS(arr)[3] != width) {
  62. throw std::runtime_error(name + " has wrong width");
  63. }
  64. }
  65. // Net constructor for passing phase as int
  66. shared_ptr<Net<Dtype> > Net_Init(
  67. string param_file, int phase) {
  68. CheckFile(param_file);
  69. shared_ptr<Net<Dtype> > net(new Net<Dtype>(param_file,
  70. static_cast<Phase>(phase)));
  71. return net;
  72. }
  73. // Net construct-and-load convenience constructor
  74. shared_ptr<Net<Dtype> > Net_Init_Load(
  75. string param_file, string pretrained_param_file, int phase) {
  76. CheckFile(param_file);
  77. CheckFile(pretrained_param_file);
  78. shared_ptr<Net<Dtype> > net(new Net<Dtype>(param_file,
  79. static_cast<Phase>(phase)));
  80. net->CopyTrainedLayersFrom(pretrained_param_file);
  81. return net;
  82. }
  83. void Net_Save(const Net<Dtype>& net, string filename) {
  84. NetParameter net_param;
  85. net.ToProto(&net_param, false);
  86. WriteProtoToBinaryFile(net_param, filename.c_str());
  87. }
  88. void Net_SetInputArrays(Net<Dtype>* net, bp::object data_obj,
  89. bp::object labels_obj) {
  90. // check that this network has an input MemoryDataLayer
  91. shared_ptr<MemoryDataLayer<Dtype> > md_layer =
  92. boost::dynamic_pointer_cast<MemoryDataLayer<Dtype> >(net->layers()[0]);
  93. if (!md_layer) {
  94. throw std::runtime_error("set_input_arrays may only be called if the"
  95. " first layer is a MemoryDataLayer");
  96. }
  97. // check that we were passed appropriately-sized contiguous memory
  98. PyArrayObject* data_arr =
  99. reinterpret_cast<PyArrayObject*>(data_obj.ptr());
  100. PyArrayObject* labels_arr =
  101. reinterpret_cast<PyArrayObject*>(labels_obj.ptr());
  102. CheckContiguousArray(data_arr, "data array", md_layer->channels(),
  103. md_layer->height(), md_layer->width());
  104. CheckContiguousArray(labels_arr, "labels array", 1, 1, 1);
  105. if (PyArray_DIMS(data_arr)[0] != PyArray_DIMS(labels_arr)[0]) {
  106. throw std::runtime_error("data and labels must have the same first"
  107. " dimension");
  108. }
  109. if (PyArray_DIMS(data_arr)[0] % md_layer->batch_size() != 0) {
  110. throw std::runtime_error("first dimensions of input arrays must be a"
  111. " multiple of batch size");
  112. }
  113. md_layer->Reset(static_cast<Dtype*>(PyArray_DATA(data_arr)),
  114. static_cast<Dtype*>(PyArray_DATA(labels_arr)),
  115. PyArray_DIMS(data_arr)[0]);
  116. }
  117. Solver<Dtype>* GetSolverFromFile(const string& filename) {
  118. SolverParameter param;
  119. ReadSolverParamsFromTextFileOrDie(filename, &param);
  120. return SolverRegistry<Dtype>::CreateSolver(param);
  121. }
  122. struct NdarrayConverterGenerator {
  123. template <typename T> struct apply;
  124. };
  125. template <>
  126. struct NdarrayConverterGenerator::apply<Dtype*> {
  127. struct type {
  128. PyObject* operator() (Dtype* data) const {
  129. // Just store the data pointer, and add the shape information in postcall.
  130. return PyArray_SimpleNewFromData(0, NULL, NPY_DTYPE, data);
  131. }
  132. const PyTypeObject* get_pytype() {
  133. return &PyArray_Type;
  134. }
  135. };
  136. };
  137. struct NdarrayCallPolicies : public bp::default_call_policies {
  138. typedef NdarrayConverterGenerator result_converter;
  139. PyObject* postcall(PyObject* pyargs, PyObject* result) {
  140. bp::object pyblob = bp::extract<bp::tuple>(pyargs)()[0];
  141. shared_ptr<Blob<Dtype> > blob =
  142. bp::extract<shared_ptr<Blob<Dtype> > >(pyblob);
  143. // Free the temporary pointer-holding array, and construct a new one with
  144. // the shape information from the blob.
  145. void* data = PyArray_DATA(reinterpret_cast<PyArrayObject*>(result));
  146. Py_DECREF(result);
  147. const int num_axes = blob->num_axes();
  148. vector<npy_intp> dims(blob->shape().begin(), blob->shape().end());
  149. PyObject *arr_obj = PyArray_SimpleNewFromData(num_axes, dims.data(),
  150. NPY_FLOAT32, data);
  151. // SetBaseObject steals a ref, so we need to INCREF.
  152. Py_INCREF(pyblob.ptr());
  153. PyArray_SetBaseObject(reinterpret_cast<PyArrayObject*>(arr_obj),
  154. pyblob.ptr());
  155. return arr_obj;
  156. }
  157. };
  158. bp::object Blob_Reshape(bp::tuple args, bp::dict kwargs) {
  159. if (bp::len(kwargs) > 0) {
  160. throw std::runtime_error("Blob.reshape takes no kwargs");
  161. }
  162. Blob<Dtype>* self = bp::extract<Blob<Dtype>*>(args[0]);
  163. vector<int> shape(bp::len(args) - 1);
  164. for (int i = 1; i < bp::len(args); ++i) {
  165. shape[i - 1] = bp::extract<int>(args[i]);
  166. }
  167. self->Reshape(shape);
  168. // We need to explicitly return None to use bp::raw_function.
  169. return bp::object();
  170. }
  171. bp::object BlobVec_add_blob(bp::tuple args, bp::dict kwargs) {
  172. if (bp::len(kwargs) > 0) {
  173. throw std::runtime_error("BlobVec.add_blob takes no kwargs");
  174. }
  175. typedef vector<shared_ptr<Blob<Dtype> > > BlobVec;
  176. BlobVec* self = bp::extract<BlobVec*>(args[0]);
  177. vector<int> shape(bp::len(args) - 1);
  178. for (int i = 1; i < bp::len(args); ++i) {
  179. shape[i - 1] = bp::extract<int>(args[i]);
  180. }
  181. self->push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
  182. // We need to explicitly return None to use bp::raw_function.
  183. return bp::object();
  184. }
  185. BOOST_PYTHON_MEMBER_FUNCTION_OVERLOADS(SolveOverloads, Solve, 0, 1);
  186. BOOST_PYTHON_MODULE(_caffe) {
  187. // below, we prepend an underscore to methods that will be replaced
  188. // in Python
  189. bp::scope().attr("__version__") = AS_STRING(CAFFE_VERSION);
  190. // Caffe utility functions
  191. bp::def("set_mode_cpu", &set_mode_cpu);
  192. bp::def("set_mode_gpu", &set_mode_gpu);
  193. bp::def("set_device", &Caffe::SetDevice);
  194. bp::def("set_random_seed", &Caffe::set_random_seed);
  195. bp::def("layer_type_list", &LayerRegistry<Dtype>::LayerTypeList);
  196. bp::enum_<Phase>("Phase")
  197. .value("TRAIN", caffe::TRAIN)
  198. .value("TEST", caffe::TEST)
  199. .export_values();
  200. bp::class_<Net<Dtype>, shared_ptr<Net<Dtype> >, boost::noncopyable >("Net",
  201. bp::no_init)
  202. .def("__init__", bp::make_constructor(&Net_Init))
  203. .def("__init__", bp::make_constructor(&Net_Init_Load))
  204. .def("_forward", &Net<Dtype>::ForwardFromTo)
  205. .def("_backward", &Net<Dtype>::BackwardFromTo)
  206. .def("reshape", &Net<Dtype>::Reshape)
  207. // The cast is to select a particular overload.
  208. .def("copy_from", static_cast<void (Net<Dtype>::*)(const string)>(
  209. &Net<Dtype>::CopyTrainedLayersFrom))
  210. .def("share_with", &Net<Dtype>::ShareTrainedLayersWith)
  211. .add_property("_blob_loss_weights", bp::make_function(
  212. &Net<Dtype>::blob_loss_weights, bp::return_internal_reference<>()))
  213. .def("_bottom_ids", bp::make_function(&Net<Dtype>::bottom_ids,
  214. bp::return_value_policy<bp::copy_const_reference>()))
  215. .def("_top_ids", bp::make_function(&Net<Dtype>::top_ids,
  216. bp::return_value_policy<bp::copy_const_reference>()))
  217. .add_property("_blobs", bp::make_function(&Net<Dtype>::blobs,
  218. bp::return_internal_reference<>()))
  219. .add_property("layers", bp::make_function(&Net<Dtype>::layers,
  220. bp::return_internal_reference<>()))
  221. .add_property("_blob_names", bp::make_function(&Net<Dtype>::blob_names,
  222. bp::return_value_policy<bp::copy_const_reference>()))
  223. .add_property("_layer_names", bp::make_function(&Net<Dtype>::layer_names,
  224. bp::return_value_policy<bp::copy_const_reference>()))
  225. .add_property("_inputs", bp::make_function(&Net<Dtype>::input_blob_indices,
  226. bp::return_value_policy<bp::copy_const_reference>()))
  227. .add_property("_outputs",
  228. bp::make_function(&Net<Dtype>::output_blob_indices,
  229. bp::return_value_policy<bp::copy_const_reference>()))
  230. .def("_set_input_arrays", &Net_SetInputArrays,
  231. bp::with_custodian_and_ward<1, 2, bp::with_custodian_and_ward<1, 3> >())
  232. .def("save", &Net_Save);
  233. bp::class_<Blob<Dtype>, shared_ptr<Blob<Dtype> >, boost::noncopyable>(
  234. "Blob", bp::no_init)
  235. .add_property("shape",
  236. bp::make_function(
  237. static_cast<const vector<int>& (Blob<Dtype>::*)() const>(
  238. &Blob<Dtype>::shape),
  239. bp::return_value_policy<bp::copy_const_reference>()))
  240. .add_property("num", &Blob<Dtype>::num)
  241. .add_property("channels", &Blob<Dtype>::channels)
  242. .add_property("height", &Blob<Dtype>::height)
  243. .add_property("width", &Blob<Dtype>::width)
  244. .add_property("count", static_cast<int (Blob<Dtype>::*)() const>(
  245. &Blob<Dtype>::count))
  246. .def("reshape", bp::raw_function(&Blob_Reshape))
  247. .add_property("data", bp::make_function(&Blob<Dtype>::mutable_cpu_data,
  248. NdarrayCallPolicies()))
  249. .add_property("diff", bp::make_function(&Blob<Dtype>::mutable_cpu_diff,
  250. NdarrayCallPolicies()));
  251. bp::class_<Layer<Dtype>, shared_ptr<PythonLayer<Dtype> >,
  252. boost::noncopyable>("Layer", bp::init<const LayerParameter&>())
  253. .add_property("blobs", bp::make_function(&Layer<Dtype>::blobs,
  254. bp::return_internal_reference<>()))
  255. .def("setup", &Layer<Dtype>::LayerSetUp)
  256. .def("reshape", &Layer<Dtype>::Reshape)
  257. .add_property("phase", bp::make_function(&Layer<Dtype>::phase))
  258. .add_property("type", bp::make_function(&Layer<Dtype>::type));
  259. bp::register_ptr_to_python<shared_ptr<Layer<Dtype> > >();
  260. bp::class_<LayerParameter>("LayerParameter", bp::no_init);
  261. bp::class_<Solver<Dtype>, shared_ptr<Solver<Dtype> >, boost::noncopyable>(
  262. "Solver", bp::no_init)
  263. .add_property("net", &Solver<Dtype>::net)
  264. .add_property("test_nets", bp::make_function(&Solver<Dtype>::test_nets,
  265. bp::return_internal_reference<>()))
  266. .add_property("iter", &Solver<Dtype>::iter)
  267. .def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
  268. &Solver<Dtype>::Solve), SolveOverloads())
  269. .def("step", &Solver<Dtype>::Step)
  270. .def("restore", &Solver<Dtype>::Restore)
  271. .def("snapshot", &Solver<Dtype>::Snapshot);
  272. bp::class_<SGDSolver<Dtype>, bp::bases<Solver<Dtype> >,
  273. shared_ptr<SGDSolver<Dtype> >, boost::noncopyable>(
  274. "SGDSolver", bp::init<string>());
  275. bp::class_<NesterovSolver<Dtype>, bp::bases<Solver<Dtype> >,
  276. shared_ptr<NesterovSolver<Dtype> >, boost::noncopyable>(
  277. "NesterovSolver", bp::init<string>());
  278. bp::class_<AdaGradSolver<Dtype>, bp::bases<Solver<Dtype> >,
  279. shared_ptr<AdaGradSolver<Dtype> >, boost::noncopyable>(
  280. "AdaGradSolver", bp::init<string>());
  281. bp::class_<RMSPropSolver<Dtype>, bp::bases<Solver<Dtype> >,
  282. shared_ptr<RMSPropSolver<Dtype> >, boost::noncopyable>(
  283. "RMSPropSolver", bp::init<string>());
  284. bp::class_<AdaDeltaSolver<Dtype>, bp::bases<Solver<Dtype> >,
  285. shared_ptr<AdaDeltaSolver<Dtype> >, boost::noncopyable>(
  286. "AdaDeltaSolver", bp::init<string>());
  287. bp::class_<AdamSolver<Dtype>, bp::bases<Solver<Dtype> >,
  288. shared_ptr<AdamSolver<Dtype> >, boost::noncopyable>(
  289. "AdamSolver", bp::init<string>());
  290. bp::def("get_solver", &GetSolverFromFile,
  291. bp::return_value_policy<bp::manage_new_object>());
  292. // vector wrappers for all the vector types we use
  293. bp::class_<vector<shared_ptr<Blob<Dtype> > > >("BlobVec")
  294. .def(bp::vector_indexing_suite<vector<shared_ptr<Blob<Dtype> > >, true>())
  295. .def("add_blob", bp::raw_function(&BlobVec_add_blob));
  296. bp::class_<vector<Blob<Dtype>*> >("RawBlobVec")
  297. .def(bp::vector_indexing_suite<vector<Blob<Dtype>*>, true>());
  298. bp::class_<vector<shared_ptr<Layer<Dtype> > > >("LayerVec")
  299. .def(bp::vector_indexing_suite<vector<shared_ptr<Layer<Dtype> > >, true>());
  300. bp::class_<vector<string> >("StringVec")
  301. .def(bp::vector_indexing_suite<vector<string> >());
  302. bp::class_<vector<int> >("IntVec")
  303. .def(bp::vector_indexing_suite<vector<int> >());
  304. bp::class_<vector<Dtype> >("DtypeVec")
  305. .def(bp::vector_indexing_suite<vector<Dtype> >());
  306. bp::class_<vector<shared_ptr<Net<Dtype> > > >("NetVec")
  307. .def(bp::vector_indexing_suite<vector<shared_ptr<Net<Dtype> > >, true>());
  308. bp::class_<vector<bool> >("BoolVec")
  309. .def(bp::vector_indexing_suite<vector<bool> >());
  310. // boost python expects a void (missing) return value, while import_array
  311. // returns NULL for python3. import_array1() forces a void return value.
  312. import_array1();
  313. }
  314. } // namespace caffe