1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_
16#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_
17
18#include <Python.h>
19
20#include "tensorflow/c/eager/c_api.h"
21#include "tensorflow/core/framework/types.pb.h"
22#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
23
24namespace tensorflow {
25
26// Converts PyObject* values to Tensors.
27//
28// This converter attempts to convert values as efficiently as possible; but
29// it has fallback paths to handle any PyObject* value for which tensor
30// conversion is defined.
31class PythonTensorConverter {
32 public:
33 // Constructs a new PythonTensorConverter.
34 //
35 // Note: the arguments to this constructor may change in the future, as
36 // we move more of python tensor conversion from the Python layer to the
37 // c++ layer.
38 //
39 // Args:
40 // py_eager_context: the value of context.context() from eager/context.py.
41 // ctx: The c++ eager context, or nullptr in graph mode.
42 // device_name: The current device name.
43 //
44 // All three argument values must remain alive until `this` is deleted.
45 PythonTensorConverter(PyObject* py_eager_context, TFE_Context* ctx,
46 const char* device_name)
47 : py_eager_context_(py_eager_context),
48 ctx_(ctx),
49 device_name_(device_name) {}
50
51 // Converts `src` to a tensor (if it's not already one), and returns a new
52 // reference to the converted value.
53 //
54 // Args:
55 // src: The object that should be converted to a Tensor.
56 // dtype: The requested dtype. Use `DT_INVALID` if the dtype should be
57 // inferred from the `src` value (in which case `dtype` will be updated
58 // in-place to be the actual dtype of the converted value).
59 // used_fallback: Output parameter used to record whether the conversion
60 // was done by falling back to the Python `tf.convert_to_tensor()`
61 // function. This is for testing/logging purposes only. May be null.
62 //
63 // If `src` can't be converted to a tensor with the requested dtype, sets a
64 // Python exception and returns nullptr.
65 Safe_PyObjectPtr Convert(PyObject* src, DataType& dtype,
66 bool* used_fallback = nullptr) const;
67
68 private:
69 PyObject* py_eager_context_;
70 TFE_Context* ctx_;
71 const char* device_name_;
72};
73
74} // namespace tensorflow
75
76#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_TENSOR_CONVERTER_H_
77