1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_ |
17 | #define TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_ |
18 | |
19 | #include <Python.h> |
20 | |
21 | #include "pybind11/cast.h" |
22 | #include "pybind11/pybind11.h" |
23 | #include "tensorflow/c/tf_status_internal.h" |
24 | #include "tensorflow/core/lib/core/status.h" |
25 | #include "tensorflow/core/platform/errors.h" |
26 | #include "tensorflow/core/platform/statusor.h" |
27 | #include "tensorflow/core/protobuf/error_codes.pb.h" |
28 | #include "tensorflow/python/lib/core/py_exception_registry.h" |
29 | |
30 | namespace tsl { |
31 | |
32 | namespace internal { |
33 | |
34 | inline PyObject* CodeToPyExc(const int code) { |
35 | switch (code) { |
36 | case error::Code::INVALID_ARGUMENT: |
37 | return PyExc_ValueError; |
38 | case error::Code::OUT_OF_RANGE: |
39 | return PyExc_IndexError; |
40 | case error::Code::UNIMPLEMENTED: |
41 | return PyExc_NotImplementedError; |
42 | default: |
43 | return PyExc_RuntimeError; |
44 | } |
45 | } |
46 | |
47 | inline PyObject* StatusToPyExc(const Status& status) { |
48 | return CodeToPyExc(status.code()); |
49 | } |
50 | |
51 | inline PyObject* TFStatusToPyExc(const TF_Status* status) { |
52 | return CodeToPyExc(TF_GetCode(status)); |
53 | } |
54 | |
55 | inline pybind11::dict StatusPayloadToDict(const Status& status) { |
56 | pybind11::dict dict; |
57 | const auto& payloads = errors::GetPayloads(status); |
58 | for (auto& pair : payloads) { |
59 | dict[PyBytes_FromString(pair.first.c_str())] = |
60 | PyBytes_FromString(pair.second.c_str()); |
61 | } |
62 | return dict; |
63 | } |
64 | |
65 | inline pybind11::dict TFStatusPayloadToDict(TF_Status* status) { |
66 | return StatusPayloadToDict(status->status); |
67 | } |
68 | |
69 | } // namespace internal |
70 | |
71 | inline void MaybeRaiseFromStatus(const Status& status) { |
72 | if (!status.ok()) { |
73 | PyErr_SetString(internal::StatusToPyExc(status), |
74 | status.error_message().c_str()); |
75 | throw pybind11::error_already_set(); |
76 | } |
77 | } |
78 | |
79 | inline void SetRegisteredErrFromStatus(const tensorflow::Status& status) { |
80 | PyErr_SetObject(tensorflow::PyExceptionRegistry::Lookup(status.code()), |
81 | pybind11::make_tuple(pybind11::none(), pybind11::none(), |
82 | status.error_message(), |
83 | internal::StatusPayloadToDict(status)) |
84 | .ptr()); |
85 | } |
86 | |
87 | inline void SetRegisteredErrFromTFStatus(TF_Status* status) { |
88 | PyErr_SetObject(tensorflow::PyExceptionRegistry::Lookup(TF_GetCode(status)), |
89 | pybind11::make_tuple(pybind11::none(), pybind11::none(), |
90 | TF_Message(status), |
91 | internal::TFStatusPayloadToDict(status)) |
92 | .ptr()); |
93 | } |
94 | |
95 | inline void MaybeRaiseRegisteredFromStatus(const tensorflow::Status& status) { |
96 | if (!status.ok()) { |
97 | SetRegisteredErrFromStatus(status); |
98 | throw pybind11::error_already_set(); |
99 | } |
100 | } |
101 | |
102 | inline void MaybeRaiseRegisteredFromStatusWithGIL( |
103 | const tensorflow::Status& status) { |
104 | if (!status.ok()) { |
105 | // Acquire GIL for throwing exception. |
106 | pybind11::gil_scoped_acquire acquire; |
107 | SetRegisteredErrFromStatus(status); |
108 | throw pybind11::error_already_set(); |
109 | } |
110 | } |
111 | |
112 | inline void MaybeRaiseFromTFStatus(TF_Status* status) { |
113 | TF_Code code = TF_GetCode(status); |
114 | if (code != TF_OK) { |
115 | PyErr_SetString(internal::TFStatusToPyExc(status), TF_Message(status)); |
116 | throw pybind11::error_already_set(); |
117 | } |
118 | } |
119 | |
120 | inline void MaybeRaiseRegisteredFromTFStatus(TF_Status* status) { |
121 | TF_Code code = TF_GetCode(status); |
122 | if (code != TF_OK) { |
123 | SetRegisteredErrFromTFStatus(status); |
124 | throw pybind11::error_already_set(); |
125 | } |
126 | } |
127 | |
128 | inline void MaybeRaiseRegisteredFromTFStatusWithGIL(TF_Status* status) { |
129 | TF_Code code = TF_GetCode(status); |
130 | if (code != TF_OK) { |
131 | // Acquire GIL for throwing exception. |
132 | pybind11::gil_scoped_acquire acquire; |
133 | SetRegisteredErrFromTFStatus(status); |
134 | throw pybind11::error_already_set(); |
135 | } |
136 | } |
137 | |
138 | } // namespace tsl |
139 | |
140 | namespace tensorflow { |
141 | |
142 | using tsl::MaybeRaiseFromStatus; |
143 | using tsl::MaybeRaiseFromTFStatus; |
144 | using tsl::MaybeRaiseRegisteredFromStatus; |
145 | using tsl::MaybeRaiseRegisteredFromStatusWithGIL; |
146 | using tsl::MaybeRaiseRegisteredFromTFStatus; |
147 | using tsl::MaybeRaiseRegisteredFromTFStatusWithGIL; |
148 | using tsl::SetRegisteredErrFromStatus; |
149 | using tsl::SetRegisteredErrFromTFStatus; |
150 | } // namespace tensorflow |
151 | |
152 | namespace pybind11 { |
153 | namespace detail { |
154 | |
155 | // Convert tensorflow::Status |
156 | // |
157 | // Raise an exception if a given status is not OK, otherwise return None. |
158 | // |
159 | // The correspondence between status codes and exception classes is given |
160 | // by PyExceptionRegistry. Note that the registry should be initialized |
161 | // in order to be used, see PyExceptionRegistry::Init. |
162 | template <> |
163 | struct type_caster<tensorflow::Status> { |
164 | public: |
165 | PYBIND11_TYPE_CASTER(tensorflow::Status, _("Status" )); |
166 | static handle cast(tensorflow::Status status, return_value_policy, handle) { |
167 | tensorflow::MaybeRaiseFromStatus(status); |
168 | return none().inc_ref(); |
169 | } |
170 | }; |
171 | |
172 | // Convert tensorflow::StatusOr |
173 | // |
174 | // Uses the same logic as the Abseil implementation: raise an exception if the |
175 | // status is not OK, otherwise return its payload. |
176 | template <typename PayloadType> |
177 | struct type_caster<tensorflow::StatusOr<PayloadType>> { |
178 | public: |
179 | using PayloadCaster = make_caster<PayloadType>; |
180 | using StatusCaster = make_caster<tensorflow::Status>; |
181 | static constexpr auto name = PayloadCaster::name; |
182 | |
183 | static handle cast(const tensorflow::StatusOr<PayloadType>* src, |
184 | return_value_policy policy, handle parent) { |
185 | if (!src) return none().release(); |
186 | return cast_impl(*src, policy, parent); |
187 | } |
188 | |
189 | static handle cast(const tensorflow::StatusOr<PayloadType>& src, |
190 | return_value_policy policy, handle parent) { |
191 | return cast_impl(src, policy, parent); |
192 | } |
193 | |
194 | static handle cast(tensorflow::StatusOr<PayloadType>&& src, |
195 | return_value_policy policy, handle parent) { |
196 | return cast_impl(std::move(src), policy, parent); |
197 | } |
198 | |
199 | private: |
200 | template <typename CType> |
201 | static handle cast_impl(CType&& src, return_value_policy policy, |
202 | handle parent) { |
203 | if (src.ok()) { |
204 | // Convert and return the payload. |
205 | return PayloadCaster::cast(std::forward<CType>(src).value(), policy, |
206 | parent); |
207 | } else { |
208 | // Convert and return the error. |
209 | return StatusCaster::cast(std::forward<CType>(src).status(), |
210 | return_value_policy::move, parent); |
211 | } |
212 | } |
213 | }; |
214 | |
215 | } // namespace detail |
216 | } // namespace pybind11 |
217 | |
218 | #endif // TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_ |
219 | |