1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_INFO_H_ |
16 | #define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_INFO_H_ |
17 | |
18 | #include <Python.h> |
19 | |
20 | #include <map> |
21 | #include <string> |
22 | #include <vector> |
23 | |
24 | #include "absl/types/span.h" |
25 | #include "tensorflow/core/framework/op_def.pb.h" |
26 | #include "tensorflow/core/framework/types.pb.h" |
27 | #include "tensorflow/core/platform/status.h" |
28 | #include "tensorflow/python/framework/op_def_util.h" |
29 | #include "tensorflow/python/framework/python_tensor_converter.h" |
30 | #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" |
31 | |
32 | namespace tensorflow { |
33 | |
34 | // Precomputed information about a TensorFlow Python API. |
35 | // |
36 | // PythonAPIInfo records information about a single TensorFlow Python API, |
37 | // in order to allow calls to the API to be executed more efficiently. This |
38 | // information includes: |
39 | // |
40 | // * The name of the API. (E.g. "tf.math.add") |
41 | // |
42 | // * The name of the registered op that implements the API, if applicable |
43 | // (e.g. "AddV2"). |
44 | // |
45 | // * Information about the API's parameters. Parameters are divided into two |
46 | // "kinds": inputs and attributes. An *input* is a parameter that |
47 | // expects a Tensor or list of Tensors, and it is described by an `ArgDef`. |
48 | // An *attribute* is a parameter that expects any other value type, and it is |
49 | // described by an `AttrDef`. |
50 | // |
51 | // * Default values for the API's attribute parameters. |
52 | // |
53 | // * Information about "inferred attributes" -- attributes whose values are |
54 | // inferred from `input` parameters. There are two kinds of inferred |
55 | // attributes: Tensor dtypes, which are inferred from tensor and list(tensor) |
56 | // parameters; and list lengths, which are inferred from list(tensor) |
57 | // parameters. |
58 | class PythonAPIInfo { |
59 | public: |
60 | // The index of a parameter in the canonicalized parameter list. The |
61 | // canonicalized parameter list includes inputs and attributes (but does |
62 | // not include inferred attributes). `-1` is used for inferred attributes. |
63 | using ParamIndex = int; |
64 | |
65 | // Information about a parameter that expects a non-Tensor value. |
66 | struct Attribute { |
67 | ParamIndex index; // -1 if this is an inferred attribute |
68 | AttributeType type; |
69 | const char* name; // Interned python string |
70 | int inferred_index; // index to store attribute in InferredAttributes |
71 | }; |
72 | |
73 | // Information about a parameter that expects a Tensor or list(Tensor). |
74 | // Additional information about tensor parameters is stored in types |
75 | // defined below, in order to simplify dtype/length inference: |
76 | // * FixedDTypeInput: inputs with fixed dtypes. |
77 | // * InputsWithTypeAttr: groups inputs that use a type_attr for dtype. |
78 | // * InputsWithTypeListAttr: groups inputs that use a type_list_attr. |
79 | // * InputsWithNumberAttr: groups inputs by a number_attr for length. |
80 | struct Input { |
81 | ParamIndex index; |
82 | bool is_list; |
83 | }; |
84 | |
85 | // Information about a Tensor parameter w/ fixed dtype. |
86 | struct InputWithFixedDType { |
87 | DataType dtype; |
88 | ParamIndex index; |
89 | bool is_list; |
90 | }; |
91 | |
92 | // Information about Tensor parameters whose DType is specified by a single |
93 | // `type_attr` attribute. |
94 | struct InputsWithTypeAttr { |
95 | Attribute* type_attr; // not owned. |
96 | DataType default_dtype; // DT_INVALID if no default. |
97 | std::vector<ParamIndex> tensor_params; // single-tensor inputs. |
98 | std::vector<ParamIndex> tensor_list_params; // list(tensor) inputs. |
99 | std::vector<DataType> ok_dtypes; |
100 | }; |
101 | |
102 | // Information about Tensor parameters whose DType is specified by a single |
103 | // `type_list_attr` attribute. |
104 | struct InputsWithTypeListAttr { |
105 | Attribute* type_list_attr; // not owned. |
106 | std::vector<DataType> default_dtypes; // empty if no default. |
107 | std::vector<ParamIndex> tensor_list_params; // list(tensor) inputs. |
108 | std::vector<DataType> ok_dtypes; |
109 | }; |
110 | |
111 | // Information about Tensor-list parameters whose length is specified by a |
112 | // single `int` attribute. |
113 | struct InputsWithNumberAttr { |
114 | Attribute* number_attr; // not owned. |
115 | int64_t default_length; // -1 for no default. |
116 | std::vector<ParamIndex> tensor_list_params; // list(tensor) inputs. |
117 | }; |
118 | |
119 | // Structure used to return inferred attribute values. |
120 | // * types[i] is the inferred value for inferred_type_attrs()[i] |
121 | // * type_lists[i] is the inferred value for inferred_type_list_attrs()[i] |
122 | // * lengths[i] is the inferred value for inferred_length_attrs()[i] |
123 | struct InferredAttributes { |
124 | std::vector<DataType> types; |
125 | std::vector<std::vector<DataType>> type_lists; |
126 | std::vector<int64_t> lengths; |
127 | }; |
128 | |
129 | // Constructs a new PythonAPIInfo. |
130 | // |
131 | // Note: One of the `Initialize()` functions must be called before the |
132 | // `PythonAPIInfo` is used. |
133 | // |
134 | // Args: |
135 | // api_name: The fully-qualified name of the python API (e.g., tf.math.sum). |
136 | explicit PythonAPIInfo(const std::string& api_name); |
137 | |
138 | // Initializes this PythonAPIInfo. |
139 | // |
140 | // Args: |
141 | // op_def: Contains information about the parameters. |
142 | // param_names: The argument names for the python API, in canonical order. |
143 | // defaults_tuple: Tuple containing default values for the parameters, |
144 | // right-aligned with `param_names` -- i.e., `defaults[-i]` is the default |
145 | // for `param_names[-i]`. |
146 | Status Initialize(const OpDef& op_def, const std::vector<string> param_names, |
147 | PyObject* defaults_tuple); |
148 | |
149 | // Initialize this PythonAPIInfo based on the registered OpDef for the given |
150 | // operation. |
151 | // |
152 | // Args: |
153 | // op_name: The registered name of the operation (e.g. "AddV2"). |
154 | Status InitializeFromRegisteredOp(const std::string& op_name); |
155 | |
156 | // Initializes this PythonAPIInfo based on a set of parameter specifications. |
157 | // |
158 | // Args: |
159 | // input_specs: Mapping from parameter name to specification string for |
160 | // each input (parameter that expects a tensor value). |
161 | // attr_specs: Mapping from parameter name to specification string for |
162 | // each attribute (parameter that expects a non-tensor value). |
163 | // param_names: The argument names for the python API, in canonical order. |
164 | // defaults_tuple: Tuple containing default values for the parameters, |
165 | // right-aligned with `param_names` -- i.e., `defaults[-i]` is the default |
166 | // for `param_names[-i]`. |
167 | // |
168 | // Note: the `name` parameter should not be included in `input_specs` or |
169 | // `attr_specs`. |
170 | Status InitializeFromParamSpecs( |
171 | const std::map<std::string, std::string>& input_specs, |
172 | const std::map<std::string, std::string>& attr_specs, |
173 | const std::vector<string> param_names, PyObject* defaults_tuple); |
174 | |
175 | // The name of the API that is described by this PythonAPIInfo. |
176 | const char* api_name() const { return api_name_; } |
177 | |
178 | // The ordered names of the canononical parameters that this API expects. |
179 | const std::vector<const char*>& param_names() const { return param_names_; } |
180 | |
181 | // A Python tuple containing the default values for parameters. This is |
182 | // right-aligned with `param_name` -- i.e., `defaults[-i]` is the default |
183 | // for `param_names[-i]`. |
184 | const PyObject* defaults_tuple() const { return defaults_tuple_.get(); } |
185 | |
186 | // Information about the attribute (non-tensor) parameters for this API. |
187 | const std::vector<Attribute>& attributes() const { return attributes_; } |
188 | |
189 | // Information about the input (tensor) parameters for this API. |
190 | const std::vector<Input>& inputs() const { return inputs_; } |
191 | const std::vector<InputWithFixedDType>& inputs_with_fixed_dtype() const { |
192 | return inputs_with_fixed_dtype_; |
193 | } |
194 | const std::vector<InputsWithTypeAttr>& inputs_with_type_attrs() const { |
195 | return inputs_with_type_attrs_; |
196 | } |
197 | const std::vector<InputsWithTypeListAttr>& inputs_with_type_list_attrs() |
198 | const { |
199 | return inputs_with_type_list_attrs_; |
200 | } |
201 | const std::vector<InputsWithNumberAttr>& inputs_with_number_attrs() const { |
202 | return inputs_with_number_attrs_; |
203 | } |
204 | |
205 | // Names of inferred attributes. |
206 | const std::vector<const char*>& inferred_type_attrs() const { |
207 | return inferred_type_attrs_; |
208 | } |
209 | const std::vector<const char*>& inferred_type_list_attrs() const { |
210 | return inferred_type_list_attrs_; |
211 | } |
212 | const std::vector<const char*>& inferred_length_attrs() const { |
213 | return inferred_length_attrs_; |
214 | } |
215 | |
216 | // Returns a string summarizing the internal state of this type converter. |
217 | string DebugInfo() const; |
218 | |
219 | private: |
220 | // Adds an entry to the attributes_ vector based on the given `AttrDef`. |
221 | // |
222 | // If `attr_def` describes a type attribute, then adds a value to |
223 | // inputs_with_type_attrs_ or inputs_with_type_list_attrs_ (to record any |
224 | // tensor inputs that use this dtype). |
225 | // |
226 | // If `attr_def` describes an int attribute, then adds a value to |
227 | // inputs_with_number_attrs_ (to record any tensor inputs that use this |
228 | // value as a list length). |
229 | Status InitializeAttribute( |
230 | const OpDef::AttrDef& attr_def, |
231 | const std::map<std::string, ParamIndex>& param_name_to_index); |
232 | |
233 | // Adds an entry to the inputs_ vector based on the given `ArgDef`. |
234 | // |
235 | // If `arg_def` has a fixed dtype, then adds a value to `fixed_dtype_inputs`. |
236 | // |
237 | // If `arg_def`'s dtype is described by a `type` attr, then updates the |
238 | // appropriate value in `inputs_with_type_attrs_` with information about the |
239 | // `arg_def`. |
240 | // |
241 | // If `arg_def`'s dtype is described by a `list(type)` attr, then updates the |
242 | // appropriate value in `inputs_with_type_list_attrs_` with information about |
243 | // the `arg_def`. |
244 | Status InitializeInput(const OpDef::ArgDef& arg_def, |
245 | const std::map<std::string, int>& param_name_to_index); |
246 | |
247 | // Checks that the OpDef used to initialize this PythonAPIInfo |
248 | // had an AttrDef or ArgDef specification for each parameter. |
249 | Status CheckParamNames() const; |
250 | |
251 | // Searches inputs_with_type_attrs_ for an input with the given name. |
252 | InputsWithTypeAttr* FindInputsWithTypeAttr(const string& name); |
253 | |
254 | // Searches inputs_with_type_list_attrs_ for an input with the given name. |
255 | InputsWithTypeListAttr* FindInputsWithTypeListAttr(const string& name); |
256 | |
257 | // Searches inputs_with_type_list_attrs_ for an input with the given name. |
258 | InputsWithNumberAttr* FindInputsWithNumberAttr(const string& name); |
259 | |
260 | ABSL_MUST_USE_RESULT |
261 | bool InferLengthAttributes(const absl::Span<PyObject*> params, |
262 | std::vector<int64_t>& inferred_length_attrs) const; |
263 | |
264 | // ========================================================================== |
265 | // Member Variables |
266 | // ========================================================================== |
267 | |
268 | // The name of the API that is described by this PythonAPIInfo. |
269 | // (Interned python string). |
270 | const char* api_name_; |
271 | |
272 | // The names of the parameters that this API expects. |
273 | // (Interned python strings.) |
274 | std::vector<const char*> param_names_; |
275 | |
276 | // Tuple containing default values for the parameters, right-aligned with |
277 | // `param_names` -- i.e., `defaults[-i]` is the default for `param_names[-i]`. |
278 | Safe_PyObjectPtr defaults_tuple_; |
279 | |
280 | // Information about the non-tensor-valued parameters that this API expects. |
281 | std::vector<Attribute> attributes_; |
282 | |
283 | // Information about the tensor-valued parameters that this API expects. |
284 | std::vector<Input> inputs_; |
285 | std::vector<InputWithFixedDType> inputs_with_fixed_dtype_; |
286 | std::vector<InputsWithTypeAttr> inputs_with_type_attrs_; |
287 | std::vector<InputsWithTypeListAttr> inputs_with_type_list_attrs_; |
288 | std::vector<InputsWithNumberAttr> inputs_with_number_attrs_; |
289 | |
290 | // Names of inferred attributes. (Interned python strings.) |
291 | std::vector<const char*> inferred_type_attrs_; |
292 | std::vector<const char*> inferred_type_list_attrs_; |
293 | std::vector<const char*> inferred_length_attrs_; |
294 | }; |
295 | |
296 | } // namespace tensorflow |
297 | |
298 | #endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_INFO_H_ |
299 | |