1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_PARAMETER_CONVERTER_H_
16#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_PARAMETER_CONVERTER_H_
17
18#include <Python.h>
19
20#include <map>
21#include <string>
22#include <vector>
23
24#include "absl/types/span.h"
25#include "tensorflow/core/framework/op_def.pb.h"
26#include "tensorflow/core/framework/types.pb.h"
27#include "tensorflow/core/platform/status.h"
28#include "tensorflow/python/framework/op_def_util.h"
29#include "tensorflow/python/framework/python_api_info.h"
30#include "tensorflow/python/framework/python_tensor_converter.h"
31#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
32
33namespace tensorflow {
34
35// Converts the canoncialized parameters to the expected types (in place).
36//
37// * Input parameters (i.e., parameters that expect tensor values) are
38// converted to tensors (or lists of tensors) using
39// `tensor_converter.Convert`.
40// * Attribute parameters are converted to the expected type.
41// * Inferred attributes are written to `inferred_attrs`. (Can be
42// nullptr if inferred attributes are not needed.)
43// * If there's a "name" parameter, then its value is not modified.
44//
45// Note: for list-of-tensor parameters, the elements of the list will be
46// converted in-place. Therefore, any list-of-tensor parameters should have
47// their values copied to new lists before calling this method. (See
48// `CopyPythonAPITensorLists`.)
49//
50// Any values that are removed from `params` have their reference count
51// decremented, and any objects added to `params` are new references.
52//
53// Returns true on success, or sets an exception and returns false on error.
54ABSL_MUST_USE_RESULT
55bool ConvertPythonAPIParameters(
56 const PythonAPIInfo& api_info,
57 const PythonTensorConverter& tensor_converter, absl::Span<PyObject*> params,
58 PythonAPIInfo::InferredAttributes* inferred_attrs);
59
60// Copies any parameters that expect a list of tensors to a new list.
61// This ensures that any iterable value can be used, and also ensures that
62// `ConvertPythonAPIParameters` can safely convert tensors in-place.
63//
64// Any values that are removed from `params` have their reference count
65// decremented, and any objects added to `params` are new references.
66//
67// Returns true on success, or sets an exception and returns false on error.
68ABSL_MUST_USE_RESULT
69bool CopyPythonAPITensorLists(const PythonAPIInfo& api_info,
70 absl::Span<PyObject*> params);
71
72} // namespace tensorflow
73
74#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_PARAMETER_CONVERTER_H_
75