1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_INFO_H_
16#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_INFO_H_
17
18#include <Python.h>
19
20#include <map>
21#include <string>
22#include <vector>
23
24#include "absl/types/span.h"
25#include "tensorflow/core/framework/op_def.pb.h"
26#include "tensorflow/core/framework/types.pb.h"
27#include "tensorflow/core/platform/status.h"
28#include "tensorflow/python/framework/op_def_util.h"
29#include "tensorflow/python/framework/python_tensor_converter.h"
30#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
31
32namespace tensorflow {
33
34// Precomputed information about a TensorFlow Python API.
35//
36// PythonAPIInfo records information about a single TensorFlow Python API,
37// in order to allow calls to the API to be executed more efficiently. This
38// information includes:
39//
40// * The name of the API. (E.g. "tf.math.add")
41//
42// * The name of the registered op that implements the API, if applicable
43// (e.g. "AddV2").
44//
45// * Information about the API's parameters. Parameters are divided into two
46// "kinds": inputs and attributes. An *input* is a parameter that
47// expects a Tensor or list of Tensors, and it is described by an `ArgDef`.
48// An *attribute* is a parameter that expects any other value type, and it is
49// described by an `AttrDef`.
50//
51// * Default values for the API's attribute parameters.
52//
53// * Information about "inferred attributes" -- attributes whose values are
54// inferred from `input` parameters. There are two kinds of inferred
55// attributes: Tensor dtypes, which are inferred from tensor and list(tensor)
56// parameters; and list lengths, which are inferred from list(tensor)
57// parameters.
58class PythonAPIInfo {
59 public:
60 // The index of a parameter in the canonicalized parameter list. The
61 // canonicalized parameter list includes inputs and attributes (but does
62 // not include inferred attributes). `-1` is used for inferred attributes.
63 using ParamIndex = int;
64
65 // Information about a parameter that expects a non-Tensor value.
66 struct Attribute {
67 ParamIndex index; // -1 if this is an inferred attribute
68 AttributeType type;
69 const char* name; // Interned python string
70 int inferred_index; // index to store attribute in InferredAttributes
71 };
72
73 // Information about a parameter that expects a Tensor or list(Tensor).
74 // Additional information about tensor parameters is stored in types
75 // defined below, in order to simplify dtype/length inference:
76 // * FixedDTypeInput: inputs with fixed dtypes.
77 // * InputsWithTypeAttr: groups inputs that use a type_attr for dtype.
78 // * InputsWithTypeListAttr: groups inputs that use a type_list_attr.
79 // * InputsWithNumberAttr: groups inputs by a number_attr for length.
80 struct Input {
81 ParamIndex index;
82 bool is_list;
83 };
84
85 // Information about a Tensor parameter w/ fixed dtype.
86 struct InputWithFixedDType {
87 DataType dtype;
88 ParamIndex index;
89 bool is_list;
90 };
91
92 // Information about Tensor parameters whose DType is specified by a single
93 // `type_attr` attribute.
94 struct InputsWithTypeAttr {
95 Attribute* type_attr; // not owned.
96 DataType default_dtype; // DT_INVALID if no default.
97 std::vector<ParamIndex> tensor_params; // single-tensor inputs.
98 std::vector<ParamIndex> tensor_list_params; // list(tensor) inputs.
99 std::vector<DataType> ok_dtypes;
100 };
101
102 // Information about Tensor parameters whose DType is specified by a single
103 // `type_list_attr` attribute.
104 struct InputsWithTypeListAttr {
105 Attribute* type_list_attr; // not owned.
106 std::vector<DataType> default_dtypes; // empty if no default.
107 std::vector<ParamIndex> tensor_list_params; // list(tensor) inputs.
108 std::vector<DataType> ok_dtypes;
109 };
110
111 // Information about Tensor-list parameters whose length is specified by a
112 // single `int` attribute.
113 struct InputsWithNumberAttr {
114 Attribute* number_attr; // not owned.
115 int64_t default_length; // -1 for no default.
116 std::vector<ParamIndex> tensor_list_params; // list(tensor) inputs.
117 };
118
119 // Structure used to return inferred attribute values.
120 // * types[i] is the inferred value for inferred_type_attrs()[i]
121 // * type_lists[i] is the inferred value for inferred_type_list_attrs()[i]
122 // * lengths[i] is the inferred value for inferred_length_attrs()[i]
123 struct InferredAttributes {
124 std::vector<DataType> types;
125 std::vector<std::vector<DataType>> type_lists;
126 std::vector<int64_t> lengths;
127 };
128
129 // Constructs a new PythonAPIInfo.
130 //
131 // Note: One of the `Initialize()` functions must be called before the
132 // `PythonAPIInfo` is used.
133 //
134 // Args:
135 // api_name: The fully-qualified name of the python API (e.g., tf.math.sum).
136 explicit PythonAPIInfo(const std::string& api_name);
137
138 // Initializes this PythonAPIInfo.
139 //
140 // Args:
141 // op_def: Contains information about the parameters.
142 // param_names: The argument names for the python API, in canonical order.
143 // defaults_tuple: Tuple containing default values for the parameters,
144 // right-aligned with `param_names` -- i.e., `defaults[-i]` is the default
145 // for `param_names[-i]`.
146 Status Initialize(const OpDef& op_def, const std::vector<string> param_names,
147 PyObject* defaults_tuple);
148
149 // Initialize this PythonAPIInfo based on the registered OpDef for the given
150 // operation.
151 //
152 // Args:
153 // op_name: The registered name of the operation (e.g. "AddV2").
154 Status InitializeFromRegisteredOp(const std::string& op_name);
155
156 // Initializes this PythonAPIInfo based on a set of parameter specifications.
157 //
158 // Args:
159 // input_specs: Mapping from parameter name to specification string for
160 // each input (parameter that expects a tensor value).
161 // attr_specs: Mapping from parameter name to specification string for
162 // each attribute (parameter that expects a non-tensor value).
163 // param_names: The argument names for the python API, in canonical order.
164 // defaults_tuple: Tuple containing default values for the parameters,
165 // right-aligned with `param_names` -- i.e., `defaults[-i]` is the default
166 // for `param_names[-i]`.
167 //
168 // Note: the `name` parameter should not be included in `input_specs` or
169 // `attr_specs`.
170 Status InitializeFromParamSpecs(
171 const std::map<std::string, std::string>& input_specs,
172 const std::map<std::string, std::string>& attr_specs,
173 const std::vector<string> param_names, PyObject* defaults_tuple);
174
175 // The name of the API that is described by this PythonAPIInfo.
176 const char* api_name() const { return api_name_; }
177
178 // The ordered names of the canononical parameters that this API expects.
179 const std::vector<const char*>& param_names() const { return param_names_; }
180
181 // A Python tuple containing the default values for parameters. This is
182 // right-aligned with `param_name` -- i.e., `defaults[-i]` is the default
183 // for `param_names[-i]`.
184 const PyObject* defaults_tuple() const { return defaults_tuple_.get(); }
185
186 // Information about the attribute (non-tensor) parameters for this API.
187 const std::vector<Attribute>& attributes() const { return attributes_; }
188
189 // Information about the input (tensor) parameters for this API.
190 const std::vector<Input>& inputs() const { return inputs_; }
191 const std::vector<InputWithFixedDType>& inputs_with_fixed_dtype() const {
192 return inputs_with_fixed_dtype_;
193 }
194 const std::vector<InputsWithTypeAttr>& inputs_with_type_attrs() const {
195 return inputs_with_type_attrs_;
196 }
197 const std::vector<InputsWithTypeListAttr>& inputs_with_type_list_attrs()
198 const {
199 return inputs_with_type_list_attrs_;
200 }
201 const std::vector<InputsWithNumberAttr>& inputs_with_number_attrs() const {
202 return inputs_with_number_attrs_;
203 }
204
205 // Names of inferred attributes.
206 const std::vector<const char*>& inferred_type_attrs() const {
207 return inferred_type_attrs_;
208 }
209 const std::vector<const char*>& inferred_type_list_attrs() const {
210 return inferred_type_list_attrs_;
211 }
212 const std::vector<const char*>& inferred_length_attrs() const {
213 return inferred_length_attrs_;
214 }
215
216 // Returns a string summarizing the internal state of this type converter.
217 string DebugInfo() const;
218
219 private:
220 // Adds an entry to the attributes_ vector based on the given `AttrDef`.
221 //
222 // If `attr_def` describes a type attribute, then adds a value to
223 // inputs_with_type_attrs_ or inputs_with_type_list_attrs_ (to record any
224 // tensor inputs that use this dtype).
225 //
226 // If `attr_def` describes an int attribute, then adds a value to
227 // inputs_with_number_attrs_ (to record any tensor inputs that use this
228 // value as a list length).
229 Status InitializeAttribute(
230 const OpDef::AttrDef& attr_def,
231 const std::map<std::string, ParamIndex>& param_name_to_index);
232
233 // Adds an entry to the inputs_ vector based on the given `ArgDef`.
234 //
235 // If `arg_def` has a fixed dtype, then adds a value to `fixed_dtype_inputs`.
236 //
237 // If `arg_def`'s dtype is described by a `type` attr, then updates the
238 // appropriate value in `inputs_with_type_attrs_` with information about the
239 // `arg_def`.
240 //
241 // If `arg_def`'s dtype is described by a `list(type)` attr, then updates the
242 // appropriate value in `inputs_with_type_list_attrs_` with information about
243 // the `arg_def`.
244 Status InitializeInput(const OpDef::ArgDef& arg_def,
245 const std::map<std::string, int>& param_name_to_index);
246
247 // Checks that the OpDef used to initialize this PythonAPIInfo
248 // had an AttrDef or ArgDef specification for each parameter.
249 Status CheckParamNames() const;
250
251 // Searches inputs_with_type_attrs_ for an input with the given name.
252 InputsWithTypeAttr* FindInputsWithTypeAttr(const string& name);
253
254 // Searches inputs_with_type_list_attrs_ for an input with the given name.
255 InputsWithTypeListAttr* FindInputsWithTypeListAttr(const string& name);
256
257 // Searches inputs_with_type_list_attrs_ for an input with the given name.
258 InputsWithNumberAttr* FindInputsWithNumberAttr(const string& name);
259
260 ABSL_MUST_USE_RESULT
261 bool InferLengthAttributes(const absl::Span<PyObject*> params,
262 std::vector<int64_t>& inferred_length_attrs) const;
263
264 // ==========================================================================
265 // Member Variables
266 // ==========================================================================
267
268 // The name of the API that is described by this PythonAPIInfo.
269 // (Interned python string).
270 const char* api_name_;
271
272 // The names of the parameters that this API expects.
273 // (Interned python strings.)
274 std::vector<const char*> param_names_;
275
276 // Tuple containing default values for the parameters, right-aligned with
277 // `param_names` -- i.e., `defaults[-i]` is the default for `param_names[-i]`.
278 Safe_PyObjectPtr defaults_tuple_;
279
280 // Information about the non-tensor-valued parameters that this API expects.
281 std::vector<Attribute> attributes_;
282
283 // Information about the tensor-valued parameters that this API expects.
284 std::vector<Input> inputs_;
285 std::vector<InputWithFixedDType> inputs_with_fixed_dtype_;
286 std::vector<InputsWithTypeAttr> inputs_with_type_attrs_;
287 std::vector<InputsWithTypeListAttr> inputs_with_type_list_attrs_;
288 std::vector<InputsWithNumberAttr> inputs_with_number_attrs_;
289
290 // Names of inferred attributes. (Interned python strings.)
291 std::vector<const char*> inferred_type_attrs_;
292 std::vector<const char*> inferred_type_list_attrs_;
293 std::vector<const char*> inferred_length_attrs_;
294};
295
296} // namespace tensorflow
297
298#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_INFO_H_
299