1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/python/framework/python_op_gen.h" |
16 | |
17 | #include <stdio.h> |
18 | |
19 | #include <sstream> |
20 | #include <string> |
21 | #include <unordered_map> |
22 | |
23 | #include "absl/strings/escaping.h" |
24 | #include "tensorflow/core/framework/api_def.pb.h" |
25 | #include "tensorflow/core/framework/attr_value.pb.h" |
26 | #include "tensorflow/core/framework/op.h" |
27 | #include "tensorflow/core/framework/op_def.pb.h" |
28 | #include "tensorflow/core/framework/op_def_util.h" |
29 | #include "tensorflow/core/framework/op_gen_lib.h" |
30 | #include "tensorflow/core/framework/tensor.pb.h" |
31 | #include "tensorflow/core/framework/types.h" |
32 | #include "tensorflow/core/framework/types.pb.h" |
33 | #include "tensorflow/core/lib/gtl/map_util.h" |
34 | #include "tensorflow/core/lib/strings/str_util.h" |
35 | #include "tensorflow/core/lib/strings/strcat.h" |
36 | #include "tensorflow/core/lib/strings/stringprintf.h" |
37 | #include "tensorflow/core/platform/logging.h" |
38 | #include "tensorflow/core/platform/macros.h" |
39 | #include "tensorflow/core/platform/types.h" |
40 | #include "tensorflow/python/framework/python_op_gen_internal.h" |
41 | |
42 | namespace tensorflow { |
43 | namespace { |
44 | |
45 | const int kRightMargin = 78; |
46 | |
47 | constexpr char kEagerFallbackSuffix[] = "_eager_fallback" ; |
48 | |
49 | // Maps C++ dtype enum values to Python DType classes |
50 | const std::unordered_map<string, string> dtype_type{ |
51 | {"_dtypes.float16" , "_dtypes.Float16" }, |
52 | {"_dtypes.half" , "_dtypes.Half" }, |
53 | {"_dtypes.float32" , "_dtypes.Float32" }, |
54 | {"_dtypes.float64" , "_dtypes.Float64" }, |
55 | {"_dtypes.bfloat16" , "_dtypes.BFloat16" }, |
56 | {"_dtypes.complex64" , "_dtypes.Complex64" }, |
57 | {"_dtypes.complex128" , "_dtypes.Complex128" }, |
58 | {"_dtypes.int8" , "_dtypes.Int8" }, |
59 | {"_dtypes.uint8" , "_dtypes.UInt8" }, |
60 | {"_dtypes.uint16" , "_dtypes.UInt16" }, |
61 | {"_dtypes.uint32" , "_dtypes.UInt32" }, |
62 | {"_dtypes.uint64" , "_dtypes.UInt64" }, |
63 | {"_dtypes.int16" , "_dtypes.Int16" }, |
64 | {"_dtypes.int32" , "_dtypes.Int32" }, |
65 | {"_dtypes.int64" , "_dtypes.Int64" }, |
66 | {"_dtypes.bool" , "_dtypes.Bool" }, |
67 | {"_dtypes.string" , "_dtypes.String" }, |
68 | {"_dtypes.qint8" , "_dtypes.QInt8" }, |
69 | {"_dtypes.quint8" , "_dtypes.QUInt8" }, |
70 | {"_dtypes.qint16" , "_dtypes.QInt16" }, |
71 | {"_dtypes.quint16" , "_dtypes.QUInt16" }, |
72 | {"_dtypes.qint32" , "_dtypes.QInt32" }, |
73 | {"_dtypes.resource" , "_dtypes.Resource" }, |
74 | {"_dtypes.variant" , "_dtypes.Variant" }}; |
75 | |
76 | string AttrVarName(const string& attr_name, |
77 | std::unordered_map<string, string>* attr_expressions) { |
78 | const string var = strings::StrCat("_attr_" , attr_name); |
79 | if (attr_expressions != nullptr) (*attr_expressions)[attr_name] = var; |
80 | return var; |
81 | } |
82 | |
83 | void AddInferredAttr(const string& indentation, const string& attr_name, |
84 | const string& value_expression, string* result, |
85 | std::unordered_map<string, string>* attr_expressions) { |
86 | strings::StrAppend(result, indentation, |
87 | AttrVarName(attr_name, attr_expressions), " = " , |
88 | value_expression, "\n" ); |
89 | } |
90 | |
91 | string VectorToTuple(const std::vector<string>& l) { |
92 | if (l.size() == 1) return strings::StrCat("(" , l.front(), ",)" ); |
93 | string ret = "(" ; |
94 | for (int i = 0, end = l.size(); i < end; ++i) { |
95 | if (i > 0) { |
96 | strings::StrAppend(&ret, ", " ); |
97 | } |
98 | strings::StrAppend(&ret, l[i]); |
99 | } |
100 | strings::StrAppend(&ret, ")" ); |
101 | return ret; |
102 | } |
103 | |
104 | void Unflatten(const string& prefix, const std::vector<string>& output_sizes, |
105 | const string& var, string* result) { |
106 | for (int i = 0, end = output_sizes.size(); i < end; ++i) { |
107 | if (!output_sizes[i].empty()) { |
108 | strings::StrAppend(result, prefix, var, " = " ); |
109 | if (i > 0) strings::StrAppend(result, var, "[:" , i, "] + " ); |
110 | if (i + 1 < end) { |
111 | // Special case i == 0 to avoid "0 +" in the generated code. |
112 | if (i == 0) { |
113 | strings::StrAppend(result, "[" , var, "[:" , output_sizes[i], "]] + " , |
114 | var, "[" , output_sizes[i], ":]" ); |
115 | } else { |
116 | strings::StrAppend(result, "[" , var, "[" , i, ":" , i, " + " , |
117 | output_sizes[i], "]] + " , var, "[" , i, " + " , |
118 | output_sizes[i], ":]" ); |
119 | } |
120 | } else { |
121 | strings::StrAppend(result, "[" , var, "[" , i, ":]]" ); |
122 | } |
123 | strings::StrAppend(result, "\n" ); |
124 | } |
125 | } |
126 | } |
127 | |
128 | string TensorPBString(const TensorProto& pb) { |
129 | // Explicitly not using ShortDebugString, because ShortDebugString should |
130 | // not be used as a format for transporting information (it's e.g. subject |
131 | // to redaction of sensitive information). There is a PrintShortTextProto |
132 | // helper, but it's not feasible to depend on that library). |
133 | |
134 | std::string message_short_text; |
135 | |
136 | ::tensorflow::protobuf::TextFormat::Printer printer; |
137 | printer.SetSingleLineMode(true); |
138 | printer.SetExpandAny(true); |
139 | |
140 | printer.PrintToString(pb, &message_short_text); |
141 | |
142 | // Note: This gets used in the argument list, and so must survive naive |
143 | // word wrapping. |
144 | return strings::StrCat("\"\"\"" , message_short_text, "\"\"\"" ); |
145 | } |
146 | |
147 | class GenEagerPythonOp : public python_op_gen_internal::GenPythonOp { |
148 | public: |
149 | GenEagerPythonOp(const OpDef& op_def, const ApiDef& api_def, |
150 | const string& function_name, bool add_type_annotations) |
151 | : python_op_gen_internal::GenPythonOp(op_def, api_def, function_name, |
152 | add_type_annotations) { |
153 | op_name_ = function_name_; |
154 | absl::ConsumePrefix(&op_name_, "_" ); |
155 | } |
156 | ~GenEagerPythonOp() override {} |
157 | |
158 | string Code() override; |
159 | |
160 | protected: |
161 | void HandleGraphMode(const string& function_setup, |
162 | const std::vector<string>& output_sizes); |
163 | |
164 | string GetEagerNotAllowedError(); |
165 | void ExpectListArg(const string& indentation, const string& arg_name, |
166 | string* output); |
167 | bool GetEagerFunctionSetup(const string& indentation, string* function_setup); |
168 | void GetOutputSizesAndNumOutputsExpr(std::vector<string>* output_sizes, |
169 | string* num_outputs_expr); |
170 | |
171 | void AddEagerFunctionTeardown(const string& indentation, |
172 | const std::vector<string>& output_sizes, |
173 | bool execute_record_gradient); |
174 | |
175 | bool AddEagerFastPathAndGraphCode( |
176 | const string& parameters, const std::vector<string>& output_sizes, |
177 | const string& eager_not_allowed_error, |
178 | const std::unordered_map<string, string>& type_annotations); |
179 | bool AddEagerFallbackCode( |
180 | const string& parameters, const std::vector<string>& output_sizes, |
181 | const string& num_outputs_expr, const string& eager_not_allowed_error, |
182 | const std::unordered_map<string, string>& type_annotations); |
183 | void AddEagerFastPathExecute(); |
184 | |
185 | void AddEagerInferredAttrs(const string& indentation); |
186 | void AddEagerInputCasts(const string& indentation); |
187 | void AddEagerAttrs(const string& indentation); |
188 | void AddEagerExecute(const string& indentation, |
189 | const string& num_outputs_expr); |
190 | void AddFallbackDispatch(const string& prefix); |
191 | void AddTypeBasedDispatch(const string& prefix); |
192 | void AddTypeBasedDispatcherAlias(); |
193 | |
194 | void AddRawOpExport(const string& parameters); |
195 | |
196 | std::unordered_map<string, string> GetTypeAnnotations(); |
197 | |
198 | void GenerateTypeVars( |
199 | const std::unordered_map<string, string>& type_annotations); |
200 | |
201 | void AddReturnTypeAnnotation( |
202 | const std::unordered_map<string, string>& type_annotations); |
203 | |
204 | void AddAttrForArg(const string& attr, int arg_index) { |
205 | gtl::InsertIfNotPresent(&inferred_attrs_, attr, |
206 | op_def_.input_arg(arg_index).name()); |
207 | auto iter = attr_to_args_.find(attr); |
208 | if (iter == attr_to_args_.end()) { |
209 | attr_to_args_.insert(AttrToArgMap::value_type(attr, {arg_index})); |
210 | } else { |
211 | iter->second.push_back(arg_index); |
212 | } |
213 | } |
214 | |
215 | // Returns a string expression representing a flattened list of all |
216 | // the inputs given by `*input_indices` (or all inputs if |
217 | // `input_indices` is nullptr). `*output_sizes` can be used to unflatten. |
218 | string FlattenInputs(const std::vector<int>* input_indices, |
219 | std::vector<string>* output_sizes) const; |
220 | |
221 | StringPiece op_name_; |
222 | typedef std::unordered_map<string, std::vector<int>> AttrToArgMap; |
223 | AttrToArgMap attr_to_args_; |
224 | std::unordered_map<string, string> attr_expressions_; |
225 | // This has all the input args followed by those attrs that don't have |
226 | // defaults. |
227 | std::vector<python_op_gen_internal::ParamNames> params_no_default_; |
228 | // The parameters with defaults (these have to be listed after those without). |
229 | // No input args are included, just attrs. |
230 | std::vector<std::pair<python_op_gen_internal::ParamNames, string>> |
231 | params_with_default_; |
232 | }; |
233 | |
234 | string GetEagerPythonOp(const OpDef& op_def, const ApiDef& api_def, |
235 | const string& function_name, |
236 | bool add_type_annotations) { |
237 | return GenEagerPythonOp(op_def, api_def, function_name, add_type_annotations) |
238 | .Code(); |
239 | } |
240 | |
241 | string GenEagerPythonOp::FlattenInputs( |
242 | const std::vector<int>* input_indices, |
243 | std::vector<string>* output_sizes) const { |
244 | string inputs; |
245 | enum { STARTING, WAS_LIST_INPUT, WAS_SOLO_INPUT } inputs_state = STARTING; |
246 | const int n = input_indices != nullptr ? input_indices->size() |
247 | : op_def_.input_arg_size(); |
248 | for (int j = 0; j < n; ++j) { |
249 | const int i = input_indices ? (*input_indices)[j] : j; |
250 | const auto& arg(op_def_.input_arg(i)); |
251 | const bool is_list = |
252 | !arg.type_list_attr().empty() || !arg.number_attr().empty(); |
253 | if (is_list) { |
254 | if (inputs_state == WAS_SOLO_INPUT) { |
255 | strings::StrAppend(&inputs, "] + " ); |
256 | } else if (inputs_state == WAS_LIST_INPUT) { |
257 | strings::StrAppend(&inputs, " + " ); |
258 | } |
259 | strings::StrAppend(&inputs, "list(" , param_names_[i].GetRenameTo(), ")" ); |
260 | inputs_state = WAS_LIST_INPUT; |
261 | if (output_sizes != nullptr) { |
262 | if (!arg.number_attr().empty()) { |
263 | output_sizes->emplace_back(AttrVarName(arg.number_attr(), nullptr)); |
264 | } else { |
265 | output_sizes->emplace_back( |
266 | strings::StrCat("len(" , param_names_[i].GetRenameTo(), ")" )); |
267 | } |
268 | } |
269 | } else { |
270 | if (inputs_state == WAS_SOLO_INPUT) { |
271 | strings::StrAppend(&inputs, ", " ); |
272 | } else if (inputs_state == WAS_LIST_INPUT) { |
273 | strings::StrAppend(&inputs, " + [" ); |
274 | } else { |
275 | strings::StrAppend(&inputs, "[" ); |
276 | } |
277 | strings::StrAppend(&inputs, param_names_[i].GetRenameTo()); |
278 | inputs_state = WAS_SOLO_INPUT; |
279 | if (output_sizes != nullptr) output_sizes->emplace_back(); |
280 | } |
281 | } |
282 | if (inputs_state == STARTING) return "[]" ; |
283 | if (inputs_state == WAS_SOLO_INPUT) { |
284 | strings::StrAppend(&inputs, "]" ); |
285 | } |
286 | return inputs; |
287 | } |
288 | |
289 | string GenEagerPythonOp::Code() { |
290 | if (api_def_.visibility() == ApiDef::SKIP) { |
291 | return "" ; |
292 | } |
293 | |
294 | for (int i = 0; i < api_def_.arg_order_size(); ++i) { |
295 | const auto& arg = *FindInputArg(api_def_.arg_order(i), op_def_); |
296 | const auto& api_def_arg = *FindInputArg(api_def_.arg_order(i), api_def_); |
297 | params_no_default_.emplace_back(api_def_arg.name(), |
298 | api_def_arg.rename_to()); |
299 | if (!arg.type_attr().empty()) { |
300 | AddAttrForArg(arg.type_attr(), i); |
301 | } else if (!arg.type_list_attr().empty()) { |
302 | AddAttrForArg(arg.type_list_attr(), i); |
303 | } |
304 | if (!arg.number_attr().empty()) { |
305 | AddAttrForArg(arg.number_attr(), i); |
306 | } |
307 | } |
308 | for (int i = 0; i < op_def_.attr_size(); ++i) { |
309 | const auto& attr(op_def_.attr(i)); |
310 | const auto& api_def_attr(api_def_.attr(i)); |
311 | // Do not add inferred attrs to the Python function signature. |
312 | if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) { |
313 | if (api_def_attr.has_default_value()) { |
314 | if (attr.type() == "tensor" ) { |
315 | params_with_default_.emplace_back( |
316 | python_op_gen_internal::ParamNames(api_def_attr.name(), |
317 | api_def_attr.rename_to()), |
318 | strings::StrCat( |
319 | "_execute.make_tensor(" , |
320 | TensorPBString(api_def_attr.default_value().tensor()), ", \"" , |
321 | api_def_attr.rename_to(), "\")" )); |
322 | } else if (attr.type() == "list(tensor)" ) { |
323 | std::vector<string> pbtxt; |
324 | for (const auto& pb : api_def_attr.default_value().list().tensor()) { |
325 | pbtxt.emplace_back(TensorPBString(pb)); |
326 | } |
327 | params_with_default_.emplace_back( |
328 | python_op_gen_internal::ParamNames(api_def_attr.name(), |
329 | api_def_attr.rename_to()), |
330 | strings::StrCat("[_execute.make_tensor(_pb, \"" , |
331 | api_def_attr.rename_to(), "\") for _pb in " , |
332 | VectorToTuple(pbtxt), "]" )); |
333 | } else { |
334 | params_with_default_.emplace_back( |
335 | python_op_gen_internal::ParamNames(api_def_attr.name(), |
336 | api_def_attr.rename_to()), |
337 | python_op_gen_internal::AttrValueToPython( |
338 | attr.type(), api_def_attr.default_value(), "_dtypes." )); |
339 | } |
340 | } else { |
341 | params_no_default_.emplace_back(api_def_attr.name(), |
342 | api_def_attr.rename_to()); |
343 | } |
344 | } |
345 | } |
346 | |
347 | // Save the list of attr parameters (attrs that won't be inferred), |
348 | // those with defaults go at the end. |
349 | // Get the attrs in the order we want by taking the attrs without defaults |
350 | // from the end of params_no_default_, and adding params_no_default_. |
351 | attrs_.reserve(params_no_default_.size() - op_def_.input_arg_size() + |
352 | params_with_default_.size()); |
353 | for (int i = op_def_.input_arg_size(), end = params_no_default_.size(); |
354 | i < end; ++i) { |
355 | attrs_.push_back(params_no_default_[i].GetName()); |
356 | } |
357 | for (const auto& p : params_with_default_) { |
358 | attrs_.push_back(p.first.GetName()); |
359 | } |
360 | |
361 | // TODO(slebedev): call AvoidPythonReserved on each param? |
362 | param_names_.reserve(params_no_default_.size() + params_with_default_.size()); |
363 | param_names_.insert(param_names_.begin(), params_no_default_.begin(), |
364 | params_no_default_.end()); |
365 | for (const auto& param_and_default : params_with_default_) { |
366 | param_names_.push_back(param_and_default.first); |
367 | } |
368 | |
369 | std::unordered_map<string, string> type_annotations; |
370 | // Only populate map for allowlisted ops |
371 | if (add_type_annotations_) { |
372 | type_annotations = GetTypeAnnotations(); |
373 | } |
374 | |
375 | string parameters; |
376 | // Param can be an input or an attr |
377 | for (const auto& param : params_no_default_) { |
378 | if (!parameters.empty()) strings::StrAppend(¶meters, ", " ); |
379 | strings::StrAppend(¶meters, param.GetRenameTo()); |
380 | |
381 | if (type_annotations.find(param.GetName()) != type_annotations.end()) { |
382 | strings::StrAppend(¶meters, ": " , |
383 | type_annotations.at(param.GetName())); |
384 | } |
385 | } |
386 | |
387 | string parameters_with_defaults = parameters; |
388 | for (const auto& param_and_default : params_with_default_) { |
389 | if (!parameters.empty()) strings::StrAppend(¶meters, ", " ); |
390 | if (!parameters_with_defaults.empty()) |
391 | strings::StrAppend(¶meters_with_defaults, ", " ); |
392 | |
393 | strings::StrAppend(¶meters, param_and_default.first.GetRenameTo()); |
394 | strings::StrAppend(¶meters_with_defaults, |
395 | param_and_default.first.GetRenameTo()); |
396 | if (type_annotations.find(param_and_default.first.GetName()) != |
397 | type_annotations.end()) { |
398 | const string param_type = |
399 | type_annotations.at(param_and_default.first.GetName()); |
400 | // Append to parameters and parameters_with_defaults because multiple |
401 | // functions are generated by AddEagerFastPathAndGraphCode() and |
402 | // AddEagerFallbackCode() |
403 | strings::StrAppend(¶meters, ": " , param_type); |
404 | strings::StrAppend(¶meters_with_defaults, ":" , param_type); |
405 | } |
406 | |
407 | strings::StrAppend(¶meters_with_defaults, "=" , |
408 | param_and_default.second); |
409 | } |
410 | |
411 | strings::StrAppend(¶meters, parameters.empty() ? "" : ", " , "name" ); |
412 | strings::StrAppend(¶meters_with_defaults, |
413 | parameters_with_defaults.empty() ? "" : ", " , "name=None" ); |
414 | |
415 | // Add attr_expressions_ for attrs that are params. |
416 | for (int i = 0, end = attrs_.size(); i < end; ++i) { |
417 | const string& attr_name = attrs_[i]; |
418 | const string& attr_api_name = |
419 | param_names_[i + op_def_.input_arg_size()].GetRenameTo(); |
420 | attr_expressions_[attr_name] = attr_api_name; |
421 | } |
422 | // Add attr_expressions_ for attrs that are inferred. |
423 | for (int i = 0; i < op_def_.attr_size(); ++i) { |
424 | const auto& attr(op_def_.attr(i)); |
425 | if (attr.type() == "int" ) { |
426 | auto arg_list = attr_to_args_.find(attr.name()); |
427 | if (arg_list != attr_to_args_.end()) { |
428 | AttrVarName(attr.name(), &attr_expressions_); |
429 | } |
430 | } |
431 | } |
432 | |
433 | string num_outputs_expr; |
434 | std::vector<string> output_sizes(num_outs_); |
435 | GetOutputSizesAndNumOutputsExpr(&output_sizes, &num_outputs_expr); |
436 | |
437 | string eager_not_allowed_error = GetEagerNotAllowedError(); |
438 | |
439 | if (!AddEagerFastPathAndGraphCode(parameters_with_defaults, output_sizes, |
440 | eager_not_allowed_error, |
441 | type_annotations)) { |
442 | return result_; |
443 | } |
444 | |
445 | if (!AddEagerFallbackCode(parameters, output_sizes, num_outputs_expr, |
446 | eager_not_allowed_error, type_annotations)) { |
447 | return result_; |
448 | } |
449 | |
450 | return prelude_ + result_; |
451 | } |
452 | |
453 | std::unordered_map<string, string> GenEagerPythonOp::GetTypeAnnotations() { |
454 | std::unordered_map<string, string> type_annotations; |
455 | // Map attrs to TypeVars |
456 | for (const auto& attr : op_def_.attr()) { |
457 | if (attr.type() == "type" ) { |
458 | const string type_var_name = "TV_" + op_def_.name() + "_" + attr.name(); |
459 | type_annotations[attr.name()] = type_var_name; |
460 | } else if (attr.type() == "bool" || attr.type() == "float" || |
461 | attr.type() == "int" || attr.type() == "bytes" ) { |
462 | type_annotations[attr.name()] = attr.type(); |
463 | } else if (attr.type() == "string" ) { |
464 | type_annotations[attr.name()] = "str" ; |
465 | } |
466 | } |
467 | |
468 | // Map input Tensors to their types |
469 | for (const auto& arg : op_def_.input_arg()) { |
470 | // TODO(rahulkamat): Add type annotations to args that accept a sequence of |
471 | // Tensors |
472 | if (!arg.number_attr().empty() || !arg.type_list_attr().empty()) continue; |
473 | type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations); |
474 | } |
475 | |
476 | // TODO(rahulkamat): Add type annotations to handle return types of a sequence |
477 | // of Tensors. Map output Tensor to its type |
478 | if (op_def_.output_arg_size() == 1) { |
479 | const auto& arg = op_def_.output_arg(0); |
480 | if (arg.number_attr().empty() && arg.type_list_attr().empty()) |
481 | type_annotations[arg.name()] = GetArgAnnotation(arg, type_annotations); |
482 | } |
483 | |
484 | return type_annotations; |
485 | } |
486 | |
487 | // Generate TypeVars using attrs |
488 | void GenEagerPythonOp::GenerateTypeVars( |
489 | const std::unordered_map<string, string>& type_annotations) { |
490 | bool added_typevar = false; |
491 | for (const auto& attr : op_def_.attr()) { |
492 | if (attr.type() == "type" ) { |
493 | std::vector<string> allowed_types; |
494 | for (int t : attr.allowed_values().list().type()) { |
495 | DataType dtype = static_cast<DataType>(t); |
496 | const string py_dtype = |
497 | python_op_gen_internal::DataTypeToPython(dtype, "_dtypes." ); |
498 | allowed_types.emplace_back(dtype_type.at(py_dtype)); |
499 | } |
500 | |
501 | // When a Tensor does not have any dtypes specified, all dtypes are |
502 | // allowed |
503 | if (allowed_types.empty()) { |
504 | for (std::pair<string, string> map_dtype : dtype_type) { |
505 | allowed_types.emplace_back(map_dtype.second); |
506 | } |
507 | } |
508 | |
509 | std::sort(allowed_types.begin(), allowed_types.end()); |
510 | |
511 | string typevar_dtypes; |
512 | for (std::vector<string>::iterator it = allowed_types.begin(); |
513 | it != allowed_types.end(); ++it) { |
514 | if (!typevar_dtypes.empty()) strings::StrAppend(&typevar_dtypes, ", " ); |
515 | strings::StrAppend(&typevar_dtypes, *it); |
516 | } |
517 | |
518 | const string type_var_name = type_annotations.at(attr.name()); |
519 | strings::StrAppend(&result_, type_var_name, " = TypeVar(\"" , |
520 | type_var_name, "\", " , typevar_dtypes, ")\n" ); |
521 | added_typevar = true; |
522 | } |
523 | } |
524 | |
525 | if (added_typevar) strings::StrAppend(&result_, "\n" ); |
526 | } |
527 | |
528 | void GenEagerPythonOp::AddReturnTypeAnnotation( |
529 | const std::unordered_map<string, string>& type_annotations) { |
530 | if (op_def_.output_arg_size() == 1) { |
531 | const auto& arg = op_def_.output_arg(0); |
532 | if (arg.number_attr().empty() && arg.type_list_attr().empty()) { |
533 | const string return_type = type_annotations.at(arg.name()); |
534 | // TODO(rahulkamat): Modify AddDefLine() to add return type annotation to |
535 | // avoid erasing ":\n" from the end of the def line |
536 | result_.erase(result_.length() - 2); |
537 | strings::StrAppend(&result_, " -> " , return_type, ":\n" ); |
538 | } |
539 | } |
540 | } |
541 | |
542 | void GenEagerPythonOp::HandleGraphMode( |
543 | const string& function_setup, const std::vector<string>& output_sizes) { |
544 | if (api_def_.visibility() == ApiDef::VISIBLE) { |
545 | strings::StrAppend(&result_, " else:\n" ); |
546 | AddTypeBasedDispatch(" " ); |
547 | } |
548 | strings::StrAppend(&result_, " # Add nodes to the TensorFlow graph.\n" ); |
549 | strings::StrAppend(&result_, function_setup); |
550 | if (api_def_.visibility() == ApiDef::VISIBLE) { |
551 | strings::StrAppend(&result_, " try:\n " ); |
552 | } |
553 | strings::StrAppend( |
554 | &result_, " _, _, _op, _outputs = _op_def_library._apply_op_helper(\n" ); |
555 | AddBodyNoReturn(strings::StrCat(" \"" , op_def_.name(), "\", " )); |
556 | AddFallbackDispatch(" " ); |
557 | |
558 | if (num_outs_ > 0) { |
559 | strings::StrAppend(&result_, " _result = _outputs[:]\n" ); |
560 | // Special case handling for stateful op with single list output |
561 | // that might be empty. |
562 | if (num_outs_ == 1 && op_def_.is_stateful() && |
563 | (!op_def_.output_arg(0).number_attr().empty() || |
564 | !op_def_.output_arg(0).type_list_attr().empty())) { |
565 | // TODO(josh11b): Can skip this if the number_attr/type_list_attr has |
566 | // a constraint indicating that this can never be empty. |
567 | strings::StrAppend(&result_, |
568 | " if not _result:\n" |
569 | " return _op\n" ); |
570 | } |
571 | |
572 | // Compute graph-mode attrs when we need to record a gradient. |
573 | strings::StrAppend(&result_, " if _execute.must_record_gradient():\n" ); |
574 | if (op_def_.attr_size() > 0) { |
575 | string attr_values; |
576 | for (int i = 0; i < op_def_.attr_size(); ++i) { |
577 | if (i > 0) strings::StrAppend(&attr_values, ", " ); |
578 | const auto& attr_name(op_def_.attr(i).name()); |
579 | if (op_def_.attr(i).type() == "type" ) { |
580 | strings::StrAppend(&attr_values, "\"" , attr_name, |
581 | "\", _op._get_attr_type(\"" , attr_name, "\")" ); |
582 | } else if (op_def_.attr(i).type() == "bool" ) { |
583 | strings::StrAppend(&attr_values, "\"" , attr_name, |
584 | "\", _op._get_attr_bool(\"" , attr_name, "\")" ); |
585 | } else if (op_def_.attr(i).type() == "int" ) { |
586 | strings::StrAppend(&attr_values, "\"" , attr_name, |
587 | "\", _op._get_attr_int(\"" , attr_name, "\")" ); |
588 | } else { |
589 | strings::StrAppend(&attr_values, "\"" , attr_name, |
590 | "\", _op.get_attr(\"" , attr_name, "\")" ); |
591 | } |
592 | } |
593 | strings::StrAppend(&attr_values, ")" ); |
594 | strings::StrAppend(&result_, |
595 | WordWrap(" _attrs = (" , attr_values, kRightMargin), |
596 | "\n" ); |
597 | |
598 | } else { |
599 | strings::StrAppend(&result_, " _attrs = ()\n" ); |
600 | } |
601 | |
602 | strings::StrAppend(&result_, " _inputs_flat = _op.inputs\n" ); |
603 | strings::StrAppend(&result_, " _execute.record_gradient(\n" , |
604 | " \"" , op_def_.name(), |
605 | "\", _inputs_flat, _attrs, _result)\n" ); |
606 | |
607 | if (num_outs_ == 1 && !output_sizes[0].empty()) { |
608 | // Single list result. |
609 | } else if (num_outs_ == 1) { |
610 | // Execute returns a single-element list which we need to destructure. |
611 | strings::StrAppend(&result_, " " , "_result, = _result\n" ); |
612 | } else { |
613 | // Have multiple outputs, so we will need to reformat the return |
614 | // value of execute() to be a list with one entry per op output |
615 | // (that entry will be a list of tensors if that output is of list |
616 | // type). |
617 | // For list outputs, convert the right subrange of _result into a list. |
618 | Unflatten(" " , output_sizes, "_result" , &result_); |
619 | // Convert to a named tuple. |
620 | strings::StrAppend( |
621 | &result_, " _result = _" , |
622 | python_op_gen_internal::AvoidPythonReserved(op_def_.name()), |
623 | "Output._make(_result)\n" ); |
624 | } |
625 | strings::StrAppend(&result_, " return _result\n\n" ); |
626 | } else { |
627 | strings::StrAppend(&result_, " return _op\n" ); |
628 | } |
629 | } |
630 | |
631 | string GenEagerPythonOp::GetEagerNotAllowedError() { |
632 | bool eager_allowed = true; |
633 | string ref_arg; |
634 | for (int i = 0; i < op_def_.input_arg_size(); ++i) { |
635 | const auto& arg = op_def_.input_arg(i); |
636 | if (arg.is_ref()) { |
637 | eager_allowed = false; |
638 | DCHECK_EQ(op_def_.input_arg(i).name(), api_def_.in_arg(i).name()); |
639 | ref_arg = api_def_.in_arg(i).rename_to(); |
640 | } |
641 | } |
642 | for (int i = 0; i < op_def_.output_arg_size(); ++i) { |
643 | const auto& arg = op_def_.output_arg(i); |
644 | if (arg.is_ref()) { |
645 | eager_allowed = false; |
646 | DCHECK_EQ(op_def_.output_arg(i).name(), api_def_.out_arg(i).name()); |
647 | ref_arg = api_def_.out_arg(i).rename_to(); |
648 | } |
649 | } |
650 | |
651 | if (eager_allowed) return "" ; |
652 | |
653 | return strings::StrCat("raise RuntimeError(\"" , op_name_, |
654 | " op does not support eager execution. " , "Arg '" , |
655 | ref_arg, "' is a ref.\")\n" ); |
656 | } |
657 | |
658 | void GenEagerPythonOp::ExpectListArg(const string& indentation, |
659 | const string& arg_name, string* output) { |
660 | strings::StrAppend(output, indentation, "if not isinstance(" , arg_name, |
661 | ", (list, tuple)):\n" , indentation, " raise TypeError(\n" , |
662 | indentation, " \"Expected list for '" , arg_name, |
663 | "' argument to \"\n" , indentation, " \"'" , op_name_, |
664 | "' Op, not %r.\" % " , arg_name, ")\n" ); |
665 | } |
666 | |
667 | bool GenEagerPythonOp::GetEagerFunctionSetup(const string& indentation, |
668 | string* function_setup) { |
669 | // Validate list inputs, infer length attrs. |
670 | for (int i = 0; i < op_def_.attr_size(); ++i) { |
671 | const auto& attr(op_def_.attr(i)); |
672 | if (attr.type() == "int" ) { |
673 | auto arg_list = attr_to_args_.find(attr.name()); |
674 | if (arg_list != attr_to_args_.end()) { |
675 | // Inferred int attrs are the lengths of inputs. Validate those |
676 | // inputs are lists and have the same length. |
677 | for (auto iter = arg_list->second.begin(); |
678 | iter != arg_list->second.end(); ++iter) { |
679 | const string& arg_api_name = param_names_[*iter].GetRenameTo(); |
680 | ExpectListArg(indentation, arg_api_name, function_setup); |
681 | if (iter == arg_list->second.begin()) { |
682 | AddInferredAttr(indentation, attr.name(), |
683 | strings::StrCat("len(" , arg_api_name, ")" ), |
684 | function_setup, &attr_expressions_); |
685 | } else { |
686 | const auto& attr_var = attr_expressions_[attr.name()]; |
687 | strings::StrAppend( |
688 | function_setup, indentation, "if len(" , arg_api_name, |
689 | ") != " , attr_var, ":\n" , indentation, " raise ValueError(\n" , |
690 | indentation, " \"List argument '" , arg_api_name, "' to '" , |
691 | op_name_, "' Op with length %d \"\n" , indentation, |
692 | " \"must match length %d of argument '" , |
693 | inferred_attrs_[attr.name()], "'.\" %\n" , indentation, |
694 | " (len(" , arg_api_name, "), " , attr_var, "))\n" ); |
695 | } |
696 | } |
697 | } |
698 | } |
699 | } |
700 | |
701 | for (int i = 0, end = attrs_.size(); i < end; ++i) { |
702 | const string& attr_name = attrs_[i]; |
703 | const auto& param = param_names_[i + op_def_.input_arg_size()]; |
704 | const auto& attr = *FindAttr(attr_name, op_def_); |
705 | const string& attr_api_name = param.GetRenameTo(); |
706 | StringPiece attr_type = attr.type(); |
707 | attr_expressions_[attr_name] = attr_api_name; |
708 | const int default_index = i - (attrs_.size() - params_with_default_.size()); |
709 | if (default_index >= 0) { |
710 | const string& default_value = params_with_default_[default_index].second; |
711 | strings::StrAppend(function_setup, indentation, "if " , attr_api_name, |
712 | " is None:\n" ); |
713 | strings::StrAppend(function_setup, indentation, " " , attr_api_name, |
714 | " = " , default_value, "\n" ); |
715 | } |
716 | if (absl::StartsWith(attr_type, "list(" )) { |
717 | ExpectListArg(indentation, attr_api_name, function_setup); |
718 | } |
719 | |
720 | if (attr_type == "string" ) { |
721 | strings::StrAppend(function_setup, indentation, attr_api_name, |
722 | " = _execute.make_str(" , attr_api_name, ", \"" , |
723 | attr_api_name, "\")\n" ); |
724 | } else if (attr_type == "list(string)" ) { |
725 | strings::StrAppend(function_setup, indentation, attr_api_name, |
726 | " = [_execute.make_str(_s, \"" , attr_api_name, |
727 | "\") for _s in " , attr_api_name, "]\n" ); |
728 | } else if (attr_type == "int" ) { |
729 | strings::StrAppend(function_setup, indentation, attr_api_name, |
730 | " = _execute.make_int(" , attr_api_name, ", \"" , |
731 | attr_api_name, "\")\n" ); |
732 | } else if (attr_type == "list(int)" ) { |
733 | strings::StrAppend(function_setup, indentation, attr_api_name, |
734 | " = [_execute.make_int(_i, \"" , attr_api_name, |
735 | "\") for _i in " , attr_api_name, "]\n" ); |
736 | } else if (attr_type == "float" ) { |
737 | strings::StrAppend(function_setup, indentation, attr_api_name, |
738 | " = _execute.make_float(" , attr_api_name, ", \"" , |
739 | attr_api_name, "\")\n" ); |
740 | } else if (attr_type == "list(float)" ) { |
741 | strings::StrAppend(function_setup, indentation, attr_api_name, |
742 | " = [_execute.make_float(_f, \"" , attr_api_name, |
743 | "\") for _f in " , attr_api_name, "]\n" ); |
744 | } else if (attr_type == "bool" ) { |
745 | strings::StrAppend(function_setup, indentation, attr_api_name, |
746 | " = _execute.make_bool(" , attr_api_name, ", \"" , |
747 | attr_api_name, "\")\n" ); |
748 | } else if (attr_type == "list(bool)" ) { |
749 | strings::StrAppend(function_setup, indentation, attr_api_name, |
750 | " = [_execute.make_bool(_b, \"" , attr_api_name, |
751 | "\") for _b in " , attr_api_name, "]\n" ); |
752 | } else if (attr_type == "type" ) { |
753 | strings::StrAppend(function_setup, indentation, attr_api_name, |
754 | " = _execute.make_type(" , attr_api_name, ", \"" , |
755 | attr_api_name, "\")\n" ); |
756 | } else if (attr_type == "list(type)" ) { |
757 | strings::StrAppend(function_setup, indentation, attr_api_name, |
758 | " = [_execute.make_type(_t, \"" , attr_api_name, |
759 | "\") for _t in " , attr_api_name, "]\n" ); |
760 | } else if (attr_type == "shape" ) { |
761 | strings::StrAppend(function_setup, indentation, attr_api_name, |
762 | " = _execute.make_shape(" , attr_api_name, ", \"" , |
763 | attr_api_name, "\")\n" ); |
764 | } else if (attr_type == "list(shape)" ) { |
765 | strings::StrAppend(function_setup, indentation, attr_api_name, |
766 | " = [_execute.make_shape(_s, \"" , attr_api_name, |
767 | "\") for _s in " , attr_api_name, "]\n" ); |
768 | } else if (attr_type == "tensor" ) { |
769 | strings::StrAppend(function_setup, indentation, attr_api_name, |
770 | " = _execute.make_tensor(" , attr_api_name, ", \"" , |
771 | attr_api_name, "\")\n" ); |
772 | } else if (attr_type == "list(tensor)" ) { |
773 | strings::StrAppend(function_setup, indentation, attr_api_name, |
774 | " = [_execute.make_tensor(_t, \"" , attr_api_name, |
775 | "\") for _t in " , attr_api_name, "]\n" ); |
776 | } else if (attr_type != "func" && attr_type != "list(func)" ) { |
777 | *function_setup = |
778 | strings::StrCat("# No definition for " , function_name_, |
779 | " since we don't support attrs with type\n" |
780 | "# '" , |
781 | attr_type, "' right now.\n\n" ); |
782 | return false; |
783 | } |
784 | } |
785 | return true; |
786 | } |
787 | |
788 | // If output i is list output, output_sizes[i] will be set to a |
789 | // string with the python expression that will evaluate to its |
790 | // length. output_sizes[i] is empty for non-list outputs. |
791 | void GenEagerPythonOp::GetOutputSizesAndNumOutputsExpr( |
792 | std::vector<string>* output_sizes, string* num_outputs_expr) { |
793 | // Expression representing the number of outputs. |
794 | int num_fixed_outputs = 0; |
795 | for (int i = 0; i < num_outs_; ++i) { |
796 | const auto& arg(op_def_.output_arg(i)); |
797 | if (!arg.number_attr().empty()) { |
798 | if (!num_outputs_expr->empty()) { |
799 | strings::StrAppend(num_outputs_expr, " + " ); |
800 | } |
801 | (*output_sizes)[i] = attr_expressions_[arg.number_attr()]; |
802 | strings::StrAppend(num_outputs_expr, (*output_sizes)[i]); |
803 | } else if (!arg.type_list_attr().empty()) { |
804 | if (!num_outputs_expr->empty()) { |
805 | strings::StrAppend(num_outputs_expr, " + " ); |
806 | } |
807 | // Have to be careful to use an expression that works in both |
808 | // graph and eager paths here. |
809 | const auto iter = inferred_attrs_.find(arg.type_list_attr()); |
810 | if (iter == inferred_attrs_.end()) { |
811 | (*output_sizes)[i] = strings::StrCat( |
812 | "len(" , attr_expressions_[arg.type_list_attr()], ")" ); |
813 | } else { |
814 | (*output_sizes)[i] = strings::StrCat("len(" , iter->second, ")" ); |
815 | } |
816 | strings::StrAppend(num_outputs_expr, (*output_sizes)[i]); |
817 | } else { |
818 | ++num_fixed_outputs; |
819 | } |
820 | } |
821 | if (num_fixed_outputs > 0) { |
822 | if (!num_outputs_expr->empty()) { |
823 | strings::StrAppend(num_outputs_expr, " + " ); |
824 | } |
825 | strings::StrAppend(num_outputs_expr, num_fixed_outputs); |
826 | } else if (num_outputs_expr->empty()) { |
827 | *num_outputs_expr = "0" ; |
828 | } |
829 | } |
830 | |
831 | void GenEagerPythonOp::AddEagerFunctionTeardown( |
832 | const string& indentation, const std::vector<string>& output_sizes, |
833 | bool execute_record_gradient) { |
834 | if (num_outs_ > 0) { |
835 | if (execute_record_gradient) { |
836 | strings::StrAppend(&result_, indentation, |
837 | "if _execute.must_record_gradient():\n" ); |
838 | strings::StrAppend(&result_, indentation, " _execute.record_gradient(\n" , |
839 | " \"" , op_def_.name(), |
840 | "\", _inputs_flat, _attrs, _result)\n" ); |
841 | } |
842 | if (num_outs_ == 1 && !output_sizes[0].empty()) { |
843 | // Single list result. |
844 | } else if (num_outs_ == 1) { |
845 | // Execute returns a single-element list which we need to destructure. |
846 | strings::StrAppend(&result_, indentation, "_result, = _result\n" ); |
847 | } else { |
848 | // Have multiple outputs, so we will need to reformat the return |
849 | // value of execute() to be a list with one entry per op output |
850 | // (that entry will be a list of tensors if that output is of list |
851 | // type). |
852 | // For list outputs, convert the right subrange of _result into a list. |
853 | Unflatten(indentation, output_sizes, "_result" , &result_); |
854 | // Convert to a named tuple. |
855 | strings::StrAppend( |
856 | &result_, indentation, "_result = _" , |
857 | python_op_gen_internal::AvoidPythonReserved(op_def_.name()), |
858 | "Output._make(_result)\n" ); |
859 | } |
860 | } else { |
861 | strings::StrAppend(&result_, indentation, "_result = None\n" ); |
862 | } |
863 | strings::StrAppend(&result_, indentation, "return _result\n\n" ); |
864 | } |
865 | |
866 | bool GenEagerPythonOp::AddEagerFastPathAndGraphCode( |
867 | const string& parameters, const std::vector<string>& output_sizes, |
868 | const string& eager_not_allowed_error, |
869 | const std::unordered_map<string, string>& type_annotations) { |
870 | if (add_type_annotations_) { |
871 | GenerateTypeVars(type_annotations); |
872 | } |
873 | if (api_def_.visibility() == ApiDef::VISIBLE) { |
874 | strings::StrAppend(&result_, "@_dispatch.add_fallback_dispatch_list\n" ); |
875 | strings::StrAppend(&result_, "@_dispatch.add_type_based_api_dispatcher\n" ); |
876 | } |
877 | |
878 | AddExport(); |
879 | AddDefLine(function_name_, parameters); |
880 | if (add_type_annotations_) { |
881 | AddReturnTypeAnnotation(type_annotations); |
882 | } |
883 | AddDocStringDescription(); |
884 | AddDocStringArgs(); |
885 | AddDocStringInputs(); |
886 | AddDocStringAttrs(); |
887 | AddDocStringNameArg(); |
888 | AddOutputGlobals(); // Added to prelude_ |
889 | AddDocStringOutputs(); |
890 | strings::StrAppend(&result_, " \"\"\"\n" ); |
891 | |
892 | strings::StrAppend(&result_, |
893 | " _ctx = _context._context or _context.context()\n" |
894 | " tld = _ctx._thread_local_data\n" , |
895 | " if tld.is_eager:" , "\n" ); |
896 | if (eager_not_allowed_error.empty()) { |
897 | AddEagerFastPathExecute(); |
898 | } else { |
899 | strings::StrAppend(&result_, " " , eager_not_allowed_error); |
900 | } |
901 | |
902 | // Handle graph-mode case |
903 | string function_setup; |
904 | if (!GetEagerFunctionSetup(" " , &function_setup)) { |
905 | result_ = function_setup; |
906 | return false; |
907 | } |
908 | HandleGraphMode(function_setup, output_sizes); |
909 | |
910 | AddRawOpExport(parameters); |
911 | AddTypeBasedDispatcherAlias(); |
912 | strings::StrAppend(&result_, "\n\n" ); |
913 | return true; |
914 | } |
915 | |
916 | bool GenEagerPythonOp::AddEagerFallbackCode( |
917 | const string& parameters, const std::vector<string>& output_sizes, |
918 | const string& num_outputs_expr, const string& eager_not_allowed_error, |
919 | const std::unordered_map<string, string>& type_annotations) { |
920 | AddDefLine( |
921 | strings::StrCat(function_name_, kEagerFallbackSuffix), |
922 | strings::StrCat(parameters, parameters.empty() ? "" : ", " , "ctx" )); |
923 | if (add_type_annotations_) { |
924 | AddReturnTypeAnnotation(type_annotations); |
925 | } |
926 | if (!eager_not_allowed_error.empty()) { |
927 | strings::StrAppend(&result_, " " , eager_not_allowed_error); |
928 | return true; |
929 | } |
930 | |
931 | string function_setup; |
932 | if (!GetEagerFunctionSetup(" " , &function_setup)) { |
933 | result_ = function_setup; |
934 | return false; |
935 | } |
936 | strings::StrAppend(&result_, function_setup); |
937 | |
938 | AddEagerInferredAttrs(" " ); |
939 | AddEagerInputCasts(" " ); |
940 | strings::StrAppend( |
941 | &result_, " _inputs_flat = " , FlattenInputs(nullptr, nullptr), "\n" ); |
942 | AddEagerAttrs(" " ); |
943 | AddEagerExecute(" " , num_outputs_expr); |
944 | |
945 | AddEagerFunctionTeardown(" " , output_sizes, |
946 | true /* execute_record_gradient */); |
947 | |
948 | return true; |
949 | } |
950 | |
951 | void GenEagerPythonOp::AddEagerFastPathExecute() { |
952 | string fastpath_execute_params = |
953 | strings::StrCat("_ctx, \"" , op_def_.name(), "\", " , "name" ); |
954 | string fallback_params; |
955 | |
956 | for (int i = 0; i < api_def_.in_arg_size(); i++) { |
957 | const string param_name = param_names_[i].GetRenameTo(); |
958 | strings::StrAppend(&fastpath_execute_params, ", " , param_name); |
959 | if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", " ); |
960 | strings::StrAppend(&fallback_params, param_name); |
961 | } |
962 | |
963 | for (const auto& attr : api_def_.attr()) { |
964 | if (inferred_attrs_.find(attr.name()) == inferred_attrs_.end()) { |
965 | strings::StrAppend(&fastpath_execute_params, ", \"" , attr.name(), "\", " , |
966 | attr.rename_to()); |
967 | |
968 | if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", " ); |
969 | strings::StrAppend(&fallback_params, attr.rename_to(), "=" , |
970 | attr.rename_to()); |
971 | } |
972 | } |
973 | |
974 | if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", " ); |
975 | strings::StrAppend(&fallback_params, "name=name" ); |
976 | |
977 | strings::StrAppend(&result_, " try:\n" ); |
978 | strings::StrAppend( |
979 | &result_, " " , "_result = pywrap_tfe.TFE_Py_FastPathExecute(\n" , |
980 | WordWrap(strings::StrCat(" " ), |
981 | strings::StrCat(fastpath_execute_params, ")" ), kRightMargin), |
982 | "\n" ); |
983 | |
984 | if (op_def_.output_arg_size() > 1) { |
985 | const string output_tuple_name = strings::StrCat( |
986 | "_" , python_op_gen_internal::AvoidPythonReserved(op_def_.name()), |
987 | "Output" ); |
988 | strings::StrAppend(&result_, " " , "_result = " , output_tuple_name, |
989 | "._make(_result)\n" ); |
990 | } |
991 | strings::StrAppend(&result_, " " , "return _result\n" ); |
992 | |
993 | // Handle fallback. |
994 | if (!fallback_params.empty()) strings::StrAppend(&fallback_params, ", " ); |
995 | strings::StrAppend(&fallback_params, "ctx=_ctx" ); |
996 | |
997 | // Any errors thrown from execute need to be unwrapped from |
998 | // _NotOkStatusException. |
999 | strings::StrAppend(&result_, " " , |
1000 | "except _core._NotOkStatusException as e:\n" ); |
1001 | strings::StrAppend(&result_, " " , |
1002 | "_ops.raise_from_not_ok_status(e, name)\n" ); |
1003 | |
1004 | strings::StrAppend(&result_, " " , "except _core._FallbackException:\n" ); |
1005 | strings::StrAppend(&result_, " pass\n" ); |
1006 | strings::StrAppend(&result_, " try:\n" ); |
1007 | AddTypeBasedDispatch(" " ); |
1008 | strings::StrAppend( |
1009 | &result_, " " , "return " , function_name_, kEagerFallbackSuffix, |
1010 | "(\n" , |
1011 | WordWrap(strings::StrCat(" " ), |
1012 | strings::StrCat(fallback_params, ")" ), kRightMargin), |
1013 | "\n" ); |
1014 | strings::StrAppend(&result_, " except _core._SymbolicException:\n" ); |
1015 | strings::StrAppend(&result_, |
1016 | " pass # Add nodes to the TensorFlow graph.\n" ); |
1017 | AddFallbackDispatch(" " ); |
1018 | } |
1019 | |
1020 | void GenEagerPythonOp::AddEagerInferredAttrs(const string& indentation) { |
1021 | // Figure out values for inferred attrs, and cast to eager tensors. |
1022 | for (int i = 0; i < op_def_.attr_size(); ++i) { |
1023 | const auto& attr(op_def_.attr(i)); |
1024 | const auto& api_def_attr(api_def_.attr(i)); |
1025 | auto arg_list = attr_to_args_.find(attr.name()); |
1026 | if (arg_list != attr_to_args_.end()) { |
1027 | if (attr.type() == "type" ) { |
1028 | std::vector<string> output_sizes; |
1029 | const string flattened = |
1030 | FlattenInputs(&arg_list->second, &output_sizes); |
1031 | string conversion = strings::StrCat("_execute.args_to_matching_eager(" , |
1032 | flattened, ", ctx" ); |
1033 | |
1034 | strings::StrAppend(&conversion, ", [" ); |
1035 | for (int t : attr.allowed_values().list().type()) { |
1036 | DataType dtype = static_cast<DataType>(t); |
1037 | const string py_dtype = |
1038 | python_op_gen_internal::DataTypeToPython(dtype, "_dtypes." ); |
1039 | strings::StrAppend(&conversion, py_dtype, ", " ); |
1040 | } |
1041 | strings::StrAppend(&conversion, "]" ); |
1042 | |
1043 | if (attr.has_default_value()) { |
1044 | strings::StrAppend( |
1045 | &conversion, ", " , |
1046 | python_op_gen_internal::AttrValueToPython( |
1047 | attr.type(), api_def_attr.default_value(), "_dtypes." )); |
1048 | } |
1049 | strings::StrAppend(&conversion, ")" ); |
1050 | const string var_name = AttrVarName(attr.name(), &attr_expressions_); |
1051 | if (output_sizes.size() == 1) { |
1052 | // Avoid creating a temporary variable in the case where |
1053 | // we can easily assign to the right value directly. |
1054 | const string inputs_var = |
1055 | param_names_[arg_list->second.front()].GetRenameTo(); |
1056 | if (output_sizes.front().empty()) { |
1057 | strings::StrAppend(&result_, indentation, var_name, ", (" , |
1058 | inputs_var, ",) = " , conversion, "\n" ); |
1059 | } else { |
1060 | strings::StrAppend(&result_, indentation, var_name, ", " , |
1061 | inputs_var, " = " , conversion, "\n" ); |
1062 | } |
1063 | } else { |
1064 | const string inputs_var = strings::StrCat("_inputs_" , attr.name()); |
1065 | strings::StrAppend(&result_, indentation, var_name, ", " , inputs_var, |
1066 | " = " , conversion, "\n" ); |
1067 | // Convert from a flat list of eager tensors back to the |
1068 | // parameter variables. |
1069 | Unflatten(indentation, output_sizes, inputs_var, &result_); |
1070 | std::vector<string> p; |
1071 | for (int j : arg_list->second) { |
1072 | p.emplace_back(param_names_[j].GetRenameTo()); |
1073 | } |
1074 | strings::StrAppend(&result_, indentation, VectorToTuple(p), " = " , |
1075 | inputs_var, "\n" ); |
1076 | } |
1077 | } else if (attr.type() == "list(type)" ) { |
1078 | // NOTE: We ignore default values for these attrs, since it is |
1079 | // unclear how you would use it, and the one use case is |
1080 | // parse_single_sequence_example which only needs it for |
1081 | // backwards compatibility. |
1082 | const string var_name = AttrVarName(attr.name(), &attr_expressions_); |
1083 | string inputs_var; |
1084 | string conversion; |
1085 | if (arg_list->second.size() > 1) { |
1086 | // If you have more than one list(tensor) argument, their types |
1087 | // have to match. |
1088 | std::vector<string> lists; |
1089 | for (auto iter = arg_list->second.begin(); |
1090 | iter != arg_list->second.end(); ++iter) { |
1091 | lists.push_back(param_names_[*iter].GetRenameTo()); |
1092 | } |
1093 | inputs_var = VectorToTuple(lists); |
1094 | conversion = "_execute.args_to_mixed_eager_tensors" ; |
1095 | } else { |
1096 | // For one list(tensor) argument, we just convert every |
1097 | // element of the list to an eager tensor. |
1098 | inputs_var = param_names_[arg_list->second.front()].GetRenameTo(); |
1099 | conversion = "_execute.convert_to_mixed_eager_tensors" ; |
1100 | } |
1101 | strings::StrAppend(&result_, indentation, var_name, ", " , inputs_var, |
1102 | " = " , conversion, "(" , inputs_var, ", ctx)\n" ); |
1103 | } |
1104 | } |
1105 | } |
1106 | } |
1107 | |
1108 | void GenEagerPythonOp::AddEagerInputCasts(const string& indentation) { |
1109 | // Cast remaining args to eager tensors |
1110 | for (int i = 0; i < op_def_.input_arg_size(); ++i) { |
1111 | const auto& arg(op_def_.input_arg(i)); |
1112 | if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) continue; |
1113 | const string& param = param_names_[i].GetRenameTo(); |
1114 | const string fn = arg.number_attr().empty() ? "" : "n_" ; |
1115 | const string dtype = |
1116 | python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes." ); |
1117 | strings::StrAppend(&result_, indentation, param, " = _ops.convert_" , fn, |
1118 | "to_tensor(" , param, ", " , dtype, ")\n" ); |
1119 | } |
1120 | } |
1121 | |
1122 | void GenEagerPythonOp::AddEagerAttrs(const string& indentation) { |
1123 | // Compute eager attrs |
1124 | if (op_def_.attr_size() > 0) { |
1125 | string attr_values; |
1126 | for (int i = 0; i < op_def_.attr_size(); ++i) { |
1127 | if (i > 0) strings::StrAppend(&attr_values, ", " ); |
1128 | const auto& attr_name(op_def_.attr(i).name()); |
1129 | strings::StrAppend(&attr_values, "\"" , attr_name, "\", " , |
1130 | attr_expressions_[attr_name]); |
1131 | } |
1132 | strings::StrAppend(&attr_values, ")" ); |
1133 | strings::StrAppend( |
1134 | &result_, |
1135 | WordWrap(indentation, strings::StrCat("_attrs = (" , attr_values), |
1136 | kRightMargin), |
1137 | "\n" ); |
1138 | } else { |
1139 | strings::StrAppend(&result_, indentation, "_attrs = None\n" ); |
1140 | } |
1141 | } |
1142 | |
1143 | void GenEagerPythonOp::AddEagerExecute(const string& indentation, |
1144 | const string& num_outputs_expr) { |
1145 | const string return_prefix = |
1146 | strings::StrCat(indentation, "_result = _execute.execute(" ); |
1147 | const string return_args = strings::StrCat( |
1148 | "b\"" , op_def_.name(), "\", " , num_outputs_expr, |
1149 | ", inputs=_inputs_flat, attrs=_attrs, ctx=ctx, name=name)" ); |
1150 | strings::StrAppend(&result_, |
1151 | // Wrap the arguments, and indent to the (. |
1152 | WordWrap(return_prefix, return_args, kRightMargin), "\n" ); |
1153 | } |
1154 | |
1155 | void GenEagerPythonOp::AddFallbackDispatch(const string& prefix) { |
1156 | if (api_def_.visibility() != ApiDef::VISIBLE) return; |
1157 | |
1158 | strings::StrAppend(&result_, prefix, "except (TypeError, ValueError):\n" ); |
1159 | strings::StrAppend(&result_, prefix, " _result = _dispatch.dispatch(\n" ); |
1160 | AddBodyNoReturn(strings::StrCat(prefix, " " , function_name_, |
1161 | ", " |
1162 | "(), dict(" )); |
1163 | strings::StrAppend(&result_, prefix, " )\n" ); |
1164 | strings::StrAppend(&result_, prefix, |
1165 | " if _result is not " |
1166 | "_dispatch.OpDispatcher.NOT_SUPPORTED:\n" ); |
1167 | strings::StrAppend(&result_, prefix, " return _result\n" ); |
1168 | strings::StrAppend(&result_, prefix, " raise\n" ); |
1169 | } |
1170 | |
1171 | void GenEagerPythonOp::AddTypeBasedDispatcherAlias() { |
1172 | // It's possible for the name of a parameter to be the same as the name of |
1173 | // an op, in which case the parameter shadows the op's function. To avoid |
1174 | // this, we add a private variable with the dispatcher, and access that |
1175 | // directly. |
1176 | if (api_def_.visibility() == ApiDef::VISIBLE) { |
1177 | strings::StrAppend(&result_, "_dispatcher_for_" , function_name_, " = " , |
1178 | function_name_, "._tf_type_based_dispatcher.Dispatch\n" ); |
1179 | } |
1180 | } |
1181 | void GenEagerPythonOp::AddTypeBasedDispatch(const string& prefix) { |
1182 | if (api_def_.visibility() != ApiDef::VISIBLE) return; |
1183 | std::string args("(" ); |
1184 | for (const auto& name : param_names_) { |
1185 | strings::StrAppend(&args, name.GetRenameTo(), ", " ); |
1186 | } |
1187 | strings::StrAppend(&args, "name,), None" ); |
1188 | |
1189 | strings::StrAppend( |
1190 | &result_, prefix, "_result = " , "_dispatcher_for_" , function_name_, "(\n" , |
1191 | WordWrap(strings::StrCat(prefix, " " ), args, kRightMargin), ")\n" ); |
1192 | strings::StrAppend(&result_, prefix, "if _result is not NotImplemented:\n" , |
1193 | prefix, " return _result\n" ); |
1194 | } |
1195 | |
1196 | void GenEagerPythonOp::AddRawOpExport(const string& parameters) { |
1197 | // Example: |
1198 | // |
1199 | // Identity = tf_export("raw_ops.Identity")(_ops._to_raw_op(identity)) |
1200 | const string raw_function_name = |
1201 | python_op_gen_internal::AvoidPythonReserved(op_def_.name()); |
1202 | strings::StrAppend(&result_, raw_function_name, " = tf_export(\"raw_ops." , |
1203 | raw_function_name, "\")" , "(_ops.to_raw_op(" , |
1204 | function_name_, "))\n" ); |
1205 | } |
1206 | |
1207 | string GetPythonOpsImpl( |
1208 | const OpList& ops, const ApiDefMap& api_defs, |
1209 | const std::vector<string>& hidden_ops, const string& source_file_name = "" , |
1210 | const std::unordered_set<string> type_annotate_ops = {}) { |
1211 | string result; |
1212 | // Header |
1213 | // TODO(josh11b): Mention the library for which wrappers are being generated. |
1214 | strings::StrAppend(&result, R"("""Python wrappers around TensorFlow ops. |
1215 | |
1216 | This file is MACHINE GENERATED! Do not edit. |
1217 | )" ); |
1218 | |
1219 | // Mention the original source file so someone tracing back through |
1220 | // generated Python code will know where to look next. |
1221 | if (!source_file_name.empty()) { |
1222 | strings::StrAppend(&result, "Original C++ source file: " ); |
1223 | strings::StrAppend(&result, source_file_name); |
1224 | strings::StrAppend(&result, "\n" ); |
1225 | } |
1226 | |
1227 | strings::StrAppend(&result, R"(""" |
1228 | |
1229 | import collections |
1230 | |
1231 | from tensorflow.python import pywrap_tfe as pywrap_tfe |
1232 | from tensorflow.python.eager import context as _context |
1233 | from tensorflow.python.eager import core as _core |
1234 | from tensorflow.python.eager import execute as _execute |
1235 | from tensorflow.python.framework import dtypes as _dtypes |
1236 | |
1237 | from tensorflow.python.framework import op_def_registry as _op_def_registry |
1238 | from tensorflow.python.framework import ops as _ops |
1239 | from tensorflow.python.framework import op_def_library as _op_def_library |
1240 | from tensorflow.python.util.deprecation import deprecated_endpoints |
1241 | from tensorflow.python.util import dispatch as _dispatch |
1242 | from tensorflow.python.util.tf_export import tf_export |
1243 | |
1244 | from typing import TypeVar |
1245 | )" ); |
1246 | |
1247 | for (const auto& op_def : ops.op()) { |
1248 | const auto* api_def = api_defs.GetApiDef(op_def.name()); |
1249 | |
1250 | if (api_def->visibility() == ApiDef::SKIP) { |
1251 | continue; |
1252 | } |
1253 | // An op is hidden if either its ApiDef visibility is HIDDEN |
1254 | // or it is in the hidden_ops list. |
1255 | bool is_hidden = api_def->visibility() == ApiDef::HIDDEN; |
1256 | bool hidden_by_api_def = is_hidden; |
1257 | if (!is_hidden) { |
1258 | for (const string& hidden : hidden_ops) { |
1259 | if (op_def.name() == hidden) { |
1260 | is_hidden = true; |
1261 | break; |
1262 | } |
1263 | } |
1264 | } |
1265 | |
1266 | string function_name; |
1267 | python_op_gen_internal::GenerateLowerCaseOpName(op_def.name(), |
1268 | &function_name); |
1269 | bool is_reserved = python_op_gen_internal::IsPythonReserved(function_name); |
1270 | |
1271 | // Prefix an op with underscore if the op is listed in hidden_ops or |
1272 | // name is reserved or it is of the exceptions in IsOpWithUnderscorePrefix. |
1273 | // Do not add underscores to ops set to HIDDEN in ApiDef otherwise. |
1274 | // TODO(annarev): don't prefix with underscores even if op is in hidden_ops. |
1275 | if (is_hidden) { |
1276 | if (!hidden_by_api_def || is_reserved || |
1277 | python_op_gen_internal::IsOpWithUnderscorePrefix(function_name)) { |
1278 | function_name = strings::StrCat("_" , function_name); |
1279 | } |
1280 | } else if (is_reserved) { |
1281 | // When users create custom python wrappers, they may link in the |
1282 | // default op registry by accident, and because they can't |
1283 | // enumerate all 'hidden' symbols, this guard is to prevent |
1284 | // instantiating a python reserved word in their wrapper. |
1285 | continue; |
1286 | } |
1287 | |
1288 | auto iter = type_annotate_ops.find(op_def.name()); |
1289 | bool add_type_annotations = iter != type_annotate_ops.end(); |
1290 | |
1291 | strings::StrAppend(&result, |
1292 | GetEagerPythonOp(op_def, *api_def, function_name, |
1293 | add_type_annotations)); |
1294 | } |
1295 | |
1296 | return result; |
1297 | } |
1298 | |
1299 | } // namespace |
1300 | |
1301 | string GetPythonOps(const OpList& ops, const ApiDefMap& api_defs, |
1302 | const std::vector<string>& hidden_ops, |
1303 | const string& source_file_name, |
1304 | const std::unordered_set<string> type_annotate_ops) { |
1305 | return GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name, |
1306 | type_annotate_ops); |
1307 | } |
1308 | |
1309 | void PrintPythonOps(const OpList& ops, const ApiDefMap& api_defs, |
1310 | const std::vector<string>& hidden_ops, |
1311 | const string& source_file_name, |
1312 | const std::unordered_set<string> type_annotate_ops) { |
1313 | printf("%s" , GetPythonOpsImpl(ops, api_defs, hidden_ops, source_file_name, |
1314 | type_annotate_ops) |
1315 | .c_str()); |
1316 | } |
1317 | |
1318 | string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) { |
1319 | OpList ops; |
1320 | ops.ParseFromArray(op_list_buf, op_list_len); |
1321 | |
1322 | ApiDefMap api_def_map(ops); |
1323 | return GetPythonOpsImpl(ops, api_def_map, {}); |
1324 | } |
1325 | |
1326 | string GetArgAnnotation( |
1327 | const OpDef::ArgDef& arg, |
1328 | const std::unordered_map<string, string>& type_annotations) { |
1329 | if (!arg.type_attr().empty()) { |
1330 | // Get the correct TypeVar if arg maps to an attr |
1331 | return "_ops.Tensor[" + type_annotations.at(arg.type_attr()) + "]" ; |
1332 | } else { |
1333 | // Get the dtype of the Tensor |
1334 | const string py_dtype = |
1335 | python_op_gen_internal::DataTypeToPython(arg.type(), "_dtypes." ); |
1336 | return "_ops.Tensor[" + dtype_type.at(py_dtype) + "]" ; |
1337 | } |
1338 | |
1339 | return "Any" ; |
1340 | } |
1341 | |
1342 | } // namespace tensorflow |
1343 | |