1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
17#define TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
18
19#include <Python.h>
20
21#include "pybind11/cast.h"
22#include "pybind11/pybind11.h"
23#include "tensorflow/c/tf_status_internal.h"
24#include "tensorflow/core/lib/core/status.h"
25#include "tensorflow/core/platform/errors.h"
26#include "tensorflow/core/platform/statusor.h"
27#include "tensorflow/core/protobuf/error_codes.pb.h"
28#include "tensorflow/python/lib/core/py_exception_registry.h"
29
30namespace tsl {
31
32namespace internal {
33
34inline PyObject* CodeToPyExc(const int code) {
35 switch (code) {
36 case error::Code::INVALID_ARGUMENT:
37 return PyExc_ValueError;
38 case error::Code::OUT_OF_RANGE:
39 return PyExc_IndexError;
40 case error::Code::UNIMPLEMENTED:
41 return PyExc_NotImplementedError;
42 default:
43 return PyExc_RuntimeError;
44 }
45}
46
47inline PyObject* StatusToPyExc(const Status& status) {
48 return CodeToPyExc(status.code());
49}
50
51inline PyObject* TFStatusToPyExc(const TF_Status* status) {
52 return CodeToPyExc(TF_GetCode(status));
53}
54
55inline pybind11::dict StatusPayloadToDict(const Status& status) {
56 pybind11::dict dict;
57 const auto& payloads = errors::GetPayloads(status);
58 for (auto& pair : payloads) {
59 dict[PyBytes_FromString(pair.first.c_str())] =
60 PyBytes_FromString(pair.second.c_str());
61 }
62 return dict;
63}
64
65inline pybind11::dict TFStatusPayloadToDict(TF_Status* status) {
66 return StatusPayloadToDict(status->status);
67}
68
69} // namespace internal
70
71inline void MaybeRaiseFromStatus(const Status& status) {
72 if (!status.ok()) {
73 PyErr_SetString(internal::StatusToPyExc(status),
74 status.error_message().c_str());
75 throw pybind11::error_already_set();
76 }
77}
78
79inline void SetRegisteredErrFromStatus(const tensorflow::Status& status) {
80 PyErr_SetObject(tensorflow::PyExceptionRegistry::Lookup(status.code()),
81 pybind11::make_tuple(pybind11::none(), pybind11::none(),
82 status.error_message(),
83 internal::StatusPayloadToDict(status))
84 .ptr());
85}
86
87inline void SetRegisteredErrFromTFStatus(TF_Status* status) {
88 PyErr_SetObject(tensorflow::PyExceptionRegistry::Lookup(TF_GetCode(status)),
89 pybind11::make_tuple(pybind11::none(), pybind11::none(),
90 TF_Message(status),
91 internal::TFStatusPayloadToDict(status))
92 .ptr());
93}
94
95inline void MaybeRaiseRegisteredFromStatus(const tensorflow::Status& status) {
96 if (!status.ok()) {
97 SetRegisteredErrFromStatus(status);
98 throw pybind11::error_already_set();
99 }
100}
101
102inline void MaybeRaiseRegisteredFromStatusWithGIL(
103 const tensorflow::Status& status) {
104 if (!status.ok()) {
105 // Acquire GIL for throwing exception.
106 pybind11::gil_scoped_acquire acquire;
107 SetRegisteredErrFromStatus(status);
108 throw pybind11::error_already_set();
109 }
110}
111
112inline void MaybeRaiseFromTFStatus(TF_Status* status) {
113 TF_Code code = TF_GetCode(status);
114 if (code != TF_OK) {
115 PyErr_SetString(internal::TFStatusToPyExc(status), TF_Message(status));
116 throw pybind11::error_already_set();
117 }
118}
119
120inline void MaybeRaiseRegisteredFromTFStatus(TF_Status* status) {
121 TF_Code code = TF_GetCode(status);
122 if (code != TF_OK) {
123 SetRegisteredErrFromTFStatus(status);
124 throw pybind11::error_already_set();
125 }
126}
127
128inline void MaybeRaiseRegisteredFromTFStatusWithGIL(TF_Status* status) {
129 TF_Code code = TF_GetCode(status);
130 if (code != TF_OK) {
131 // Acquire GIL for throwing exception.
132 pybind11::gil_scoped_acquire acquire;
133 SetRegisteredErrFromTFStatus(status);
134 throw pybind11::error_already_set();
135 }
136}
137
138} // namespace tsl
139
140namespace tensorflow {
141
142using tsl::MaybeRaiseFromStatus;
143using tsl::MaybeRaiseFromTFStatus;
144using tsl::MaybeRaiseRegisteredFromStatus;
145using tsl::MaybeRaiseRegisteredFromStatusWithGIL;
146using tsl::MaybeRaiseRegisteredFromTFStatus;
147using tsl::MaybeRaiseRegisteredFromTFStatusWithGIL;
148using tsl::SetRegisteredErrFromStatus;
149using tsl::SetRegisteredErrFromTFStatus;
150} // namespace tensorflow
151
152namespace pybind11 {
153namespace detail {
154
155// Convert tensorflow::Status
156//
157// Raise an exception if a given status is not OK, otherwise return None.
158//
159// The correspondence between status codes and exception classes is given
160// by PyExceptionRegistry. Note that the registry should be initialized
161// in order to be used, see PyExceptionRegistry::Init.
162template <>
163struct type_caster<tensorflow::Status> {
164 public:
165 PYBIND11_TYPE_CASTER(tensorflow::Status, _("Status"));
166 static handle cast(tensorflow::Status status, return_value_policy, handle) {
167 tensorflow::MaybeRaiseFromStatus(status);
168 return none().inc_ref();
169 }
170};
171
172// Convert tensorflow::StatusOr
173//
174// Uses the same logic as the Abseil implementation: raise an exception if the
175// status is not OK, otherwise return its payload.
176template <typename PayloadType>
177struct type_caster<tensorflow::StatusOr<PayloadType>> {
178 public:
179 using PayloadCaster = make_caster<PayloadType>;
180 using StatusCaster = make_caster<tensorflow::Status>;
181 static constexpr auto name = PayloadCaster::name;
182
183 static handle cast(const tensorflow::StatusOr<PayloadType>* src,
184 return_value_policy policy, handle parent) {
185 if (!src) return none().release();
186 return cast_impl(*src, policy, parent);
187 }
188
189 static handle cast(const tensorflow::StatusOr<PayloadType>& src,
190 return_value_policy policy, handle parent) {
191 return cast_impl(src, policy, parent);
192 }
193
194 static handle cast(tensorflow::StatusOr<PayloadType>&& src,
195 return_value_policy policy, handle parent) {
196 return cast_impl(std::move(src), policy, parent);
197 }
198
199 private:
200 template <typename CType>
201 static handle cast_impl(CType&& src, return_value_policy policy,
202 handle parent) {
203 if (src.ok()) {
204 // Convert and return the payload.
205 return PayloadCaster::cast(std::forward<CType>(src).value(), policy,
206 parent);
207 } else {
208 // Convert and return the error.
209 return StatusCaster::cast(std::forward<CType>(src).status(),
210 return_value_policy::move, parent);
211 }
212 }
213};
214
215} // namespace detail
216} // namespace pybind11
217
218#endif // TENSORFLOW_PYTHON_LIB_CORE_PYBIND11_STATUS_H_
219