1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/python/framework/python_op_gen.h"
16
17#include <stdio.h>
18
19#include <sstream>
20#include <string>
21#include <unordered_map>
22
23#include "absl/strings/escaping.h"
24#include "tensorflow/core/framework/api_def.pb.h"
25#include "tensorflow/core/framework/attr_value.pb.h"
26#include "tensorflow/core/framework/op.h"
27#include "tensorflow/core/framework/op_def.pb.h"
28#include "tensorflow/core/framework/op_def_util.h"
29#include "tensorflow/core/framework/op_gen_lib.h"
30#include "tensorflow/core/framework/tensor.pb.h"
31#include "tensorflow/core/framework/types.h"
32#include "tensorflow/core/framework/types.pb.h"
33#include "tensorflow/core/lib/gtl/map_util.h"
34#include "tensorflow/core/lib/strings/str_util.h"
35#include "tensorflow/core/lib/strings/strcat.h"
36#include "tensorflow/core/lib/strings/stringprintf.h"
37#include "tensorflow/core/platform/logging.h"
38#include "tensorflow/core/platform/macros.h"
39#include "tensorflow/core/platform/types.h"
40#include "tensorflow/python/framework/python_op_gen_internal.h"
41
42namespace tensorflow {
43namespace {
44
45const int kRightMargin = 78;
46
47constexpr char kEagerFallbackSuffix[] = "_eager_fallback";
48
49// Maps C++ dtype enum values to Python DType classes
50const std::unordered_map<string, string> dtype_type{
51 {"_dtypes.float16", "_dtypes.Float16"},
52 {"_dtypes.half", "_dtypes.Half"},
53 {"_dtypes.float32", "_dtypes.Float32"},
54 {"_dtypes.float64", "_dtypes.Float64"},
55 {"_dtypes.bfloat16", "_dtypes.BFloat16"},
56 {"_dtypes.complex64", "_dtypes.Complex64"},
57 {"_dtypes.complex128", "_dtypes.Complex128"},
58 {"_dtypes.int8", "_dtypes.Int8"},
59 {"_dtypes.uint8", "_dtypes.UInt8"},
60 {"_dtypes.uint16", "_dtypes.UInt16"},
61 {"_dtypes.uint32", "_dtypes.UInt32"},
62 {"_dtypes.uint64", "_dtypes.UInt64"},
63 {"_dtypes.int16", "_dtypes.Int16"},
64 {"_dtypes.int32", "_dtypes.Int32"},
65 {"_dtypes.int64", "_dtypes.Int64"},
66 {"_dtypes.bool", "_dtypes.Bool"},
67 {"_dtypes.string", "_dtypes.String"},
68 {"_dtypes.qint8", "_dtypes.QInt8"},
69 {"_dtypes.quint8", "_dtypes.QUInt8"},
70 {"_dtypes.qint16", "_dtypes.QInt16"},
71 {"_dtypes.quint16", "_dtypes.QUInt16"},
72 {"_dtypes.qint32", "_dtypes.QInt32"},
73 {"_dtypes.resource", "_dtypes.Resource"},
74 {"_dtypes.variant", "_dtypes.Variant"}};
75
76string AttrVarName(const string& attr_name,
77 std::unordered_map<string, string>* attr_expressions) {
78 const string var = strings::StrCat("_attr_", attr_name);
79 if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var;
80 return var;
81}
82
83void AddInferredAttr(const string& indentation, const string& attr_name,
84 const string& value_expression, string* result,
85 std::unordered_map<string, string>* attr_expressions) {
86 strings::StrAppend(result, indentation,
87 AttrVarName(attr_name, attr_expressions), " = ",
88 value_expression, "\n");
89}
90
91string VectorToTuple(const std::vector<string>& l) {
92 if (l.size() == 1) return strings::StrCat("(", l.front(), ",)");
93 string ret = "(";
94 for (int i = 0, end = l.size(); i < end; ++i) {
95 if (i > 0) {
96 strings::StrAppend(&ret, ", ");
97 }
98 strings::StrAppend(&ret, l[i]);
99 }
100 strings::StrAppend(&ret, ")");
101 return ret;
102}
103
104void Unflatten(const string& prefix, const std::vector<string>& output_sizes,
105 const string& var, string* result) {
106 for (int i = 0, end = output_sizes.size(); i < end; ++i) {
107 if (!output_sizes[i].empty()) {
108 strings::StrAppend(result, prefix, var, " = ");
109 if (i > 0) strings::StrAppend(result, var, "[:", i, "] + ");
110 if (i + 1 < end) {
111 // Special case i == 0 to avoid "0 +" in the generated code.
112 if (i == 0) {
113 strings::StrAppend(result, "[", var, "[:", output_sizes[i], "]] + ",
114 var, "[", output_sizes[i], ":]");
115 } else {
116 strings::StrAppend(result, "[", var, "[", i, ":", i, " + ",
117 output_sizes[i], "]] + ", var, "[", i, " + ",
118 output_sizes[i], ":]");
119 }
120 } else {
121 strings::StrAppend(result, "[", var, "[", i, ":]]");
122 }
123 strings::StrAppend(result, "\n");
124 }
125 }
126}
127
128string TensorPBString(const TensorProto& pb) {
129 // Explicitly not using ShortDebugString, because ShortDebugString should
130 // not be used as a format for transporting information (it's e.g. subject
131 // to redaction of sensitive information). There is a PrintShortTextProto
132 // helper, but it's not feasible to depend on that library).
133
134 std::string message_short_text;
135
136 ::tensorflow::protobuf::TextFormat::Printer printer;
137 printer.SetSingleLineMode(true);
138 printer.SetExpandAny(true);
139
140 printer.PrintToString(pb, &message_short_text);
141
142 // Note: This gets used in the argument list, and so must survive naive
143 // word wrapping.
144 return strings::StrCat("\"\"\"", message_short_text, "\"\"\"");
145}
146
147class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp {
148 public:
149 GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
150 const string& function_name, bool add_type_annotations)
151 : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name,
152 add_type_annotations) {
153 op_name_ = function_name_;
154 absl::ConsumePrefix(&op_name_, "_");
155 }
156 ~GenEagerPythonOp() override {}
157
158 string Code() override;
159
160 protected:
161 void HandleGraphMode(const string& function_setup,
162 const std::vector<string>& output_sizes);
163
164 string GetEagerNotAllowedError();
165 void ExpectListArg(const string& indentation, const string& arg_name,
166 string* output);
167 bool GetEagerFunctionSetup(const string& indentation, string* function_setup);
168 void GetOutputSizesAndNumOutputsExpr(std::vector<string>* output_sizes,
169 string* num_outputs_expr);
170
171 void AddEagerFunctionTeardown(const string& indentation,
172 const std::vector<string>& output_sizes,
173 bool execute_record_gradient);
174
175 bool AddEagerFastPathAndGraphCode(
176 const string& parameters, const std::vector<string>& output_sizes,
177 const string& eager_not_allowed_error,
178 const std::unordered_map<string, string>& type_annotations);
179 bool AddEagerFallbackCode(
180 const string& parameters, const std::vector<string>& output_sizes,
181 const string& num_outputs_expr, const string& eager_not_allowed_error,
182 const std::unordered_map<string, string>& type_annotations);
183 void AddEagerFastPathExecute();
184
185 void AddEagerInferredAttrs(const string& indentation);
186 void AddEagerInputCasts(const string& indentation);
187 void AddEagerAttrs(const string& indentation);
188 void AddEagerExecute(const string& indentation,
189 const string& num_outputs_expr);
190 void AddFallbackDispatch(const string& prefix);
191 void AddTypeBasedDispatch(const string& prefix);
192 void AddTypeBasedDispatcherAlias();
193
194 void AddRawOpExport(const string& parameters);
195
196 std::unordered_map<string, string> GetTypeAnnotations();
197
198 void GenerateTypeVars(
199 const std::unordered_map<string, string>& type_annotations);
200
201 void AddReturnTypeAnnotation(
202 const std::unordered_map<string, string>& type_annotations);
203
204 void AddAttrForArg(const string& attr, int arg_index) {
205 gtl::InsertIfNotPresent(&inferred_attrs_, attr,
206 op_def_.input_arg(arg_index).name());
207 auto iter = attr_to_args_.find(attr);
208 if (iter == attr_to_args_.end()) {
209 attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index}));
210 } else {
211 iter->second.push_back(arg_index);
212 }
213 }
214
215 // Returns a string expression representing a flattened list of all
216 // the inputs given by `*input_indices` (or all inputs if
217 // `input_indices` is nullptr). `*output_sizes` can be used to unflatten.
218 string FlattenInputs(const std::vector<int>* input_indices,
219 std::vector<string>* output_sizes) const;
220
221 StringPiece op_name_;
222 typedef std::unordered_map<string, std::vector<int>> AttrToArgMap;
223 AttrToArgMap attr_to_args_;
224 std::unordered_map<string, string> attr_expressions_;
225 // This has all the input args followed by those attrs that don't have
226 // defaults.
227 std::vector<python_op_gen_internal::ParamNames> params_no_default_;
228 // The parameters with defaults (these have to be listed after those without).
229 // No input args are included, just attrs.
230 std::vector<std::pair<python_op_gen_internal::ParamNames, string>>
231 params_with_default_;
232};
233
234string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def,
235 const string& function_name,
236 bool add_type_annotations) {
237 return GenEagerPythonOp(op_def, api_def, function_name, add_type_annotations)
238 .Code();
239}
240
241string GenEagerPythonOp::FlattenInputs(
242 const std::vector<int>* input_indices,
243 std::vector<string>* output_sizes) const {
244 string inputs;
245 enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING;
246 const int n = input_indices != nullptr ? input_indices->size()
247 : op_def_.input_arg_size();
248 for (int j = 0; j < n; ++j) {
249 const int i = input_indices ? (*input_indices)[j] : j;
250 const auto& arg(op_def_.input_arg(i));
251 const bool is_list =
252 !arg.type_list_attr().empty() || !arg.number_attr().empty();
253 if (is_list) {
254 if (inputs_state == WAS_SOLO_INPUT) {
255 strings::StrAppend(&inputs, "] + ");
256 } else if (inputs_state == WAS_LIST_INPUT) {
257 strings::StrAppend(&inputs, " + ");
258 }
259 strings::StrAppend(&inputs, "list(", param_names_[i].GetRenameTo(), ")");
260 inputs_state = WAS_LIST_INPUT;
261 if (output_sizes != nullptr) {
262 if (!arg.number_attr().empty()) {
263 output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr));
264 } else {
265 output_sizes->emplace_back(
266 strings::StrCat("len(", param_names_[i].GetRenameTo(), ")"));
267 }
268 }
269 } else {
270 if (inputs_state == WAS_SOLO_INPUT) {
271 strings::StrAppend(&inputs, ", ");
272 } else if (inputs_state == WAS_LIST_INPUT) {
273 strings::StrAppend(&inputs, " + [");
274 } else {
275 strings::StrAppend(&inputs, "[");
276 }
277 strings::StrAppend(&inputs, param_names_[i].GetRenameTo());
278 inputs_state = WAS_SOLO_INPUT;
279 if (output_sizes != nullptr) output_sizes->emplace_back();
280 }
281 }
282 if (inputs_state == STARTING) return "[]";
283 if (inputs_state == WAS_SOLO_INPUT) {
284 strings::StrAppend(&inputs, "]");
285 }
286 return inputs;
287}
288
289string GenEagerPythonOp::Code() {
290 if (api_def_.visibility() == ApiDef::SKIP) {
291 return "";
292 }
293
294 for (int i = 0; i < api_def_.arg_order_size(); ++i) {
295 const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_);
296 const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_);
297 params_no_default_.emplace_back(api_def_arg.name(),
298 api_def_arg.rename_to());
299 if (!arg.type_attr().empty()) {
300 AddAttrForArg(arg.type_attr(), i);
301 } else if (!arg.type_list_attr().empty()) {
302 AddAttrForArg(arg.type_list_attr(), i);
303 }
304 if (!arg.number_attr().empty()) {
305 AddAttrForArg(arg.number_attr(), i);
306 }
307 }
308 for (int i = 0; i < op_def_.attr_size(); ++i) {
309 const auto& attr(op_def_.attr(i));
310 const auto& api_def_attr(api_def_.attr(i));
311 // Do not add inferred attrs to the Python function signature.
312 if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
313 if (api_def_attr.has_default_value()) {
314 if (attr.type() == "tensor") {
315 params_with_default_.emplace_back(
316 python_op_gen_internal::ParamNames(api_def_attr.name(),
317 api_def_attr.rename_to()),
318 strings::StrCat(
319 "_execute.make_tensor(",
320 TensorPBString(api_def_attr.default_value().tensor()), ", \"",
321 api_def_attr.rename_to(), "\")"));
322 } else if (attr.type() == "list(tensor)") {
323 std::vector<string> pbtxt;
324 for (const auto& pb : api_def_attr.default_value().list().tensor()) {
325 pbtxt.emplace_back(TensorPBString(pb));
326 }
327 params_with_default_.emplace_back(
328 python_op_gen_internal::ParamNames(api_def_attr.name(),
329 api_def_attr.rename_to()),
330 strings::StrCat("[_execute.make_tensor(_pb, \"",
331 api_def_attr.rename_to(), "\") for _pb in ",
332 VectorToTuple(pbtxt), "]"));
333 } else {
334 params_with_default_.emplace_back(
335 python_op_gen_internal::ParamNames(api_def_attr.name(),
336 api_def_attr.rename_to()),
337 python_op_gen_internal::AttrValueToPython(
338 attr.type(), api_def_attr.default_value(), "_dtypes."));
339 }
340 } else {
341 params_no_default_.emplace_back(api_def_attr.name(),
342 api_def_attr.rename_to());
343 }
344 }
345 }
346
347 // Save the list of attr parameters (attrs that won't be inferred),
348 // those with defaults go at the end.
349 // Get the attrs in the order we want by taking the attrs without defaults
350 // from the end of params_no_default_, and adding params_no_default_.
351 attrs_.reserve(params_no_default_.size() - op_def_.input_arg_size() +
352 params_with_default_.size());
353 for (int i = op_def_.input_arg_size(), end = params_no_default_.size();
354 i < end; ++i) {
355 attrs_.push_back(params_no_default_[i].GetName());
356 }
357 for (const auto& p : params_with_default_) {
358 attrs_.push_back(p.first.GetName());
359 }
360
361 // TODO(slebedev): call AvoidPythonReserved on each param?
362 param_names_.reserve(params_no_default_.size() + params_with_default_.size());
363 param_names_.insert(param_names_.begin(), params_no_default_.begin(),
364 params_no_default_.end());
365 for (const auto& param_and_default : params_with_default_) {
366 param_names_.push_back(param_and_default.first);
367 }
368
369 std::unordered_map<string, string> type_annotations;
370 // Only populate map for allowlisted ops
371 if (add_type_annotations_) {
372 type_annotations = GetTypeAnnotations();
373 }
374
375 string parameters;
376 // Param can be an input or an attr
377 for (const auto& param : params_no_default_) {
378 if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
379 strings::StrAppend(&parameters, param.GetRenameTo());
380
381 if (type_annotations.find(param.GetName()) != type_annotations.end()) {
382 strings::StrAppend(&parameters, ": ",
383 type_annotations.at(param.GetName()));
384 }
385 }
386
387 string parameters_with_defaults = parameters;
388 for (const auto& param_and_default : params_with_default_) {
389 if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
390 if (!parameters_with_defaults.empty())
391 strings::StrAppend(&parameters_with_defaults, ", ");
392
393 strings::StrAppend(&parameters, param_and_default.first.GetRenameTo());
394 strings::StrAppend(&parameters_with_defaults,
395 param_and_default.first.GetRenameTo());
396 if (type_annotations.find(param_and_default.first.GetName()) !=
397 type_annotations.end()) {
398 const string param_type =
399 type_annotations.at(param_and_default.first.GetName());
400 // Append to parameters and parameters_with_defaults because multiple
401 // functions are generated by AddEagerFastPathAndGraphCode() and
402 // AddEagerFallbackCode()
403 strings::StrAppend(&parameters, ": ", param_type);
404 strings::StrAppend(&parameters_with_defaults, ":", param_type);
405 }
406
407 strings::StrAppend(&parameters_with_defaults, "=",
408 param_and_default.second);
409 }
410
411 strings::StrAppend(&parameters, parameters.empty() ? "" : ", ", "name");
412 strings::StrAppend(&parameters_with_defaults,
413 parameters_with_defaults.empty() ? "" : ", ", "name=None");
414
415 // Add attr_expressions_ for attrs that are params.
416 for (int i = 0, end = attrs_.size(); i < end; ++i) {
417 const string& attr_name = attrs_[i];
418 const string& attr_api_name =
419 param_names_[i + op_def_.input_arg_size()].GetRenameTo();
420 attr_expressions_[attr_name] = attr_api_name;
421 }
422 // Add attr_expressions_ for attrs that are inferred.
423 for (int i = 0; i < op_def_.attr_size(); ++i) {
424 const auto& attr(op_def_.attr(i));
425 if (attr.type() == "int") {
426 auto arg_list = attr_to_args_.find(attr.name());
427 if (arg_list != attr_to_args_.end()) {
428 AttrVarName(attr.name(), &attr_expressions_);
429 }
430 }
431 }
432
433 string num_outputs_expr;
434 std::vector<string> output_sizes(num_outs_);
435 GetOutputSizesAndNumOutputsExpr(&output_sizes, &num_outputs_expr);
436
437 string eager_not_allowed_error = GetEagerNotAllowedError();
438
439 if (!AddEagerFastPathAndGraphCode(parameters_with_defaults, output_sizes,
440 eager_not_allowed_error,
441 type_annotations)) {
442 return result_;
443 }
444
445 if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr,
446 eager_not_allowed_error, type_annotations)) {
447 return result_;
448 }
449
450 return prelude_ + result_;
451}
452
453std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotations() {
454 std::unordered_map<string, string> type_annotations;
455 // Map attrs to TypeVars
456 for (const auto& attr : op_def_.attr()) {
457 if (attr.type() == "type") {
458 const string type_var_name = "TV_" + op_def_.name() + "_" + attr.name();
459 type_annotations[attr.name()] = type_var_name;
460 } else if (attr.type() == "bool" || attr.type() == "float" ||
461 attr.type() == "int" || attr.type() == "bytes") {
462 type_annotations[attr.name()] = attr.type();
463 } else if (attr.type() == "string") {
464 type_annotations[attr.name()] = "str";
465 }
466 }
467
468 // Map input Tensors to their types
469 for (const auto& arg : op_def_.input_arg()) {
470 // TODO(rahulkamat): Add type annotations to args that accept a sequence of
471 // Tensors
472 if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) continue;
473 type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
474 }
475
476 // TODO(rahulkamat): Add type annotations to handle return types of a sequence
477 // of Tensors. Map output Tensor to its type
478 if (op_def_.output_arg_size() == 1) {
479 const auto& arg = op_def_.output_arg(0);
480 if (arg.number_attr().empty() && arg.type_list_attr().empty())
481 type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations);
482 }
483
484 return type_annotations;
485}
486
487// Generate TypeVars using attrs
488void GenEagerPythonOp::GenerateTypeVars(
489 const std::unordered_map<string, string>& type_annotations) {
490 bool added_typevar = false;
491 for (const auto& attr : op_def_.attr()) {
492 if (attr.type() == "type") {
493 std::vector<string> allowed_types;
494 for (int t : attr.allowed_values().list().type()) {
495 DataType dtype = static_cast<DataType>(t);
496 const string py_dtype =
497 python_op_gen_internal::DataTypeToPython(dtype, "_dtypes.");
498 allowed_types.emplace_back(dtype_type.at(py_dtype));
499 }
500
501 // When a Tensor does not have any dtypes specified, all dtypes are
502 // allowed
503 if (allowed_types.empty()) {
504 for (std::pair<string, string> map_dtype : dtype_type) {
505 allowed_types.emplace_back(map_dtype.second);
506 }
507 }
508
509 std::sort(allowed_types.begin(), allowed_types.end());
510
511 string typevar_dtypes;
512 for (std::vector<string>::iterator it = allowed_types.begin();
513 it != allowed_types.end(); ++it) {
514 if (!typevar_dtypes.empty()) strings::StrAppend(&typevar_dtypes, ", ");
515 strings::StrAppend(&typevar_dtypes, *it);
516 }
517
518 const string type_var_name = type_annotations.at(attr.name());
519 strings::StrAppend(&result_, type_var_name, " = TypeVar(\"",
520 type_var_name, "\", ", typevar_dtypes, ")\n");
521 added_typevar = true;
522 }
523 }
524
525 if (added_typevar) strings::StrAppend(&result_, "\n");
526}
527
528void GenEagerPythonOp::AddReturnTypeAnnotation(
529 const std::unordered_map<string, string>& type_annotations) {
530 if (op_def_.output_arg_size() == 1) {
531 const auto& arg = op_def_.output_arg(0);
532 if (arg.number_attr().empty() && arg.type_list_attr().empty()) {
533 const string return_type = type_annotations.at(arg.name());
534 // TODO(rahulkamat): Modify AddDefLine() to add return type annotation to
535 // avoid erasing ":\n" from the end of the def line
536 result_.erase(result_.length() - 2);
537 strings::StrAppend(&result_, " -> ", return_type, ":\n");
538 }
539 }
540}
541
542void GenEagerPythonOp::HandleGraphMode(
543 const string& function_setup, const std::vector<string>& output_sizes) {
544 if (api_def_.visibility() == ApiDef::VISIBLE) {
545 strings::StrAppend(&result_, " else:\n");
546 AddTypeBasedDispatch(" ");
547 }
548 strings::StrAppend(&result_, " # Add nodes to the TensorFlow graph.\n");
549 strings::StrAppend(&result_, function_setup);
550 if (api_def_.visibility() == ApiDef::VISIBLE) {
551 strings::StrAppend(&result_, " try:\n ");
552 }
553 strings::StrAppend(
554 &result_, " _, _, _op, _outputs = _op_def_library._apply_op_helper(\n");
555 AddBodyNoReturn(strings::StrCat(" \"", op_def_.name(), "\", "));
556 AddFallbackDispatch(" ");
557
558 if (num_outs_ > 0) {
559 strings::StrAppend(&result_, " _result = _outputs[:]\n");
560 // Special case handling for stateful op with single list output
561 // that might be empty.
562 if (num_outs_ == 1 && op_def_.is_stateful() &&
563 (!op_def_.output_arg(0).number_attr().empty() ||
564 !op_def_.output_arg(0).type_list_attr().empty())) {
565 // TODO(josh11b): Can skip this if the number_attr/type_list_attr has
566 // a constraint indicating that this can never be empty.
567 strings::StrAppend(&result_,
568 " if not _result:\n"
569 " return _op\n");
570 }
571
572 // Compute graph-mode attrs when we need to record a gradient.
573 strings::StrAppend(&result_, " if _execute.must_record_gradient():\n");
574 if (op_def_.attr_size() > 0) {
575 string attr_values;
576 for (int i = 0; i < op_def_.attr_size(); ++i) {
577 if (i > 0) strings::StrAppend(&attr_values, ", ");
578 const auto& attr_name(op_def_.attr(i).name());
579 if (op_def_.attr(i).type() == "type") {
580 strings::StrAppend(&attr_values, "\"", attr_name,
581 "\", _op._get_attr_type(\"", attr_name, "\")");
582 } else if (op_def_.attr(i).type() == "bool") {
583 strings::StrAppend(&attr_values, "\"", attr_name,
584 "\", _op._get_attr_bool(\"", attr_name, "\")");
585 } else if (op_def_.attr(i).type() == "int") {
586 strings::StrAppend(&attr_values, "\"", attr_name,
587 "\", _op._get_attr_int(\"", attr_name, "\")");
588 } else {
589 strings::StrAppend(&attr_values, "\"", attr_name,
590 "\", _op.get_attr(\"", attr_name, "\")");
591 }
592 }
593 strings::StrAppend(&attr_values, ")");
594 strings::StrAppend(&result_,
595 WordWrap(" _attrs = (", attr_values, kRightMargin),
596 "\n");
597
598 } else {
599 strings::StrAppend(&result_, " _attrs = ()\n");
600 }
601
602 strings::StrAppend(&result_, " _inputs_flat = _op.inputs\n");
603 strings::StrAppend(&result_, " _execute.record_gradient(\n",
604 " \"", op_def_.name(),
605 "\", _inputs_flat, _attrs, _result)\n");
606
607 if (num_outs_ == 1 && !output_sizes[0].empty()) {
608 // Single list result.
609 } else if (num_outs_ == 1) {
610 // Execute returns a single-element list which we need to destructure.
611 strings::StrAppend(&result_, " ", "_result, = _result\n");
612 } else {
613 // Have multiple outputs, so we will need to reformat the return
614 // value of execute() to be a list with one entry per op output
615 // (that entry will be a list of tensors if that output is of list
616 // type).
617 // For list outputs, convert the right subrange of _result into a list.
618 Unflatten(" ", output_sizes, "_result", &result_);
619 // Convert to a named tuple.
620 strings::StrAppend(
621 &result_, " _result = _",
622 python_op_gen_internal::AvoidPythonReserved(op_def_.name()),
623 "Output._make(_result)\n");
624 }
625 strings::StrAppend(&result_, " return _result\n\n");
626 } else {
627 strings::StrAppend(&result_, " return _op\n");
628 }
629}
630
631string GenEagerPythonOp::GetEagerNotAllowedError() {
632 bool eager_allowed = true;
633 string ref_arg;
634 for (int i = 0; i < op_def_.input_arg_size(); ++i) {
635 const auto& arg = op_def_.input_arg(i);
636 if (arg.is_ref()) {
637 eager_allowed = false;
638 DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name());
639 ref_arg = api_def_.in_arg(i).rename_to();
640 }
641 }
642 for (int i = 0; i < op_def_.output_arg_size(); ++i) {
643 const auto& arg = op_def_.output_arg(i);
644 if (arg.is_ref()) {
645 eager_allowed = false;
646 DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name());
647 ref_arg = api_def_.out_arg(i).rename_to();
648 }
649 }
650
651 if (eager_allowed) return "";
652
653 return strings::StrCat("raise RuntimeError(\"", op_name_,
654 " op does not support eager execution. ", "Arg '",
655 ref_arg, "' is a ref.\")\n");
656}
657
658void GenEagerPythonOp::ExpectListArg(const string& indentation,
659 const string& arg_name, string* output) {
660 strings::StrAppend(output, indentation, "if not isinstance(", arg_name,
661 ", (list, tuple)):\n", indentation, " raise TypeError(\n",
662 indentation, " \"Expected list for '", arg_name,
663 "' argument to \"\n", indentation, " \"'", op_name_,
664 "' Op, not %r.\" % ", arg_name, ")\n");
665}
666
667bool GenEagerPythonOp::GetEagerFunctionSetup(const string& indentation,
668 string* function_setup) {
669 // Validate list inputs, infer length attrs.
670 for (int i = 0; i < op_def_.attr_size(); ++i) {
671 const auto& attr(op_def_.attr(i));
672 if (attr.type() == "int") {
673 auto arg_list = attr_to_args_.find(attr.name());
674 if (arg_list != attr_to_args_.end()) {
675 // Inferred int attrs are the lengths of inputs. Validate those
676 // inputs are lists and have the same length.
677 for (auto iter = arg_list->second.begin();
678 iter != arg_list->second.end(); ++iter) {
679 const string& arg_api_name = param_names_[*iter].GetRenameTo();
680 ExpectListArg(indentation, arg_api_name, function_setup);
681 if (iter == arg_list->second.begin()) {
682 AddInferredAttr(indentation, attr.name(),
683 strings::StrCat("len(", arg_api_name, ")"),
684 function_setup, &attr_expressions_);
685 } else {
686 const auto& attr_var = attr_expressions_[attr.name()];
687 strings::StrAppend(
688 function_setup, indentation, "if len(", arg_api_name,
689 ") != ", attr_var, ":\n", indentation, " raise ValueError(\n",
690 indentation, " \"List argument '", arg_api_name, "' to '",
691 op_name_, "' Op with length %d \"\n", indentation,
692 " \"must match length %d of argument '",
693 inferred_attrs_[attr.name()], "'.\" %\n", indentation,
694 " (len(", arg_api_name, "), ", attr_var, "))\n");
695 }
696 }
697 }
698 }
699 }
700
701 for (int i = 0, end = attrs_.size(); i < end; ++i) {
702 const string& attr_name = attrs_[i];
703 const auto& param = param_names_[i + op_def_.input_arg_size()];
704 const auto& attr = *FindAttr(attr_name, op_def_);
705 const string& attr_api_name = param.GetRenameTo();
706 StringPiece attr_type = attr.type();
707 attr_expressions_[attr_name] = attr_api_name;
708 const int default_index = i - (attrs_.size() - params_with_default_.size());
709 if (default_index >= 0) {
710 const string& default_value = params_with_default_[default_index].second;
711 strings::StrAppend(function_setup, indentation, "if ", attr_api_name,
712 " is None:\n");
713 strings::StrAppend(function_setup, indentation, " ", attr_api_name,
714 " = ", default_value, "\n");
715 }
716 if (absl::StartsWith(attr_type, "list(")) {
717 ExpectListArg(indentation, attr_api_name, function_setup);
718 }
719
720 if (attr_type == "string") {
721 strings::StrAppend(function_setup, indentation, attr_api_name,
722 " = _execute.make_str(", attr_api_name, ", \"",
723 attr_api_name, "\")\n");
724 } else if (attr_type == "list(string)") {
725 strings::StrAppend(function_setup, indentation, attr_api_name,
726 " = [_execute.make_str(_s, \"", attr_api_name,
727 "\") for _s in ", attr_api_name, "]\n");
728 } else if (attr_type == "int") {
729 strings::StrAppend(function_setup, indentation, attr_api_name,
730 " = _execute.make_int(", attr_api_name, ", \"",
731 attr_api_name, "\")\n");
732 } else if (attr_type == "list(int)") {
733 strings::StrAppend(function_setup, indentation, attr_api_name,
734 " = [_execute.make_int(_i, \"", attr_api_name,
735 "\") for _i in ", attr_api_name, "]\n");
736 } else if (attr_type == "float") {
737 strings::StrAppend(function_setup, indentation, attr_api_name,
738 " = _execute.make_float(", attr_api_name, ", \"",
739 attr_api_name, "\")\n");
740 } else if (attr_type == "list(float)") {
741 strings::StrAppend(function_setup, indentation, attr_api_name,
742 " = [_execute.make_float(_f, \"", attr_api_name,
743 "\") for _f in ", attr_api_name, "]\n");
744 } else if (attr_type == "bool") {
745 strings::StrAppend(function_setup, indentation, attr_api_name,
746 " = _execute.make_bool(", attr_api_name, ", \"",
747 attr_api_name, "\")\n");
748 } else if (attr_type == "list(bool)") {
749 strings::StrAppend(function_setup, indentation, attr_api_name,
750 " = [_execute.make_bool(_b, \"", attr_api_name,
751 "\") for _b in ", attr_api_name, "]\n");
752 } else if (attr_type == "type") {
753 strings::StrAppend(function_setup, indentation, attr_api_name,
754 " = _execute.make_type(", attr_api_name, ", \"",
755 attr_api_name, "\")\n");
756 } else if (attr_type == "list(type)") {
757 strings::StrAppend(function_setup, indentation, attr_api_name,
758 " = [_execute.make_type(_t, \"", attr_api_name,
759 "\") for _t in ", attr_api_name, "]\n");
760 } else if (attr_type == "shape") {
761 strings::StrAppend(function_setup, indentation, attr_api_name,
762 " = _execute.make_shape(", attr_api_name, ", \"",
763 attr_api_name, "\")\n");
764 } else if (attr_type == "list(shape)") {
765 strings::StrAppend(function_setup, indentation, attr_api_name,
766 " = [_execute.make_shape(_s, \"", attr_api_name,
767 "\") for _s in ", attr_api_name, "]\n");
768 } else if (attr_type == "tensor") {
769 strings::StrAppend(function_setup, indentation, attr_api_name,
770 " = _execute.make_tensor(", attr_api_name, ", \"",
771 attr_api_name, "\")\n");
772 } else if (attr_type == "list(tensor)") {
773 strings::StrAppend(function_setup, indentation, attr_api_name,
774 " = [_execute.make_tensor(_t, \"", attr_api_name,
775 "\") for _t in ", attr_api_name, "]\n");
776 } else if (attr_type != "func" && attr_type != "list(func)") {
777 *function_setup =
778 strings::StrCat("# No definition for ", function_name_,
779 " since we don't support attrs with type\n"
780 "# '",
781 attr_type, "' right now.\n\n");
782 return false;
783 }
784 }
785 return true;
786}
787
788// If output i is list output, output_sizes[i] will be set to a
789// string with the python expression that will evaluate to its
790// length. output_sizes[i] is empty for non-list outputs.
791void GenEagerPythonOp::GetOutputSizesAndNumOutputsExpr(
792 std::vector<string>* output_sizes, string* num_outputs_expr) {
793 // Expression representing the number of outputs.
794 int num_fixed_outputs = 0;
795 for (int i = 0; i < num_outs_; ++i) {
796 const auto& arg(op_def_.output_arg(i));
797 if (!arg.number_attr().empty()) {
798 if (!num_outputs_expr->empty()) {
799 strings::StrAppend(num_outputs_expr, " + ");
800 }
801 (*output_sizes)[i] = attr_expressions_[arg.number_attr()];
802 strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
803 } else if (!arg.type_list_attr().empty()) {
804 if (!num_outputs_expr->empty()) {
805 strings::StrAppend(num_outputs_expr, " + ");
806 }
807 // Have to be careful to use an expression that works in both
808 // graph and eager paths here.
809 const auto iter = inferred_attrs_.find(arg.type_list_attr());
810 if (iter == inferred_attrs_.end()) {
811 (*output_sizes)[i] = strings::StrCat(
812 "len(", attr_expressions_[arg.type_list_attr()], ")");
813 } else {
814 (*output_sizes)[i] = strings::StrCat("len(", iter->second, ")");
815 }
816 strings::StrAppend(num_outputs_expr, (*output_sizes)[i]);
817 } else {
818 ++num_fixed_outputs;
819 }
820 }
821 if (num_fixed_outputs > 0) {
822 if (!num_outputs_expr->empty()) {
823 strings::StrAppend(num_outputs_expr, " + ");
824 }
825 strings::StrAppend(num_outputs_expr, num_fixed_outputs);
826 } else if (num_outputs_expr->empty()) {
827 *num_outputs_expr = "0";
828 }
829}
830
831void GenEagerPythonOp::AddEagerFunctionTeardown(
832 const string& indentation, const std::vector<string>& output_sizes,
833 bool execute_record_gradient) {
834 if (num_outs_ > 0) {
835 if (execute_record_gradient) {
836 strings::StrAppend(&result_, indentation,
837 "if _execute.must_record_gradient():\n");
838 strings::StrAppend(&result_, indentation, " _execute.record_gradient(\n",
839 " \"", op_def_.name(),
840 "\", _inputs_flat, _attrs, _result)\n");
841 }
842 if (num_outs_ == 1 && !output_sizes[0].empty()) {
843 // Single list result.
844 } else if (num_outs_ == 1) {
845 // Execute returns a single-element list which we need to destructure.
846 strings::StrAppend(&result_, indentation, "_result, = _result\n");
847 } else {
848 // Have multiple outputs, so we will need to reformat the return
849 // value of execute() to be a list with one entry per op output
850 // (that entry will be a list of tensors if that output is of list
851 // type).
852 // For list outputs, convert the right subrange of _result into a list.
853 Unflatten(indentation, output_sizes, "_result", &result_);
854 // Convert to a named tuple.
855 strings::StrAppend(
856 &result_, indentation, "_result = _",
857 python_op_gen_internal::AvoidPythonReserved(op_def_.name()),
858 "Output._make(_result)\n");
859 }
860 } else {
861 strings::StrAppend(&result_, indentation, "_result = None\n");
862 }
863 strings::StrAppend(&result_, indentation, "return _result\n\n");
864}
865
866bool GenEagerPythonOp::AddEagerFastPathAndGraphCode(
867 const string& parameters, const std::vector<string>& output_sizes,
868 const string& eager_not_allowed_error,
869 const std::unordered_map<string, string>& type_annotations) {
870 if (add_type_annotations_) {
871 GenerateTypeVars(type_annotations);
872 }
873 if (api_def_.visibility() == ApiDef::VISIBLE) {
874 strings::StrAppend(&result_, "@_dispatch.add_fallback_dispatch_list\n");
875 strings::StrAppend(&result_, "@_dispatch.add_type_based_api_dispatcher\n");
876 }
877
878 AddExport();
879 AddDefLine(function_name_, parameters);
880 if (add_type_annotations_) {
881 AddReturnTypeAnnotation(type_annotations);
882 }
883 AddDocStringDescription();
884 AddDocStringArgs();
885 AddDocStringInputs();
886 AddDocStringAttrs();
887 AddDocStringNameArg();
888 AddOutputGlobals(); // Added to prelude_
889 AddDocStringOutputs();
890 strings::StrAppend(&result_, " \"\"\"\n");
891
892 strings::StrAppend(&result_,
893 " _ctx = _context._context or _context.context()\n"
894 " tld = _ctx._thread_local_data\n",
895 " if tld.is_eager:", "\n");
896 if (eager_not_allowed_error.empty()) {
897 AddEagerFastPathExecute();
898 } else {
899 strings::StrAppend(&result_, " ", eager_not_allowed_error);
900 }
901
902 // Handle graph-mode case
903 string function_setup;
904 if (!GetEagerFunctionSetup(" ", &function_setup)) {
905 result_ = function_setup;
906 return false;
907 }
908 HandleGraphMode(function_setup, output_sizes);
909
910 AddRawOpExport(parameters);
911 AddTypeBasedDispatcherAlias();
912 strings::StrAppend(&result_, "\n\n");
913 return true;
914}
915
916bool GenEagerPythonOp::AddEagerFallbackCode(
917 const string& parameters, const std::vector<string>& output_sizes,
918 const string& num_outputs_expr, const string& eager_not_allowed_error,
919 const std::unordered_map<string, string>& type_annotations) {
920 AddDefLine(
921 strings::StrCat(function_name_, kEagerFallbackSuffix),
922 strings::StrCat(parameters, parameters.empty() ? "" : ", ", "ctx"));
923 if (add_type_annotations_) {
924 AddReturnTypeAnnotation(type_annotations);
925 }
926 if (!eager_not_allowed_error.empty()) {
927 strings::StrAppend(&result_, " ", eager_not_allowed_error);
928 return true;
929 }
930
931 string function_setup;
932 if (!GetEagerFunctionSetup(" ", &function_setup)) {
933 result_ = function_setup;
934 return false;
935 }
936 strings::StrAppend(&result_, function_setup);
937
938 AddEagerInferredAttrs(" ");
939 AddEagerInputCasts(" ");
940 strings::StrAppend(
941 &result_, " _inputs_flat = ", FlattenInputs(nullptr, nullptr), "\n");
942 AddEagerAttrs(" ");
943 AddEagerExecute(" ", num_outputs_expr);
944
945 AddEagerFunctionTeardown(" ", output_sizes,
946 true /* execute_record_gradient */);
947
948 return true;
949}
950
951void GenEagerPythonOp::AddEagerFastPathExecute() {
952 string fastpath_execute_params =
953 strings::StrCat("_ctx, \"", op_def_.name(), "\", ", "name");
954 string fallback_params;
955
956 for (int i = 0; i < api_def_.in_arg_size(); i++) {
957 const string param_name = param_names_[i].GetRenameTo();
958 strings::StrAppend(&fastpath_execute_params, ", ", param_name);
959 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
960 strings::StrAppend(&fallback_params, param_name);
961 }
962
963 for (const auto& attr : api_def_.attr()) {
964 if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) {
965 strings::StrAppend(&fastpath_execute_params, ", \"", attr.name(), "\", ",
966 attr.rename_to());
967
968 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
969 strings::StrAppend(&fallback_params, attr.rename_to(), "=",
970 attr.rename_to());
971 }
972 }
973
974 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
975 strings::StrAppend(&fallback_params, "name=name");
976
977 strings::StrAppend(&result_, " try:\n");
978 strings::StrAppend(
979 &result_, " ", "_result = pywrap_tfe.TFE_Py_FastPathExecute(\n",
980 WordWrap(strings::StrCat(" "),
981 strings::StrCat(fastpath_execute_params, ")"), kRightMargin),
982 "\n");
983
984 if (op_def_.output_arg_size() > 1) {
985 const string output_tuple_name = strings::StrCat(
986 "_", python_op_gen_internal::AvoidPythonReserved(op_def_.name()),
987 "Output");
988 strings::StrAppend(&result_, " ", "_result = ", output_tuple_name,
989 "._make(_result)\n");
990 }
991 strings::StrAppend(&result_, " ", "return _result\n");
992
993 // Handle fallback.
994 if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", ");
995 strings::StrAppend(&fallback_params, "ctx=_ctx");
996
997 // Any errors thrown from execute need to be unwrapped from
998 // _NotOkStatusException.
999 strings::StrAppend(&result_, " ",
1000 "except _core._NotOkStatusException as e:\n");
1001 strings::StrAppend(&result_, " ",
1002 "_ops.raise_from_not_ok_status(e, name)\n");
1003
1004 strings::StrAppend(&result_, " ", "except _core._FallbackException:\n");
1005 strings::StrAppend(&result_, " pass\n");
1006 strings::StrAppend(&result_, " try:\n");
1007 AddTypeBasedDispatch(" ");
1008 strings::StrAppend(
1009 &result_, " ", "return ", function_name_, kEagerFallbackSuffix,
1010 "(\n",
1011 WordWrap(strings::StrCat(" "),
1012 strings::StrCat(fallback_params, ")"), kRightMargin),
1013 "\n");
1014 strings::StrAppend(&result_, " except _core._SymbolicException:\n");
1015 strings::StrAppend(&result_,
1016 " pass # Add nodes to the TensorFlow graph.\n");
1017 AddFallbackDispatch(" ");
1018}
1019
1020void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) {
1021 // Figure out values for inferred attrs, and cast to eager tensors.
1022 for (int i = 0; i < op_def_.attr_size(); ++i) {
1023 const auto& attr(op_def_.attr(i));
1024 const auto& api_def_attr(api_def_.attr(i));
1025 auto arg_list = attr_to_args_.find(attr.name());
1026 if (arg_list != attr_to_args_.end()) {
1027 if (attr.type() == "type") {
1028 std::vector<string> output_sizes;
1029 const string flattened =
1030 FlattenInputs(&arg_list->second, &output_sizes);
1031 string conversion = strings::StrCat("_execute.args_to_matching_eager(",
1032 flattened, ", ctx");
1033
1034 strings::StrAppend(&conversion, ", [");
1035 for (int t : attr.allowed_values().list().type()) {
1036 DataType dtype = static_cast<DataType>(t);
1037 const string py_dtype =
1038 python_op_gen_internal::DataTypeToPython(dtype, "_dtypes.");
1039 strings::StrAppend(&conversion, py_dtype, ", ");
1040 }
1041 strings::StrAppend(&conversion, "]");
1042
1043 if (attr.has_default_value()) {
1044 strings::StrAppend(
1045 &conversion, ", ",
1046 python_op_gen_internal::AttrValueToPython(
1047 attr.type(), api_def_attr.default_value(), "_dtypes."));
1048 }
1049 strings::StrAppend(&conversion, ")");
1050 const string var_name = AttrVarName(attr.name(), &attr_expressions_);
1051 if (output_sizes.size() == 1) {
1052 // Avoid creating a temporary variable in the case where
1053 // we can easily assign to the right value directly.
1054 const string inputs_var =
1055 param_names_[arg_list->second.front()].GetRenameTo();
1056 if (output_sizes.front().empty()) {
1057 strings::StrAppend(&result_, indentation, var_name, ", (",
1058 inputs_var, ",) = ", conversion, "\n");
1059 } else {
1060 strings::StrAppend(&result_, indentation, var_name, ", ",
1061 inputs_var, " = ", conversion, "\n");
1062 }
1063 } else {
1064 const string inputs_var = strings::StrCat("_inputs_", attr.name());
1065 strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
1066 " = ", conversion, "\n");
1067 // Convert from a flat list of eager tensors back to the
1068 // parameter variables.
1069 Unflatten(indentation, output_sizes, inputs_var, &result_);
1070 std::vector<string> p;
1071 for (int j : arg_list->second) {
1072 p.emplace_back(param_names_[j].GetRenameTo());
1073 }
1074 strings::StrAppend(&result_, indentation, VectorToTuple(p), " = ",
1075 inputs_var, "\n");
1076 }
1077 } else if (attr.type() == "list(type)") {
1078 // NOTE: We ignore default values for these attrs, since it is
1079 // unclear how you would use it, and the one use case is
1080 // parse_single_sequence_example which only needs it for
1081 // backwards compatibility.
1082 const string var_name = AttrVarName(attr.name(), &attr_expressions_);
1083 string inputs_var;
1084 string conversion;
1085 if (arg_list->second.size() > 1) {
1086 // If you have more than one list(tensor) argument, their types
1087 // have to match.
1088 std::vector<string> lists;
1089 for (auto iter = arg_list->second.begin();
1090 iter != arg_list->second.end(); ++iter) {
1091 lists.push_back(param_names_[*iter].GetRenameTo());
1092 }
1093 inputs_var = VectorToTuple(lists);
1094 conversion = "_execute.args_to_mixed_eager_tensors";
1095 } else {
1096 // For one list(tensor) argument, we just convert every
1097 // element of the list to an eager tensor.
1098 inputs_var = param_names_[arg_list->second.front()].GetRenameTo();
1099 conversion = "_execute.convert_to_mixed_eager_tensors";
1100 }
1101 strings::StrAppend(&result_, indentation, var_name, ", ", inputs_var,
1102 " = ", conversion, "(", inputs_var, ", ctx)\n");
1103 }
1104 }
1105 }
1106}
1107
1108void GenEagerPythonOp::AddEagerInputCasts(const string& indentation) {
1109 // Cast remaining args to eager tensors
1110 for (int i = 0; i < op_def_.input_arg_size(); ++i) {
1111 const auto& arg(op_def_.input_arg(i));
1112 if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue;
1113 const string& param = param_names_[i].GetRenameTo();
1114 const string fn = arg.number_attr().empty() ? "" : "n_";
1115 const string dtype =
1116 python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
1117 strings::StrAppend(&result_, indentation, param, " = _ops.convert_", fn,
1118 "to_tensor(", param, ", ", dtype, ")\n");
1119 }
1120}
1121
1122void GenEagerPythonOp::AddEagerAttrs(const string& indentation) {
1123 // Compute eager attrs
1124 if (op_def_.attr_size() > 0) {
1125 string attr_values;
1126 for (int i = 0; i < op_def_.attr_size(); ++i) {
1127 if (i > 0) strings::StrAppend(&attr_values, ", ");
1128 const auto& attr_name(op_def_.attr(i).name());
1129 strings::StrAppend(&attr_values, "\"", attr_name, "\", ",
1130 attr_expressions_[attr_name]);
1131 }
1132 strings::StrAppend(&attr_values, ")");
1133 strings::StrAppend(
1134 &result_,
1135 WordWrap(indentation, strings::StrCat("_attrs = (", attr_values),
1136 kRightMargin),
1137 "\n");
1138 } else {
1139 strings::StrAppend(&result_, indentation, "_attrs = None\n");
1140 }
1141}
1142
1143void GenEagerPythonOp::AddEagerExecute(const string& indentation,
1144 const string& num_outputs_expr) {
1145 const string return_prefix =
1146 strings::StrCat(indentation, "_result = _execute.execute(");
1147 const string return_args = strings::StrCat(
1148 "b\"", op_def_.name(), "\", ", num_outputs_expr,
1149 ", inputs=_inputs_flat, attrs=_attrs, ctx=ctx, name=name)");
1150 strings::StrAppend(&result_,
1151 // Wrap the arguments, and indent to the (.
1152 WordWrap(return_prefix, return_args, kRightMargin), "\n");
1153}
1154
1155void GenEagerPythonOp::AddFallbackDispatch(const string& prefix) {
1156 if (api_def_.visibility() != ApiDef::VISIBLE) return;
1157
1158 strings::StrAppend(&result_, prefix, "except (TypeError, ValueError):\n");
1159 strings::StrAppend(&result_, prefix, " _result = _dispatch.dispatch(\n");
1160 AddBodyNoReturn(strings::StrCat(prefix, " ", function_name_,
1161 ", "
1162 "(), dict("));
1163 strings::StrAppend(&result_, prefix, " )\n");
1164 strings::StrAppend(&result_, prefix,
1165 " if _result is not "
1166 "_dispatch.OpDispatcher.NOT_SUPPORTED:\n");
1167 strings::StrAppend(&result_, prefix, " return _result\n");
1168 strings::StrAppend(&result_, prefix, " raise\n");
1169}
1170
1171void GenEagerPythonOp::AddTypeBasedDispatcherAlias() {
1172 // It's possible for the name of a parameter to be the same as the name of
1173 // an op, in which case the parameter shadows the op's function. To avoid
1174 // this, we add a private variable with the dispatcher, and access that
1175 // directly.
1176 if (api_def_.visibility() == ApiDef::VISIBLE) {
1177 strings::StrAppend(&result_, "_dispatcher_for_", function_name_, " = ",
1178 function_name_, "._tf_type_based_dispatcher.Dispatch\n");
1179 }
1180}
1181void GenEagerPythonOp::AddTypeBasedDispatch(const string& prefix) {
1182 if (api_def_.visibility() != ApiDef::VISIBLE) return;
1183 std::string args("(");
1184 for (const auto& name : param_names_) {
1185 strings::StrAppend(&args, name.GetRenameTo(), ", ");
1186 }
1187 strings::StrAppend(&args, "name,), None");
1188
1189 strings::StrAppend(
1190 &result_, prefix, "_result = ", "_dispatcher_for_", function_name_, "(\n",
1191 WordWrap(strings::StrCat(prefix, " "), args, kRightMargin), ")\n");
1192 strings::StrAppend(&result_, prefix, "if _result is not NotImplemented:\n",
1193 prefix, " return _result\n");
1194}
1195
1196void GenEagerPythonOp::AddRawOpExport(const string& parameters) {
1197 // Example:
1198 //
1199 // Identity = tf_export("raw_ops.Identity")(_ops._to_raw_op(identity))
1200 const string raw_function_name =
1201 python_op_gen_internal::AvoidPythonReserved(op_def_.name());
1202 strings::StrAppend(&result_, raw_function_name, " = tf_export(\"raw_ops.",
1203 raw_function_name, "\")", "(_ops.to_raw_op(",
1204 function_name_, "))\n");
1205}
1206
1207string GetPythonOpsImpl(
1208 const OpList& ops, const ApiDefMap& api_defs,
1209 const std::vector<string>& hidden_ops, const string& source_file_name = "",
1210 const std::unordered_set<string> type_annotate_ops = {}) {
1211 string result;
1212 // Header
1213 // TODO(josh11b): Mention the library for which wrappers are being generated.
1214 strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops.
1215
1216This file is MACHINE GENERATED! Do not edit.
1217)");
1218
1219 // Mention the original source file so someone tracing back through
1220 // generated Python code will know where to look next.
1221 if (!source_file_name.empty()) {
1222 strings::StrAppend(&result, "Original C++ source file: ");
1223 strings::StrAppend(&result, source_file_name);
1224 strings::StrAppend(&result, "\n");
1225 }
1226
1227 strings::StrAppend(&result, R"("""
1228
1229import collections
1230
1231from tensorflow.python import pywrap_tfe as pywrap_tfe
1232from tensorflow.python.eager import context as _context
1233from tensorflow.python.eager import core as _core
1234from tensorflow.python.eager import execute as _execute
1235from tensorflow.python.framework import dtypes as _dtypes
1236
1237from tensorflow.python.framework import op_def_registry as _op_def_registry
1238from tensorflow.python.framework import ops as _ops
1239from tensorflow.python.framework import op_def_library as _op_def_library
1240from tensorflow.python.util.deprecation import deprecated_endpoints
1241from tensorflow.python.util import dispatch as _dispatch
1242from tensorflow.python.util.tf_export import tf_export
1243
1244from typing import TypeVar
1245)");
1246
1247 for (const auto& op_def : ops.op()) {
1248 const auto* api_def = api_defs.GetApiDef(op_def.name());
1249
1250 if (api_def->visibility() == ApiDef::SKIP) {
1251 continue;
1252 }
1253 // An op is hidden if either its ApiDef visibility is HIDDEN
1254 // or it is in the hidden_ops list.
1255 bool is_hidden = api_def->visibility() == ApiDef::HIDDEN;
1256 bool hidden_by_api_def = is_hidden;
1257 if (!is_hidden) {
1258 for (const string& hidden : hidden_ops) {
1259 if (op_def.name() == hidden) {
1260 is_hidden = true;
1261 break;
1262 }
1263 }
1264 }
1265
1266 string function_name;
1267 python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(),
1268 &function_name);
1269 bool is_reserved = python_op_gen_internal::IsPythonReserved(function_name);
1270
1271 // Prefix an op with underscore if the op is listed in hidden_ops or
1272 // name is reserved or it is of the exceptions in IsOpWithUnderscorePrefix.
1273 // Do not add underscores to ops set to HIDDEN in ApiDef otherwise.
1274 // TODO(annarev): don't prefix with underscores even if op is in hidden_ops.
1275 if (is_hidden) {
1276 if (!hidden_by_api_def || is_reserved ||
1277 python_op_gen_internal::IsOpWithUnderscorePrefix(function_name)) {
1278 function_name = strings::StrCat("_", function_name);
1279 }
1280 } else if (is_reserved) {
1281 // When users create custom python wrappers, they may link in the
1282 // default op registry by accident, and because they can't
1283 // enumerate all 'hidden' symbols, this guard is to prevent
1284 // instantiating a python reserved word in their wrapper.
1285 continue;
1286 }
1287
1288 auto iter = type_annotate_ops.find(op_def.name());
1289 bool add_type_annotations = iter != type_annotate_ops.end();
1290
1291 strings::StrAppend(&result,
1292 GetEagerPythonOp(op_def, *api_def, function_name,
1293 add_type_annotations));
1294 }
1295
1296 return result;
1297}
1298
1299} // namespace
1300
1301string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs,
1302 const std::vector<string>& hidden_ops,
1303 const string& source_file_name,
1304 const std::unordered_set<string> type_annotate_ops) {
1305 return GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name,
1306 type_annotate_ops);
1307}
1308
1309void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs,
1310 const std::vector<string>& hidden_ops,
1311 const string& source_file_name,
1312 const std::unordered_set<string> type_annotate_ops) {
1313 printf("%s", GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name,
1314 type_annotate_ops)
1315 .c_str());
1316}
1317
1318string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
1319 OpList ops;
1320 ops.ParseFromArray(op_list_buf, op_list_len);
1321
1322 ApiDefMap api_def_map(ops);
1323 return GetPythonOpsImpl(ops, api_def_map, {});
1324}
1325
1326string GetArgAnnotation(
1327 const OpDef::ArgDef& arg,
1328 const std::unordered_map<string, string>& type_annotations) {
1329 if (!arg.type_attr().empty()) {
1330 // Get the correct TypeVar if arg maps to an attr
1331 return "_ops.Tensor[" + type_annotations.at(arg.type_attr()) + "]";
1332 } else {
1333 // Get the dtype of the Tensor
1334 const string py_dtype =
1335 python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes.");
1336 return "_ops.Tensor[" + dtype_type.at(py_dtype) + "]";
1337 }
1338
1339 return "Any";
1340}
1341
1342} // namespace tensorflow
1343