1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/python/framework/python_api_info.h"
16
17#include <Python.h>
18
19#include "absl/strings/str_cat.h"
20#include "tensorflow/core/framework/op.h"
21#include "tensorflow/core/lib/gtl/map_util.h"
22#include "tensorflow/python/eager/pywrap_tensor.h"
23#include "tensorflow/python/eager/pywrap_tfe.h"
24#include "tensorflow/python/framework/op_def_util.h"
25#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
26#include "tensorflow/python/util/util.h"
27
28namespace tensorflow {
29
30#if PY_MAJOR_VERSION < 3
31// Python 2.x:
32#define PY_STRING_CHECK(x) (PyString_Check(x) || PyUnicode_Check(x))
33#define PY_INT_AS_LONG(x) (PyInt_AsLong(x))
34#define PY_STRING_FROMSTRING(x) (PyString_FromString(x))
35#define PY_STRING_INTERN_FROM_STRING(x) (PyString_InternFromString(x))
36#define PY_STRING_AS_CSTR(x) (PyString_AsString(x))
37#else
38// Python 3.x:
39#define PY_STRING_CHECK(x) (PyBytes_Check(x) || PyUnicode_Check(x))
40#define PY_INT_AS_LONG(x) (PyLong_AsLong(x))
41#define PY_STRING_FROMSTRING(x) (PyUnicode_FromString(x))
42#define PY_STRING_INTERN_FROM_STRING(x) (PyUnicode_InternFromString(x))
43#define PY_STRING_AS_CSTR(x) (PyUnicode_AsUTF8AndSize((x), nullptr))
44#endif
45
46namespace {
47
48// Converts the given object to an interned Python string, and returns its
49// data pointer. (This means we don't need to worry about ownership for
50// this string.)
51const char* InternPyString(const std::string& s) {
52 Safe_PyObjectPtr interned(PY_STRING_INTERN_FROM_STRING(s.c_str()));
53 return PY_STRING_AS_CSTR(interned.get());
54}
55
56template <typename T, typename UnaryPredicate>
57void RemoveIf(UnaryPredicate p, std::vector<T>* vec) {
58 vec->erase(std::remove_if(vec->begin(), vec->end(), p), vec->end());
59}
60
61struct DataTypeFormatter {
62 void operator()(std::string* out, DataType dtype) const {
63 out->append(DataType_Name(dtype));
64 }
65};
66
67// Populates `param_names` and `defaults_tuple` based on the given OpDef.
68void GetOpDefNamesAndDefaults(const tensorflow::OpDef& op_def,
69 std::vector<string>& param_names,
70 Safe_PyObjectPtr& defaults_tuple) {
71 param_names.reserve(op_def.input_arg_size() + op_def.attr_size());
72 std::set<std::string> inferred_attrs;
73
74 // Input parameters come first, in the order they occur in the OpDef.
75 for (const auto& input : op_def.input_arg()) {
76 param_names.push_back(input.name());
77 if (!input.type_attr().empty()) {
78 inferred_attrs.insert(input.type_attr());
79 }
80 if (!input.type_list_attr().empty()) {
81 inferred_attrs.insert(input.type_list_attr());
82 }
83 if (!input.number_attr().empty()) {
84 inferred_attrs.insert(input.number_attr());
85 }
86 }
87
88 // Next come attribute params without defaults, followed by attributes with
89 // defaults (but inferred attributes are not included).
90 std::vector<std::string> param_names_with_default;
91 std::vector<Safe_PyObjectPtr> defaults;
92 for (const auto& attr : op_def.attr()) {
93 if (inferred_attrs.count(attr.name()) == 0) {
94 if (attr.has_default_value()) {
95 param_names_with_default.push_back(attr.name());
96 defaults.push_back(AttrValueToPyObject(attr.default_value()));
97 } else {
98 param_names.push_back(attr.name());
99 }
100 }
101 }
102 param_names.insert(param_names.end(), param_names_with_default.begin(),
103 param_names_with_default.end());
104
105 // Finally, the 'name' parameter comes at the end, and its default value
106 // is the operation's name.
107 param_names.push_back("name");
108 defaults.emplace_back(PY_STRING_FROMSTRING(op_def.name().c_str()));
109
110 defaults_tuple.reset(PyTuple_New(defaults.size()));
111 for (int i = 0; i < defaults.size(); ++i) {
112 PyTuple_SET_ITEM(defaults_tuple.get(), i, defaults[i].release());
113 }
114}
115
116} // namespace
117
118PythonAPIInfo::PythonAPIInfo(const std::string& api_name)
119 : api_name_(InternPyString(api_name)) {}
120
121Status PythonAPIInfo::Initialize(const OpDef& op_def,
122 const std::vector<string> param_names,
123 PyObject* defaults_tuple) {
124 // Intern the parameter names.
125 param_names_.reserve(param_names.size());
126 for (const auto& param_name : param_names) {
127 param_names_.push_back(InternPyString(param_name));
128 }
129
130 Py_INCREF(defaults_tuple);
131 defaults_tuple_.reset(defaults_tuple);
132
133 // Build an index to look up parameter index by name. (Does not include
134 // inferred attributes.)
135 std::map<std::string, int> param_name_to_index;
136 for (int i = 0; i < param_names_.size(); ++i) {
137 param_name_to_index[param_names_[i]] = i;
138 }
139
140 // Initialize each attribute & input parameter.
141 attributes_.reserve(op_def.attr_size());
142 for (const auto& attr_def : op_def.attr()) {
143 TF_RETURN_IF_ERROR(InitializeAttribute(attr_def, param_name_to_index));
144 }
145
146 inputs_.reserve(op_def.input_arg_size());
147 for (const auto& arg_def : op_def.input_arg()) {
148 TF_RETURN_IF_ERROR(InitializeInput(arg_def, param_name_to_index));
149 }
150
151 TF_RETURN_IF_ERROR(CheckParamNames());
152
153 // Filter out any unused entries from inputs_with_*_attrs_.
154 RemoveIf(
155 [](const InputsWithTypeAttr& input) {
156 return input.tensor_params.empty() && input.tensor_list_params.empty();
157 },
158 &inputs_with_type_attrs_);
159 RemoveIf(
160 [](const InputsWithTypeListAttr& input) {
161 return input.tensor_list_params.empty();
162 },
163 &inputs_with_type_list_attrs_);
164 RemoveIf(
165 [](const InputsWithNumberAttr& input) {
166 return input.tensor_list_params.empty();
167 },
168 &inputs_with_number_attrs_);
169
170 return OkStatus();
171}
172
173Status PythonAPIInfo::CheckParamNames() const {
174 std::vector<bool> param_found(param_names_.size());
175 for (const auto& attr : attributes_) {
176 if (attr.index != -1) {
177 param_found[attr.index] = true;
178 }
179 }
180 for (const auto& input : inputs_) {
181 param_found[input.index] = true;
182 }
183
184 for (int i = 0; i < param_names_.size(); ++i) {
185 if (param_names_[i] == std::string("name")) {
186 continue;
187 }
188 if (!param_found[i]) {
189 return errors::InvalidArgument(
190 api_name_, ": missing specification for parameter ", param_names_[i]);
191 }
192 }
193 return OkStatus();
194}
195
196Status PythonAPIInfo::InitializeFromRegisteredOp(const std::string& op_name) {
197 const tensorflow::OpDef* op_def = nullptr;
198 TF_RETURN_IF_ERROR(
199 tensorflow::OpRegistry::Global()->LookUpOpDef(op_name, &op_def));
200 std::vector<std::string> param_names;
201 Safe_PyObjectPtr defaults_tuple;
202 GetOpDefNamesAndDefaults(*op_def, param_names, defaults_tuple);
203 TF_RETURN_IF_ERROR(Initialize(*op_def, param_names, defaults_tuple.get()));
204 return OkStatus();
205}
206
207Status PythonAPIInfo::InitializeFromParamSpecs(
208 const std::map<std::string, std::string>& input_specs,
209 const std::map<std::string, std::string>& attr_specs,
210 const std::vector<string> param_names, PyObject* defaults_tuple) {
211 OpDefBuilder op_def_builder(api_name_);
212 op_def_builder.AllowAttrTypeAny();
213 for (const auto& attr_spec : attr_specs) {
214 op_def_builder.Attr(absl::StrCat(attr_spec.first, ": ", attr_spec.second));
215 }
216 for (const auto& input_spec : input_specs) {
217 op_def_builder.Input(
218 absl::StrCat(input_spec.first, ": ", input_spec.second));
219 }
220 OpRegistrationData op_reg_data;
221 TF_RETURN_IF_ERROR(op_def_builder.Finalize(&op_reg_data));
222
223 TF_RETURN_IF_ERROR(
224 Initialize(op_reg_data.op_def, param_names, defaults_tuple));
225
226 return OkStatus();
227}
228
229Status PythonAPIInfo::InitializeAttribute(
230 const OpDef::AttrDef& attr_def,
231 const std::map<std::string, int>& param_name_to_index) {
232 if (attr_def.name() == "name") {
233 return errors::InvalidArgument(
234 api_name_, ": Reserved parameter `name` was used as an attribute.");
235 }
236 const char* name = InternPyString(attr_def.name());
237
238 const int param_index =
239 gtl::FindWithDefault(param_name_to_index, attr_def.name(), -1);
240 const AttributeType dtype = AttributeTypeFromName(attr_def.type());
241 const int inferred_index = -1;
242 attributes_.push_back({param_index, dtype, name, inferred_index});
243 Attribute& attr = attributes_.back();
244 if (attr.type == AttributeType::UNKNOWN) {
245 return errors::InvalidArgument(api_name_, ": Bad attribute type for ",
246 attr_def.name(), ": '", attr_def.type(),
247 "'");
248 }
249 std::vector<DataType>* ok_dtypes = nullptr;
250
251 if (attr.type == AttributeType::DTYPE) {
252 DataType default_dtype = attr_def.has_default_value()
253 ? attr_def.default_value().type()
254 : DT_INVALID;
255 inputs_with_type_attrs_.push_back({&attr, default_dtype});
256 ok_dtypes = &inputs_with_type_attrs_.back().ok_dtypes;
257
258 } else if (attr.type == AttributeType::LIST_DTYPE) {
259 inputs_with_type_list_attrs_.push_back({&attr});
260 for (int d : attr_def.default_value().list().type()) {
261 inputs_with_type_list_attrs_.back().default_dtypes.push_back(
262 static_cast<DataType>(d));
263 }
264 ok_dtypes = &inputs_with_type_list_attrs_.back().ok_dtypes;
265 }
266
267 if (attr_def.has_allowed_values() && ok_dtypes) {
268 const auto& dtypes = attr_def.allowed_values().list();
269 for (int i = 0; i < dtypes.type_size(); ++i) {
270 ok_dtypes->push_back(dtypes.type(i));
271 }
272 }
273
274 if (attr.type == AttributeType::INT) {
275 int64_t default_len =
276 attr_def.has_default_value() ? attr_def.default_value().i() : -1;
277 inputs_with_number_attrs_.push_back({&attr, default_len});
278 }
279
280 // If this is an inferred attribute, then record its name and index.
281 if (attr.index == -1) {
282 std::vector<const char*>* inferred_attr_names =
283 attr.type == AttributeType::DTYPE ? &inferred_type_attrs_
284 : attr.type == AttributeType::LIST_DTYPE ? &inferred_type_list_attrs_
285 : attr.type == AttributeType::INT ? &inferred_length_attrs_
286 : nullptr;
287 if (inferred_attr_names == nullptr) {
288 return errors::InvalidArgument(
289 api_name_, ": Missing specification for parameter ", attr_def.name());
290 } else {
291 attr.inferred_index = inferred_attr_names->size();
292 inferred_attr_names->push_back(attr.name);
293 }
294 }
295
296 return OkStatus();
297}
298
299Status PythonAPIInfo::InitializeInput(
300 const OpDef::ArgDef& arg_def,
301 const std::map<std::string, ParamIndex>& param_name_to_index) {
302 if (arg_def.name() == "name") {
303 return errors::InvalidArgument(
304 api_name_, ": Reserved parameter `name` was used as a tensor input.");
305 }
306 const ParamIndex param_index =
307 gtl::FindWithDefault(param_name_to_index, arg_def.name(), -1);
308 if (param_index == -1) {
309 return errors::InvalidArgument(
310 api_name_, ": Missing specification for parameter ", arg_def.name());
311 }
312 if (arg_def.is_ref()) {
313 // TODO(b/164980194): Support reference parameters.
314 // - Pass as_ref to convert_to_tensor
315 // - Check that values for ref inputs have ref types.
316 return errors::InvalidArgument(api_name_,
317 ": PythonAPIInfo doesn't support reference "
318 "parameters yet.");
319 }
320 bool is_list =
321 !arg_def.number_attr().empty() || !arg_def.type_list_attr().empty();
322 inputs_.push_back({param_index, is_list});
323
324 if (!arg_def.type_list_attr().empty()) {
325 // list(input) with dtypes specified by a `list(type)` attribute.
326 InputsWithTypeListAttr* input =
327 FindInputsWithTypeListAttr(arg_def.type_list_attr());
328 if (!input) {
329 return errors::InvalidArgument(
330 api_name_, ": Type attribute ", arg_def.type_list_attr(),
331 " for parameter ", arg_def.name(), " not found.");
332 }
333 input->tensor_list_params.push_back(param_index);
334 } else if (!arg_def.type_attr().empty()) {
335 InputsWithTypeAttr* input = FindInputsWithTypeAttr(arg_def.type_attr());
336 // input or list(input) with dtype specified by a `type` attribute.
337 if (!input) {
338 return errors::InvalidArgument(api_name_, ": Type attribute ",
339 arg_def.type_attr(), " for parameter ",
340 arg_def.name(), " not found.");
341 }
342 if (arg_def.number_attr().empty()) {
343 input->tensor_params.push_back(param_index);
344 } else {
345 input->tensor_list_params.push_back(param_index);
346 }
347 } else {
348 // input or list(input) with fixed dtype
349 inputs_with_fixed_dtype_.push_back({arg_def.type(), param_index, is_list});
350 }
351
352 if (!arg_def.number_attr().empty()) {
353 InputsWithNumberAttr* input =
354 FindInputsWithNumberAttr(arg_def.number_attr());
355 if (!input) {
356 return errors::InvalidArgument(api_name_, ": Length attribute ",
357 arg_def.number_attr(), " for parameter ",
358 arg_def.name(), " not found.");
359 }
360 input->tensor_list_params.push_back(param_index);
361 }
362
363 return OkStatus();
364}
365
366PythonAPIInfo::InputsWithTypeAttr* PythonAPIInfo::FindInputsWithTypeAttr(
367 const string& name) {
368 for (auto& input : inputs_with_type_attrs_) {
369 if (name == input.type_attr->name) {
370 return &input;
371 }
372 }
373 return nullptr;
374}
375
376PythonAPIInfo::InputsWithTypeListAttr*
377PythonAPIInfo::FindInputsWithTypeListAttr(const string& name) {
378 for (auto& input : inputs_with_type_list_attrs_) {
379 if (name == input.type_list_attr->name) {
380 return &input;
381 }
382 }
383 return nullptr;
384}
385
386PythonAPIInfo::InputsWithNumberAttr* PythonAPIInfo::FindInputsWithNumberAttr(
387 const string& name) {
388 for (auto& input : inputs_with_number_attrs_) {
389 if (name == input.number_attr->name) {
390 return &input;
391 }
392 }
393 return nullptr;
394}
395
396string PythonAPIInfo::DebugInfo() const {
397 string s = absl::StrCat("DebugInfo for ", api_name_, ":\n");
398 absl::StrAppend(&s, " param_names=[", absl::StrJoin(param_names_, ", "),
399 "]\n");
400 Safe_PyObjectPtr defaults_repr(PyObject_Repr(defaults_tuple_.get()));
401 absl::StrAppend(
402 &s, " defaults_tuple=", TFE_GetPythonString(defaults_repr.get()), "\n");
403 if (!attributes_.empty()) {
404 absl::StrAppend(&s, " attributes=[");
405 for (const auto& attrib : attributes_) {
406 if (attrib.index != -1) {
407 absl::StrAppend(&s, "\n {index=", attrib.index);
408 DCHECK_EQ(attrib.inferred_index, -1);
409 } else {
410 absl::StrAppend(&s, "\n {inferred_index=", attrib.inferred_index);
411 }
412 absl::StrAppend(&s, ", name=", attrib.name,
413 ", type=", AttributeTypeToName(attrib.type), "},");
414 }
415 absl::StrAppend(&s, "]\n");
416 }
417 if (!inputs_.empty()) {
418 absl::StrAppend(&s, " inputs=[");
419 for (const auto& input : inputs_) {
420 absl::StrAppend(&s, "\n {index=", input.index,
421 ", name=", param_names_[input.index],
422 ", is_list=", input.is_list, "},");
423 }
424 absl::StrAppend(&s, "]\n");
425 }
426 if (!inputs_with_fixed_dtype_.empty()) {
427 absl::StrAppend(&s, " inputs_with_fixed_dtype=[");
428 for (const auto& input : inputs_with_fixed_dtype_) {
429 absl::StrAppend(&s, "\n {index=", input.index,
430 ", dtype=", DataType_Name(input.dtype),
431 ", is_list=", input.is_list, "},");
432 }
433 absl::StrAppend(&s, "]\n");
434 }
435 if (!inputs_with_type_attrs_.empty()) {
436 absl::StrAppend(&s, " inputs_with_type_attr=[");
437 for (const auto& input : inputs_with_type_attrs_) {
438 absl::StrAppend(&s, "\n {type_attr=", input.type_attr->name);
439 if (input.default_dtype != DT_INVALID) {
440 absl::StrAppend(&s,
441 ", default_dtype=", DataType_Name(input.default_dtype));
442 }
443 if (!input.tensor_params.empty()) {
444 absl::StrAppend(&s, ", tensor_params=[",
445 absl::StrJoin(input.tensor_params, ", "), "]");
446 }
447 if (!input.tensor_list_params.empty()) {
448 absl::StrAppend(&s, ", tensor_list_params=[",
449 absl::StrJoin(input.tensor_list_params, ", "), "]");
450 }
451 if (!input.ok_dtypes.empty()) {
452 absl::StrAppend(
453 &s, ", ok_dtypes=[",
454 absl::StrJoin(input.ok_dtypes, ", ", DataTypeFormatter()), "]");
455 }
456 absl::StrAppend(&s, "},");
457 }
458 absl::StrAppend(&s, "]\n");
459 }
460 if (!inputs_with_type_list_attrs_.empty()) {
461 absl::StrAppend(&s, " inputs_with_type_list_attrs=[");
462 for (const auto& input : inputs_with_type_list_attrs_) {
463 absl::StrAppend(&s, "\n {type_list_attr=", input.type_list_attr->name);
464 if (!input.default_dtypes.empty()) {
465 absl::StrAppend(
466 &s, ", default_dtypes=[",
467 absl::StrJoin(input.default_dtypes, ", ", DataTypeFormatter()),
468 "]");
469 }
470 if (!input.tensor_list_params.empty()) {
471 absl::StrAppend(&s, ", tensor_list_params=[",
472 absl::StrJoin(input.tensor_list_params, ", "), "]");
473 }
474 if (!input.ok_dtypes.empty()) {
475 absl::StrAppend(
476 &s, ", ok_dtypes=[",
477 absl::StrJoin(input.ok_dtypes, ", ", DataTypeFormatter()), "]");
478 }
479 absl::StrAppend(&s, "},");
480 }
481 absl::StrAppend(&s, "]\n");
482 }
483 if (!inputs_with_number_attrs_.empty()) {
484 absl::StrAppend(&s, " inputs_with_number_attrs=[");
485 for (const auto& input : inputs_with_number_attrs_) {
486 absl::StrAppend(&s, "\n {number_attr=", input.number_attr->name,
487 ", default_length=", input.default_length,
488 ", tensor_list_params=[",
489 absl::StrJoin(input.tensor_list_params, ", "), "],\n");
490 }
491 absl::StrAppend(&s, "]\n");
492 }
493 if (!inferred_type_attrs_.empty()) {
494 absl::StrAppend(&s, " inferred_type_attrs=[",
495 absl::StrJoin(inferred_type_attrs_, ", "), "]\n");
496 }
497 if (!inferred_type_list_attrs_.empty()) {
498 absl::StrAppend(&s, " inferred_type_list_attrs=[",
499 absl::StrJoin(inferred_type_list_attrs_, ", "), "]\n");
500 }
501 if (!inferred_length_attrs_.empty()) {
502 absl::StrAppend(&s, " inferred_length_attrs=[",
503 absl::StrJoin(inferred_length_attrs_, ", "), "]\n");
504 }
505 return s;
506}
507
508} // namespace tensorflow
509