Skip to content

Commit

Permalink
Merge pull request #2456 from longjon/python-layer-object
Browse files Browse the repository at this point in the history
Use bp::object instead of PyObject* for self in Python layer
  • Loading branch information
jeffdonahue committed May 14, 2015
2 parents 9ad14c2 + 7cf8b83 commit d1abf9d
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions include/caffe/python_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ template <typename Dtype>
class PythonLayer : public Layer<Dtype> {
public:
PythonLayer(PyObject* self, const LayerParameter& param)
: Layer<Dtype>(param), self_(self) { }
: Layer<Dtype>(param), self_(bp::handle<>(bp::borrowed(self))) { }

virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
try {
bp::call_method<bp::object>(self_, "setup", bottom, top);
self_.attr("setup")(bottom, top);
} catch (bp::error_already_set) {
PyErr_Print();
throw;
Expand All @@ -29,7 +29,7 @@ class PythonLayer : public Layer<Dtype> {
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
try {
bp::call_method<bp::object>(self_, "reshape", bottom, top);
self_.attr("reshape")(bottom, top);
} catch (bp::error_already_set) {
PyErr_Print();
throw;
Expand All @@ -42,7 +42,7 @@ class PythonLayer : public Layer<Dtype> {
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
try {
bp::call_method<bp::object>(self_, "forward", bottom, top);
self_.attr("forward")(bottom, top);
} catch (bp::error_already_set) {
PyErr_Print();
throw;
Expand All @@ -51,16 +51,15 @@ class PythonLayer : public Layer<Dtype> {
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
try {
bp::call_method<bp::object>(self_, "backward", top, propagate_down,
bottom);
self_.attr("backward")(top, propagate_down, bottom);
} catch (bp::error_already_set) {
PyErr_Print();
throw;
}
}

private:
PyObject* self_;
bp::object self_;
};

} // namespace caffe
Expand Down

0 comments on commit d1abf9d

Please sign in to comment.