1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/python/framework/python_api_info.h" |
16 | |
17 | #include <Python.h> |
18 | |
19 | #include "absl/strings/str_cat.h" |
20 | #include "tensorflow/core/framework/op.h" |
21 | #include "tensorflow/core/lib/gtl/map_util.h" |
22 | #include "tensorflow/python/eager/pywrap_tensor.h" |
23 | #include "tensorflow/python/eager/pywrap_tfe.h" |
24 | #include "tensorflow/python/framework/op_def_util.h" |
25 | #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" |
26 | #include "tensorflow/python/util/util.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | #if PY_MAJOR_VERSION < 3 |
31 | // Python 2.x: |
32 | #define PY_STRING_CHECK(x) (PyString_Check(x) || PyUnicode_Check(x)) |
33 | #define PY_INT_AS_LONG(x) (PyInt_AsLong(x)) |
34 | #define PY_STRING_FROMSTRING(x) (PyString_FromString(x)) |
35 | #define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x)) |
36 | #define PY_STRING_AS_CSTR(x) (PyString_AsString(x)) |
37 | #else |
38 | // Python 3.x: |
39 | #define PY_STRING_CHECK(x) (PyBytes_Check(x) || PyUnicode_Check(x)) |
40 | #define PY_INT_AS_LONG(x) (PyLong_AsLong(x)) |
41 | #define PY_STRING_FROMSTRING(x) (PyUnicode_FromString(x)) |
42 | #define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x)) |
43 | #define PY_STRING_AS_CSTR(x) (PyUnicode_AsUTF8AndSize((x), nullptr)) |
44 | #endif |
45 | |
46 | namespace { |
47 | |
48 | // Converts the given object to an interned Python string, and returns its |
49 | // data pointer. (This means we don't need to worry about ownership for |
50 | // this string.) |
51 | const char* InternPyString(const std::string& s) { |
52 | Safe_PyObjectPtr interned(PY_STRING_INTERN_FROM_STRING(s.c_str())); |
53 | return PY_STRING_AS_CSTR(interned.get()); |
54 | } |
55 | |
56 | template <typename T, typename UnaryPredicate> |
57 | void RemoveIf(UnaryPredicate p, std::vector<T>* vec) { |
58 | vec->erase(std::remove_if(vec->begin(), vec->end(), p), vec->end()); |
59 | } |
60 | |
61 | struct DataTypeFormatter { |
62 | void operator()(std::string* out, DataType dtype) const { |
63 | out->append(DataType_Name(dtype)); |
64 | } |
65 | }; |
66 | |
67 | // Populates `param_names` and `defaults_tuple` based on the given OpDef. |
68 | void GetOpDefNamesAndDefaults(const tensorflow::OpDef& op_def, |
69 | std::vector<string>& param_names, |
70 | Safe_PyObjectPtr& defaults_tuple) { |
71 | param_names.reserve(op_def.input_arg_size() + op_def.attr_size()); |
72 | std::set<std::string> inferred_attrs; |
73 | |
74 | // Input parameters come first, in the order they occur in the OpDef. |
75 | for (const auto& input : op_def.input_arg()) { |
76 | param_names.push_back(input.name()); |
77 | if (!input.type_attr().empty()) { |
78 | inferred_attrs.insert(input.type_attr()); |
79 | } |
80 | if (!input.type_list_attr().empty()) { |
81 | inferred_attrs.insert(input.type_list_attr()); |
82 | } |
83 | if (!input.number_attr().empty()) { |
84 | inferred_attrs.insert(input.number_attr()); |
85 | } |
86 | } |
87 | |
88 | // Next come attribute params without defaults, followed by attributes with |
89 | // defaults (but inferred attributes are not included). |
90 | std::vector<std::string> param_names_with_default; |
91 | std::vector<Safe_PyObjectPtr> defaults; |
92 | for (const auto& attr : op_def.attr()) { |
93 | if (inferred_attrs.count(attr.name()) == 0) { |
94 | if (attr.has_default_value()) { |
95 | param_names_with_default.push_back(attr.name()); |
96 | defaults.push_back(AttrValueToPyObject(attr.default_value())); |
97 | } else { |
98 | param_names.push_back(attr.name()); |
99 | } |
100 | } |
101 | } |
102 | param_names.insert(param_names.end(), param_names_with_default.begin(), |
103 | param_names_with_default.end()); |
104 | |
105 | // Finally, the 'name' parameter comes at the end, and its default value |
106 | // is the operation's name. |
107 | param_names.push_back("name" ); |
108 | defaults.emplace_back(PY_STRING_FROMSTRING(op_def.name().c_str())); |
109 | |
110 | defaults_tuple.reset(PyTuple_New(defaults.size())); |
111 | for (int i = 0; i < defaults.size(); ++i) { |
112 | PyTuple_SET_ITEM(defaults_tuple.get(), i, defaults[i].release()); |
113 | } |
114 | } |
115 | |
116 | } // namespace |
117 | |
118 | PythonAPIInfo::PythonAPIInfo(const std::string& api_name) |
119 | : api_name_(InternPyString(api_name)) {} |
120 | |
121 | Status PythonAPIInfo::Initialize(const OpDef& op_def, |
122 | const std::vector<string> param_names, |
123 | PyObject* defaults_tuple) { |
124 | // Intern the parameter names. |
125 | param_names_.reserve(param_names.size()); |
126 | for (const auto& param_name : param_names) { |
127 | param_names_.push_back(InternPyString(param_name)); |
128 | } |
129 | |
130 | Py_INCREF(defaults_tuple); |
131 | defaults_tuple_.reset(defaults_tuple); |
132 | |
133 | // Build an index to look up parameter index by name. (Does not include |
134 | // inferred attributes.) |
135 | std::map<std::string, int> param_name_to_index; |
136 | for (int i = 0; i < param_names_.size(); ++i) { |
137 | param_name_to_index[param_names_[i]] = i; |
138 | } |
139 | |
140 | // Initialize each attribute & input parameter. |
141 | attributes_.reserve(op_def.attr_size()); |
142 | for (const auto& attr_def : op_def.attr()) { |
143 | TF_RETURN_IF_ERROR(InitializeAttribute(attr_def, param_name_to_index)); |
144 | } |
145 | |
146 | inputs_.reserve(op_def.input_arg_size()); |
147 | for (const auto& arg_def : op_def.input_arg()) { |
148 | TF_RETURN_IF_ERROR(InitializeInput(arg_def, param_name_to_index)); |
149 | } |
150 | |
151 | TF_RETURN_IF_ERROR(CheckParamNames()); |
152 | |
153 | // Filter out any unused entries from inputs_with_*_attrs_. |
154 | RemoveIf( |
155 | [](const InputsWithTypeAttr& input) { |
156 | return input.tensor_params.empty() && input.tensor_list_params.empty(); |
157 | }, |
158 | &inputs_with_type_attrs_); |
159 | RemoveIf( |
160 | [](const InputsWithTypeListAttr& input) { |
161 | return input.tensor_list_params.empty(); |
162 | }, |
163 | &inputs_with_type_list_attrs_); |
164 | RemoveIf( |
165 | [](const InputsWithNumberAttr& input) { |
166 | return input.tensor_list_params.empty(); |
167 | }, |
168 | &inputs_with_number_attrs_); |
169 | |
170 | return OkStatus(); |
171 | } |
172 | |
173 | Status PythonAPIInfo::CheckParamNames() const { |
174 | std::vector<bool> param_found(param_names_.size()); |
175 | for (const auto& attr : attributes_) { |
176 | if (attr.index != -1) { |
177 | param_found[attr.index] = true; |
178 | } |
179 | } |
180 | for (const auto& input : inputs_) { |
181 | param_found[input.index] = true; |
182 | } |
183 | |
184 | for (int i = 0; i < param_names_.size(); ++i) { |
185 | if (param_names_[i] == std::string("name" )) { |
186 | continue; |
187 | } |
188 | if (!param_found[i]) { |
189 | return errors::InvalidArgument( |
190 | api_name_, ": missing specification for parameter " , param_names_[i]); |
191 | } |
192 | } |
193 | return OkStatus(); |
194 | } |
195 | |
196 | Status PythonAPIInfo::InitializeFromRegisteredOp(const std::string& op_name) { |
197 | const tensorflow::OpDef* op_def = nullptr; |
198 | TF_RETURN_IF_ERROR( |
199 | tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def)); |
200 | std::vector<std::string> param_names; |
201 | Safe_PyObjectPtr defaults_tuple; |
202 | GetOpDefNamesAndDefaults(*op_def, param_names, defaults_tuple); |
203 | TF_RETURN_IF_ERROR(Initialize(*op_def, param_names, defaults_tuple.get())); |
204 | return OkStatus(); |
205 | } |
206 | |
207 | Status PythonAPIInfo::InitializeFromParamSpecs( |
208 | const std::map<std::string, std::string>& input_specs, |
209 | const std::map<std::string, std::string>& attr_specs, |
210 | const std::vector<string> param_names, PyObject* defaults_tuple) { |
211 | OpDefBuilder op_def_builder(api_name_); |
212 | op_def_builder.AllowAttrTypeAny(); |
213 | for (const auto& attr_spec : attr_specs) { |
214 | op_def_builder.Attr(absl::StrCat(attr_spec.first, ": " , attr_spec.second)); |
215 | } |
216 | for (const auto& input_spec : input_specs) { |
217 | op_def_builder.Input( |
218 | absl::StrCat(input_spec.first, ": " , input_spec.second)); |
219 | } |
220 | OpRegistrationData op_reg_data; |
221 | TF_RETURN_IF_ERROR(op_def_builder.Finalize(&op_reg_data)); |
222 | |
223 | TF_RETURN_IF_ERROR( |
224 | Initialize(op_reg_data.op_def, param_names, defaults_tuple)); |
225 | |
226 | return OkStatus(); |
227 | } |
228 | |
229 | Status PythonAPIInfo::InitializeAttribute( |
230 | const OpDef::AttrDef& attr_def, |
231 | const std::map<std::string, int>& param_name_to_index) { |
232 | if (attr_def.name() == "name" ) { |
233 | return errors::InvalidArgument( |
234 | api_name_, ": Reserved parameter `name` was used as an attribute." ); |
235 | } |
236 | const char* name = InternPyString(attr_def.name()); |
237 | |
238 | const int param_index = |
239 | gtl::FindWithDefault(param_name_to_index, attr_def.name(), -1); |
240 | const AttributeType dtype = AttributeTypeFromName(attr_def.type()); |
241 | const int inferred_index = -1; |
242 | attributes_.push_back({param_index, dtype, name, inferred_index}); |
243 | Attribute& attr = attributes_.back(); |
244 | if (attr.type == AttributeType::UNKNOWN) { |
245 | return errors::InvalidArgument(api_name_, ": Bad attribute type for " , |
246 | attr_def.name(), ": '" , attr_def.type(), |
247 | "'" ); |
248 | } |
249 | std::vector<DataType>* ok_dtypes = nullptr; |
250 | |
251 | if (attr.type == AttributeType::DTYPE) { |
252 | DataType default_dtype = attr_def.has_default_value() |
253 | ? attr_def.default_value().type() |
254 | : DT_INVALID; |
255 | inputs_with_type_attrs_.push_back({&attr, default_dtype}); |
256 | ok_dtypes = &inputs_with_type_attrs_.back().ok_dtypes; |
257 | |
258 | } else if (attr.type == AttributeType::LIST_DTYPE) { |
259 | inputs_with_type_list_attrs_.push_back({&attr}); |
260 | for (int d : attr_def.default_value().list().type()) { |
261 | inputs_with_type_list_attrs_.back().default_dtypes.push_back( |
262 | static_cast<DataType>(d)); |
263 | } |
264 | ok_dtypes = &inputs_with_type_list_attrs_.back().ok_dtypes; |
265 | } |
266 | |
267 | if (attr_def.has_allowed_values() && ok_dtypes) { |
268 | const auto& dtypes = attr_def.allowed_values().list(); |
269 | for (int i = 0; i < dtypes.type_size(); ++i) { |
270 | ok_dtypes->push_back(dtypes.type(i)); |
271 | } |
272 | } |
273 | |
274 | if (attr.type == AttributeType::INT) { |
275 | int64_t default_len = |
276 | attr_def.has_default_value() ? attr_def.default_value().i() : -1; |
277 | inputs_with_number_attrs_.push_back({&attr, default_len}); |
278 | } |
279 | |
280 | // If this is an inferred attribute, then record its name and index. |
281 | if (attr.index == -1) { |
282 | std::vector<const char*>* inferred_attr_names = |
283 | attr.type == AttributeType::DTYPE ? &inferred_type_attrs_ |
284 | : attr.type == AttributeType::LIST_DTYPE ? &inferred_type_list_attrs_ |
285 | : attr.type == AttributeType::INT ? &inferred_length_attrs_ |
286 | : nullptr; |
287 | if (inferred_attr_names == nullptr) { |
288 | return errors::InvalidArgument( |
289 | api_name_, ": Missing specification for parameter " , attr_def.name()); |
290 | } else { |
291 | attr.inferred_index = inferred_attr_names->size(); |
292 | inferred_attr_names->push_back(attr.name); |
293 | } |
294 | } |
295 | |
296 | return OkStatus(); |
297 | } |
298 | |
299 | Status PythonAPIInfo::InitializeInput( |
300 | const OpDef::ArgDef& arg_def, |
301 | const std::map<std::string, ParamIndex>& param_name_to_index) { |
302 | if (arg_def.name() == "name" ) { |
303 | return errors::InvalidArgument( |
304 | api_name_, ": Reserved parameter `name` was used as a tensor input." ); |
305 | } |
306 | const ParamIndex param_index = |
307 | gtl::FindWithDefault(param_name_to_index, arg_def.name(), -1); |
308 | if (param_index == -1) { |
309 | return errors::InvalidArgument( |
310 | api_name_, ": Missing specification for parameter " , arg_def.name()); |
311 | } |
312 | if (arg_def.is_ref()) { |
313 | // TODO(b/164980194): Support reference parameters. |
314 | // - Pass as_ref to convert_to_tensor |
315 | // - Check that values for ref inputs have ref types. |
316 | return errors::InvalidArgument(api_name_, |
317 | ": PythonAPIInfo doesn't support reference " |
318 | "parameters yet." ); |
319 | } |
320 | bool is_list = |
321 | !arg_def.number_attr().empty() || !arg_def.type_list_attr().empty(); |
322 | inputs_.push_back({param_index, is_list}); |
323 | |
324 | if (!arg_def.type_list_attr().empty()) { |
325 | // list(input) with dtypes specified by a `list(type)` attribute. |
326 | InputsWithTypeListAttr* input = |
327 | FindInputsWithTypeListAttr(arg_def.type_list_attr()); |
328 | if (!input) { |
329 | return errors::InvalidArgument( |
330 | api_name_, ": Type attribute " , arg_def.type_list_attr(), |
331 | " for parameter " , arg_def.name(), " not found." ); |
332 | } |
333 | input->tensor_list_params.push_back(param_index); |
334 | } else if (!arg_def.type_attr().empty()) { |
335 | InputsWithTypeAttr* input = FindInputsWithTypeAttr(arg_def.type_attr()); |
336 | // input or list(input) with dtype specified by a `type` attribute. |
337 | if (!input) { |
338 | return errors::InvalidArgument(api_name_, ": Type attribute " , |
339 | arg_def.type_attr(), " for parameter " , |
340 | arg_def.name(), " not found." ); |
341 | } |
342 | if (arg_def.number_attr().empty()) { |
343 | input->tensor_params.push_back(param_index); |
344 | } else { |
345 | input->tensor_list_params.push_back(param_index); |
346 | } |
347 | } else { |
348 | // input or list(input) with fixed dtype |
349 | inputs_with_fixed_dtype_.push_back({arg_def.type(), param_index, is_list}); |
350 | } |
351 | |
352 | if (!arg_def.number_attr().empty()) { |
353 | InputsWithNumberAttr* input = |
354 | FindInputsWithNumberAttr(arg_def.number_attr()); |
355 | if (!input) { |
356 | return errors::InvalidArgument(api_name_, ": Length attribute " , |
357 | arg_def.number_attr(), " for parameter " , |
358 | arg_def.name(), " not found." ); |
359 | } |
360 | input->tensor_list_params.push_back(param_index); |
361 | } |
362 | |
363 | return OkStatus(); |
364 | } |
365 | |
366 | PythonAPIInfo::InputsWithTypeAttr* PythonAPIInfo::FindInputsWithTypeAttr( |
367 | const string& name) { |
368 | for (auto& input : inputs_with_type_attrs_) { |
369 | if (name == input.type_attr->name) { |
370 | return &input; |
371 | } |
372 | } |
373 | return nullptr; |
374 | } |
375 | |
376 | PythonAPIInfo::InputsWithTypeListAttr* |
377 | PythonAPIInfo::FindInputsWithTypeListAttr(const string& name) { |
378 | for (auto& input : inputs_with_type_list_attrs_) { |
379 | if (name == input.type_list_attr->name) { |
380 | return &input; |
381 | } |
382 | } |
383 | return nullptr; |
384 | } |
385 | |
386 | PythonAPIInfo::InputsWithNumberAttr* PythonAPIInfo::FindInputsWithNumberAttr( |
387 | const string& name) { |
388 | for (auto& input : inputs_with_number_attrs_) { |
389 | if (name == input.number_attr->name) { |
390 | return &input; |
391 | } |
392 | } |
393 | return nullptr; |
394 | } |
395 | |
396 | string PythonAPIInfo::DebugInfo() const { |
397 | string s = absl::StrCat("DebugInfo for " , api_name_, ":\n" ); |
398 | absl::StrAppend(&s, " param_names=[" , absl::StrJoin(param_names_, ", " ), |
399 | "]\n" ); |
400 | Safe_PyObjectPtr defaults_repr(PyObject_Repr(defaults_tuple_.get())); |
401 | absl::StrAppend( |
402 | &s, " defaults_tuple=" , TFE_GetPythonString(defaults_repr.get()), "\n" ); |
403 | if (!attributes_.empty()) { |
404 | absl::StrAppend(&s, " attributes=[" ); |
405 | for (const auto& attrib : attributes_) { |
406 | if (attrib.index != -1) { |
407 | absl::StrAppend(&s, "\n {index=" , attrib.index); |
408 | DCHECK_EQ(attrib.inferred_index, -1); |
409 | } else { |
410 | absl::StrAppend(&s, "\n {inferred_index=" , attrib.inferred_index); |
411 | } |
412 | absl::StrAppend(&s, ", name=" , attrib.name, |
413 | ", type=" , AttributeTypeToName(attrib.type), "}," ); |
414 | } |
415 | absl::StrAppend(&s, "]\n" ); |
416 | } |
417 | if (!inputs_.empty()) { |
418 | absl::StrAppend(&s, " inputs=[" ); |
419 | for (const auto& input : inputs_) { |
420 | absl::StrAppend(&s, "\n {index=" , input.index, |
421 | ", name=" , param_names_[input.index], |
422 | ", is_list=" , input.is_list, "}," ); |
423 | } |
424 | absl::StrAppend(&s, "]\n" ); |
425 | } |
426 | if (!inputs_with_fixed_dtype_.empty()) { |
427 | absl::StrAppend(&s, " inputs_with_fixed_dtype=[" ); |
428 | for (const auto& input : inputs_with_fixed_dtype_) { |
429 | absl::StrAppend(&s, "\n {index=" , input.index, |
430 | ", dtype=" , DataType_Name(input.dtype), |
431 | ", is_list=" , input.is_list, "}," ); |
432 | } |
433 | absl::StrAppend(&s, "]\n" ); |
434 | } |
435 | if (!inputs_with_type_attrs_.empty()) { |
436 | absl::StrAppend(&s, " inputs_with_type_attr=[" ); |
437 | for (const auto& input : inputs_with_type_attrs_) { |
438 | absl::StrAppend(&s, "\n {type_attr=" , input.type_attr->name); |
439 | if (input.default_dtype != DT_INVALID) { |
440 | absl::StrAppend(&s, |
441 | ", default_dtype=" , DataType_Name(input.default_dtype)); |
442 | } |
443 | if (!input.tensor_params.empty()) { |
444 | absl::StrAppend(&s, ", tensor_params=[" , |
445 | absl::StrJoin(input.tensor_params, ", " ), "]" ); |
446 | } |
447 | if (!input.tensor_list_params.empty()) { |
448 | absl::StrAppend(&s, ", tensor_list_params=[" , |
449 | absl::StrJoin(input.tensor_list_params, ", " ), "]" ); |
450 | } |
451 | if (!input.ok_dtypes.empty()) { |
452 | absl::StrAppend( |
453 | &s, ", ok_dtypes=[" , |
454 | absl::StrJoin(input.ok_dtypes, ", " , DataTypeFormatter()), "]" ); |
455 | } |
456 | absl::StrAppend(&s, "}," ); |
457 | } |
458 | absl::StrAppend(&s, "]\n" ); |
459 | } |
460 | if (!inputs_with_type_list_attrs_.empty()) { |
461 | absl::StrAppend(&s, " inputs_with_type_list_attrs=[" ); |
462 | for (const auto& input : inputs_with_type_list_attrs_) { |
463 | absl::StrAppend(&s, "\n {type_list_attr=" , input.type_list_attr->name); |
464 | if (!input.default_dtypes.empty()) { |
465 | absl::StrAppend( |
466 | &s, ", default_dtypes=[" , |
467 | absl::StrJoin(input.default_dtypes, ", " , DataTypeFormatter()), |
468 | "]" ); |
469 | } |
470 | if (!input.tensor_list_params.empty()) { |
471 | absl::StrAppend(&s, ", tensor_list_params=[" , |
472 | absl::StrJoin(input.tensor_list_params, ", " ), "]" ); |
473 | } |
474 | if (!input.ok_dtypes.empty()) { |
475 | absl::StrAppend( |
476 | &s, ", ok_dtypes=[" , |
477 | absl::StrJoin(input.ok_dtypes, ", " , DataTypeFormatter()), "]" ); |
478 | } |
479 | absl::StrAppend(&s, "}," ); |
480 | } |
481 | absl::StrAppend(&s, "]\n" ); |
482 | } |
483 | if (!inputs_with_number_attrs_.empty()) { |
484 | absl::StrAppend(&s, " inputs_with_number_attrs=[" ); |
485 | for (const auto& input : inputs_with_number_attrs_) { |
486 | absl::StrAppend(&s, "\n {number_attr=" , input.number_attr->name, |
487 | ", default_length=" , input.default_length, |
488 | ", tensor_list_params=[" , |
489 | absl::StrJoin(input.tensor_list_params, ", " ), "],\n" ); |
490 | } |
491 | absl::StrAppend(&s, "]\n" ); |
492 | } |
493 | if (!inferred_type_attrs_.empty()) { |
494 | absl::StrAppend(&s, " inferred_type_attrs=[" , |
495 | absl::StrJoin(inferred_type_attrs_, ", " ), "]\n" ); |
496 | } |
497 | if (!inferred_type_list_attrs_.empty()) { |
498 | absl::StrAppend(&s, " inferred_type_list_attrs=[" , |
499 | absl::StrJoin(inferred_type_list_attrs_, ", " ), "]\n" ); |
500 | } |
501 | if (!inferred_length_attrs_.empty()) { |
502 | absl::StrAppend(&s, " inferred_length_attrs=[" , |
503 | absl::StrJoin(inferred_length_attrs_, ", " ), "]\n" ); |
504 | } |
505 | return s; |
506 | } |
507 | |
508 | } // namespace tensorflow |
509 | |