1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_ |
16 | #define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_ |
17 | |
18 | #include <Python.h> |
19 | |
20 | #include "tensorflow/c/eager/c_api.h" |
21 | #include "tensorflow/core/framework/types.pb.h" |
22 | #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | // Converts PyObject* values to Tensors. |
27 | // |
28 | // This converter attempts to convert values as efficiently as possible; but |
29 | // it has fallback paths to handle any PyObject* value for which tensor |
30 | // conversion is defined. |
31 | class PythonTensorConverter { |
32 | public: |
33 | // Constructs a new PythonTensorConverter. |
34 | // |
35 | // Note: the arguments to this constructor may change in the future, as |
36 | // we move more of python tensor conversion from the Python layer to the |
37 | // c++ layer. |
38 | // |
39 | // Args: |
40 | // py_eager_context: the value of context.context() from eager/context.py. |
41 | // ctx: The c++ eager context, or nullptr in graph mode. |
42 | // device_name: The current device name. |
43 | // |
44 | // All three argument values must remain alive until `this` is deleted. |
45 | PythonTensorConverter(PyObject* py_eager_context, TFE_Context* ctx, |
46 | const char* device_name) |
47 | : py_eager_context_(py_eager_context), |
48 | ctx_(ctx), |
49 | device_name_(device_name) {} |
50 | |
51 | // Converts `src` to a tensor (if it's not already one), and returns a new |
52 | // reference to the converted value. |
53 | // |
54 | // Args: |
55 | // src: The object that should be converted to a Tensor. |
56 | // dtype: The requested dtype. Use `DT_INVALID` if the dtype should be |
57 | // inferred from the `src` value (in which case `dtype` will be updated |
58 | // in-place to be the actual dtype of the converted value). |
59 | // used_fallback: Output parameter used to record whether the conversion |
60 | // was done by falling back to the Python `tf.convert_to_tensor()` |
61 | // function. This is for testing/logging purposes only. May be null. |
62 | // |
63 | // If `src` can't be converted to a tensor with the requested dtype, sets a |
64 | // Python exception and returns nullptr. |
65 | Safe_PyObjectPtr Convert(PyObject* src, DataType& dtype, |
66 | bool* used_fallback = nullptr) const; |
67 | |
68 | private: |
69 | PyObject* py_eager_context_; |
70 | TFE_Context* ctx_; |
71 | const char* device_name_; |
72 | }; |
73 | |
74 | } // namespace tensorflow |
75 | |
76 | #endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_ |
77 | |