1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_PYTHON_FRAMEWORK_OP_DEF_UTIL_H_
16#define TENSORFLOW_PYTHON_FRAMEWORK_OP_DEF_UTIL_H_
17
18#include <Python.h>
19
20#include <string>
21
22#include "tensorflow/core/framework/attr_value.pb.h"
23#include "tensorflow/core/framework/tensor.pb.h"
24#include "tensorflow/core/framework/tensor_shape.pb.h"
25#include "tensorflow/core/framework/types.pb.h"
26#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
27
28namespace tensorflow {
29
30// Enumerated type corresponding with string values in AttrDef::type.
31enum class AttributeType {
32 UNKNOWN,
33 ANY, // "any"
34 FLOAT, // "float"
35 INT, // "int"
36 STRING, // "string"
37 BOOL, // "bool"
38 DTYPE, // "type" (tf.dtypes.DType)
39 SHAPE, // "shape" (tf.TensorShape)
40 TENSOR, // "tensor" (tf.TensorProto)
41 LIST_ANY, // "list(any)"
42 LIST_FLOAT, // "list(float)"
43 LIST_INT, // "list(int)"
44 LIST_STRING, // "list(string)"
45 LIST_BOOL, // "list(bool)"
46 LIST_DTYPE, // "list(dtype)"
47 LIST_SHAPE, // "list(shape)"
48 LIST_TENSOR // "list(tensor)"
49};
50
51// Returns the enumerated value corresponding to a given string (e.g.
52// "string" or "list(string)".
53AttributeType AttributeTypeFromName(const std::string& type_name);
54
55// Returns the string corresponding to a given enumerated value.
56std::string AttributeTypeToName(AttributeType attr_type);
57
58// Converts `value` to the specified type and returns a new reference to the
59// converted value (if possible); or sets a Python exception and returns
60// nullptr. This function is optimized to be fast if `value` already has the
61// desired type.
62//
63// * 'any' values are returned as-is.
64// * 'float' values are converted by calling float(value).
65// * 'int' values are converted by calling int(value).
66// * 'string' values are returned as-is if they are (bytes, unicode);
67// otherwise, an exception is raised.
68// * 'bool' values are returned as-is if they are boolean; otherwise, an
69// exception is raised.
70// * 'dtype' values are converted using `dtypes.as_dtype`.
71// * 'shape' values are converted using `tensor_shape.as_shape`.
72// * 'tensor' values are returned as-is if they are a `TensorProto`; or are
73// parsed into `TensorProto` using `textformat.merge` if they are a string.
74// Otherwise, an exception is raised.
75// * 'list(*)' values are copied to a new list, and then each element is
76// converted (in-place) as described above. (If the value is not iterable,
77// or if conversion fails for any item, then an exception is raised.)
78Safe_PyObjectPtr ConvertPyObjectToAttributeType(PyObject* value,
79 AttributeType type);
80
81// Converts a c++ `AttrValue` protobuf message to a Python object; or sets a
82// Python exception and returns nullptr if an error occurs.
83Safe_PyObjectPtr AttrValueToPyObject(const AttrValue& attr_value);
84
85// Converts a c++ `DataType` protobuf enum to a Python object; or sets a
86// Python exception and returns nullptr if an error occurs.
87Safe_PyObjectPtr DataTypeToPyObject(const DataType& data_type);
88
89// Converts a c++ `TensorShapeProto` message to a Python object; or sets a
90// Python exception and returns nullptr if an error occurs.
91Safe_PyObjectPtr TensorShapeProtoToPyObject(
92 const TensorShapeProto& tensor_shape);
93
94// TODO(edloper): Define TensorProtoToPyObject?
95
96} // namespace tensorflow
97
98#endif // TENSORFLOW_PYTHON_FRAMEWORK_OP_DEF_UTIL_H_
99