1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
16#define TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
17
18#include <Python.h>
19
20#include <map>
21
22#include "tensorflow/c/tf_status.h"
23#include "tensorflow/core/lib/core/error_codes.pb.h"
24
25namespace tensorflow {
26
27// Global registry mapping C API error codes to the corresponding custom Python
28// exception type. This is used to expose the exception types to C extension
29// code (i.e. so we can raise custom exceptions via SWIG).
30//
31// Init() must be called exactly once at the beginning of the process before
32// Lookup() can be used.
33//
34// Example usage:
35// TF_Status* status = TF_NewStatus();
36// TF_Foo(..., status);
37//
38// if (TF_GetCode(status) != TF_OK) {
39// PyObject* exc_type = PyExceptionRegistry::Lookup(TF_GetCode(status));
40// // Arguments to OpError base class. Set `node_def` and `op` to None.
41// PyObject* args =
42// Py_BuildValue("sss", nullptr, nullptr, TF_Message(status));
43// PyErr_SetObject(exc_type, args);
44// Py_DECREF(args);
45// TF_DeleteStatus(status);
46// return NULL;
47// }
48class PyExceptionRegistry {
49 public:
50 // Initializes the process-wide registry. Should be called exactly once near
51 // the beginning of the process. The arguments are the various Python
52 // exception types (e.g. `cancelled_exc` corresponds to
53 // errors.CancelledError).
54 static void Init(PyObject* code_to_exc_type_map);
55
56 // Returns the Python exception type corresponding to `code`. Init() must be
57 // called before using this function. `code` should not be TF_OK.
58 static PyObject* Lookup(TF_Code code);
59
60 static inline PyObject* Lookup(error::Code code) {
61 return Lookup(static_cast<TF_Code>(code));
62 }
63
64 private:
65 static PyExceptionRegistry* singleton_;
66 PyExceptionRegistry() = default;
67
68 // Maps error codes to the corresponding Python exception type.
69 std::map<TF_Code, PyObject*> exc_types_;
70};
71
72} // namespace tensorflow
73
74#endif // TENSORFLOW_PYTHON_LIB_CORE_PY_EXCEPTION_REGISTRY_H_
75