1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/op_gen_lib.h" |
17 | |
18 | #include <algorithm> |
19 | #include <vector> |
20 | |
21 | #include "absl/strings/escaping.h" |
22 | #include "tensorflow/core/framework/attr_value.pb.h" |
23 | #include "tensorflow/core/lib/core/errors.h" |
24 | #include "tensorflow/core/lib/gtl/map_util.h" |
25 | #include "tensorflow/core/lib/strings/str_util.h" |
26 | #include "tensorflow/core/lib/strings/strcat.h" |
27 | #include "tensorflow/core/platform/errors.h" |
28 | #include "tensorflow/core/platform/protobuf.h" |
29 | #include "tensorflow/core/util/proto/proto_utils.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | string WordWrap(StringPiece prefix, StringPiece str, int width) { |
34 | const string indent_next_line = "\n" + Spaces(prefix.size()); |
35 | width -= prefix.size(); |
36 | string result; |
37 | strings::StrAppend(&result, prefix); |
38 | |
39 | while (!str.empty()) { |
40 | if (static_cast<int>(str.size()) <= width) { |
41 | // Remaining text fits on one line. |
42 | strings::StrAppend(&result, str); |
43 | break; |
44 | } |
45 | auto space = str.rfind(' ', width); |
46 | if (space == StringPiece::npos) { |
47 | // Rather make a too-long line and break at a space. |
48 | space = str.find(' '); |
49 | if (space == StringPiece::npos) { |
50 | strings::StrAppend(&result, str); |
51 | break; |
52 | } |
53 | } |
54 | // Breaking at character at position <space>. |
55 | StringPiece to_append = str.substr(0, space); |
56 | str.remove_prefix(space + 1); |
57 | // Remove spaces at break. |
58 | while (str_util::EndsWith(to_append, " " )) { |
59 | to_append.remove_suffix(1); |
60 | } |
61 | while (absl::ConsumePrefix(&str, " " )) { |
62 | } |
63 | |
64 | // Go on to the next line. |
65 | strings::StrAppend(&result, to_append); |
66 | if (!str.empty()) strings::StrAppend(&result, indent_next_line); |
67 | } |
68 | |
69 | return result; |
70 | } |
71 | |
72 | bool ConsumeEquals(StringPiece* description) { |
73 | if (absl::ConsumePrefix(description, "=" )) { |
74 | while (absl::ConsumePrefix(description, |
75 | " " )) { // Also remove spaces after "=". |
76 | } |
77 | return true; |
78 | } |
79 | return false; |
80 | } |
81 | |
82 | // Split `*orig` into two pieces at the first occurrence of `split_ch`. |
83 | // Returns whether `split_ch` was found. Afterwards, `*before_split` |
84 | // contains the maximum prefix of the input `*orig` that doesn't |
85 | // contain `split_ch`, and `*orig` contains everything after the |
86 | // first `split_ch`. |
87 | static bool SplitAt(char split_ch, StringPiece* orig, |
88 | StringPiece* before_split) { |
89 | auto pos = orig->find(split_ch); |
90 | if (pos == StringPiece::npos) { |
91 | *before_split = *orig; |
92 | *orig = StringPiece(); |
93 | return false; |
94 | } else { |
95 | *before_split = orig->substr(0, pos); |
96 | orig->remove_prefix(pos + 1); |
97 | return true; |
98 | } |
99 | } |
100 | |
101 | // Does this line start with "<spaces><field>:" where "<field>" is |
102 | // in multi_line_fields? Sets *colon_pos to the position of the colon. |
103 | static bool StartsWithFieldName(StringPiece line, |
104 | const std::vector<string>& multi_line_fields) { |
105 | StringPiece up_to_colon; |
106 | if (!SplitAt(':', &line, &up_to_colon)) return false; |
107 | while (absl::ConsumePrefix(&up_to_colon, " " )) |
108 | ; // Remove leading spaces. |
109 | for (const auto& field : multi_line_fields) { |
110 | if (up_to_colon == field) { |
111 | return true; |
112 | } |
113 | } |
114 | return false; |
115 | } |
116 | |
117 | static bool ConvertLine(StringPiece line, |
118 | const std::vector<string>& multi_line_fields, |
119 | string* ml) { |
120 | // Is this a field we should convert? |
121 | if (!StartsWithFieldName(line, multi_line_fields)) { |
122 | return false; |
123 | } |
124 | // Has a matching field name, so look for "..." after the colon. |
125 | StringPiece up_to_colon; |
126 | StringPiece after_colon = line; |
127 | SplitAt(':', &after_colon, &up_to_colon); |
128 | while (absl::ConsumePrefix(&after_colon, " " )) |
129 | ; // Remove leading spaces. |
130 | if (!absl::ConsumePrefix(&after_colon, "\"" )) { |
131 | // We only convert string fields, so don't convert this line. |
132 | return false; |
133 | } |
134 | auto last_quote = after_colon.rfind('\"'); |
135 | if (last_quote == StringPiece::npos) { |
136 | // Error: we don't see the expected matching quote, abort the conversion. |
137 | return false; |
138 | } |
139 | StringPiece escaped = after_colon.substr(0, last_quote); |
140 | StringPiece suffix = after_colon.substr(last_quote + 1); |
141 | // We've now parsed line into '<up_to_colon>: "<escaped>"<suffix>' |
142 | |
143 | string unescaped; |
144 | if (!absl::CUnescape(escaped, &unescaped, nullptr)) { |
145 | // Error unescaping, abort the conversion. |
146 | return false; |
147 | } |
148 | // No more errors possible at this point. |
149 | |
150 | // Find a string to mark the end that isn't in unescaped. |
151 | string end = "END" ; |
152 | for (int s = 0; unescaped.find(end) != string::npos; ++s) { |
153 | end = strings::StrCat("END" , s); |
154 | } |
155 | |
156 | // Actually start writing the converted output. |
157 | strings::StrAppend(ml, up_to_colon, ": <<" , end, "\n" , unescaped, "\n" , end); |
158 | if (!suffix.empty()) { |
159 | // Output suffix, in case there was a trailing comment in the source. |
160 | strings::StrAppend(ml, suffix); |
161 | } |
162 | strings::StrAppend(ml, "\n" ); |
163 | return true; |
164 | } |
165 | |
166 | string PBTxtToMultiline(StringPiece pbtxt, |
167 | const std::vector<string>& multi_line_fields) { |
168 | string ml; |
169 | // Probably big enough, since the input and output are about the |
170 | // same size, but just a guess. |
171 | ml.reserve(pbtxt.size() * (17. / 16)); |
172 | StringPiece line; |
173 | while (!pbtxt.empty()) { |
174 | // Split pbtxt into its first line and everything after. |
175 | SplitAt('\n', &pbtxt, &line); |
176 | // Convert line or output it unchanged |
177 | if (!ConvertLine(line, multi_line_fields, &ml)) { |
178 | strings::StrAppend(&ml, line, "\n" ); |
179 | } |
180 | } |
181 | return ml; |
182 | } |
183 | |
184 | // Given a single line of text `line` with first : at `colon`, determine if |
185 | // there is an "<<END" expression after the colon and if so return true and set |
186 | // `*end` to everything after the "<<". |
187 | static bool FindMultiline(StringPiece line, size_t colon, string* end) { |
188 | if (colon == StringPiece::npos) return false; |
189 | line.remove_prefix(colon + 1); |
190 | while (absl::ConsumePrefix(&line, " " )) { |
191 | } |
192 | if (absl::ConsumePrefix(&line, "<<" )) { |
193 | *end = string(line); |
194 | return true; |
195 | } |
196 | return false; |
197 | } |
198 | |
199 | string PBTxtFromMultiline(StringPiece multiline_pbtxt) { |
200 | string pbtxt; |
201 | // Probably big enough, since the input and output are about the |
202 | // same size, but just a guess. |
203 | pbtxt.reserve(multiline_pbtxt.size() * (33. / 32)); |
204 | StringPiece line; |
205 | while (!multiline_pbtxt.empty()) { |
206 | // Split multiline_pbtxt into its first line and everything after. |
207 | if (!SplitAt('\n', &multiline_pbtxt, &line)) { |
208 | strings::StrAppend(&pbtxt, line); |
209 | break; |
210 | } |
211 | |
212 | string end; |
213 | auto colon = line.find(':'); |
214 | if (!FindMultiline(line, colon, &end)) { |
215 | // Normal case: not a multi-line string, just output the line as-is. |
216 | strings::StrAppend(&pbtxt, line, "\n" ); |
217 | continue; |
218 | } |
219 | |
220 | // Multi-line case: |
221 | // something: <<END |
222 | // xx |
223 | // yy |
224 | // END |
225 | // Should be converted to: |
226 | // something: "xx\nyy" |
227 | |
228 | // Output everything up to the colon (" something:"). |
229 | strings::StrAppend(&pbtxt, line.substr(0, colon + 1)); |
230 | |
231 | // Add every line to unescaped until we see the "END" string. |
232 | string unescaped; |
233 | bool first = true; |
234 | while (!multiline_pbtxt.empty()) { |
235 | SplitAt('\n', &multiline_pbtxt, &line); |
236 | if (absl::ConsumePrefix(&line, end)) break; |
237 | if (first) { |
238 | first = false; |
239 | } else { |
240 | unescaped.push_back('\n'); |
241 | } |
242 | strings::StrAppend(&unescaped, line); |
243 | line = StringPiece(); |
244 | } |
245 | |
246 | // Escape what we extracted and then output it in quotes. |
247 | strings::StrAppend(&pbtxt, " \"" , absl::CEscape(unescaped), "\"" , line, |
248 | "\n" ); |
249 | } |
250 | return pbtxt; |
251 | } |
252 | |
253 | static void StringReplace(const string& from, const string& to, string* s) { |
254 | // Split *s into pieces delimited by `from`. |
255 | std::vector<string> split; |
256 | string::size_type pos = 0; |
257 | while (pos < s->size()) { |
258 | auto found = s->find(from, pos); |
259 | if (found == string::npos) { |
260 | split.push_back(s->substr(pos)); |
261 | break; |
262 | } else { |
263 | split.push_back(s->substr(pos, found - pos)); |
264 | pos = found + from.size(); |
265 | if (pos == s->size()) { // handle case where `from` is at the very end. |
266 | split.push_back("" ); |
267 | } |
268 | } |
269 | } |
270 | // Join the pieces back together with a new delimiter. |
271 | *s = absl::StrJoin(split, to); |
272 | } |
273 | |
274 | static void RenameInDocs(const string& from, const string& to, |
275 | ApiDef* api_def) { |
276 | const string from_quoted = strings::StrCat("`" , from, "`" ); |
277 | const string to_quoted = strings::StrCat("`" , to, "`" ); |
278 | for (int i = 0; i < api_def->in_arg_size(); ++i) { |
279 | if (!api_def->in_arg(i).description().empty()) { |
280 | StringReplace(from_quoted, to_quoted, |
281 | api_def->mutable_in_arg(i)->mutable_description()); |
282 | } |
283 | } |
284 | for (int i = 0; i < api_def->out_arg_size(); ++i) { |
285 | if (!api_def->out_arg(i).description().empty()) { |
286 | StringReplace(from_quoted, to_quoted, |
287 | api_def->mutable_out_arg(i)->mutable_description()); |
288 | } |
289 | } |
290 | for (int i = 0; i < api_def->attr_size(); ++i) { |
291 | if (!api_def->attr(i).description().empty()) { |
292 | StringReplace(from_quoted, to_quoted, |
293 | api_def->mutable_attr(i)->mutable_description()); |
294 | } |
295 | } |
296 | if (!api_def->summary().empty()) { |
297 | StringReplace(from_quoted, to_quoted, api_def->mutable_summary()); |
298 | } |
299 | if (!api_def->description().empty()) { |
300 | StringReplace(from_quoted, to_quoted, api_def->mutable_description()); |
301 | } |
302 | } |
303 | |
304 | namespace { |
305 | |
306 | // Initializes given ApiDef with data in OpDef. |
307 | void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) { |
308 | api_def->set_graph_op_name(op_def.name()); |
309 | api_def->set_visibility(ApiDef::VISIBLE); |
310 | |
311 | auto* endpoint = api_def->add_endpoint(); |
312 | endpoint->set_name(op_def.name()); |
313 | |
314 | for (const auto& op_in_arg : op_def.input_arg()) { |
315 | auto* api_in_arg = api_def->add_in_arg(); |
316 | api_in_arg->set_name(op_in_arg.name()); |
317 | api_in_arg->set_rename_to(op_in_arg.name()); |
318 | api_in_arg->set_description(op_in_arg.description()); |
319 | |
320 | *api_def->add_arg_order() = op_in_arg.name(); |
321 | } |
322 | for (const auto& op_out_arg : op_def.output_arg()) { |
323 | auto* api_out_arg = api_def->add_out_arg(); |
324 | api_out_arg->set_name(op_out_arg.name()); |
325 | api_out_arg->set_rename_to(op_out_arg.name()); |
326 | api_out_arg->set_description(op_out_arg.description()); |
327 | } |
328 | for (const auto& op_attr : op_def.attr()) { |
329 | auto* api_attr = api_def->add_attr(); |
330 | api_attr->set_name(op_attr.name()); |
331 | api_attr->set_rename_to(op_attr.name()); |
332 | if (op_attr.has_default_value()) { |
333 | *api_attr->mutable_default_value() = op_attr.default_value(); |
334 | } |
335 | api_attr->set_description(op_attr.description()); |
336 | } |
337 | api_def->set_summary(op_def.summary()); |
338 | api_def->set_description(op_def.description()); |
339 | } |
340 | |
341 | // Updates base_arg based on overrides in new_arg. |
342 | void MergeArg(ApiDef::Arg* base_arg, const ApiDef::Arg& new_arg) { |
343 | if (!new_arg.rename_to().empty()) { |
344 | base_arg->set_rename_to(new_arg.rename_to()); |
345 | } |
346 | if (!new_arg.description().empty()) { |
347 | base_arg->set_description(new_arg.description()); |
348 | } |
349 | } |
350 | |
351 | // Updates base_attr based on overrides in new_attr. |
352 | void MergeAttr(ApiDef::Attr* base_attr, const ApiDef::Attr& new_attr) { |
353 | if (!new_attr.rename_to().empty()) { |
354 | base_attr->set_rename_to(new_attr.rename_to()); |
355 | } |
356 | if (new_attr.has_default_value()) { |
357 | *base_attr->mutable_default_value() = new_attr.default_value(); |
358 | } |
359 | if (!new_attr.description().empty()) { |
360 | base_attr->set_description(new_attr.description()); |
361 | } |
362 | } |
363 | |
364 | // Updates base_api_def based on overrides in new_api_def. |
365 | Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) { |
366 | // Merge visibility |
367 | if (new_api_def.visibility() != ApiDef::DEFAULT_VISIBILITY) { |
368 | base_api_def->set_visibility(new_api_def.visibility()); |
369 | } |
370 | // Merge endpoints |
371 | if (new_api_def.endpoint_size() > 0) { |
372 | base_api_def->clear_endpoint(); |
373 | std::copy( |
374 | new_api_def.endpoint().begin(), new_api_def.endpoint().end(), |
375 | protobuf::RepeatedFieldBackInserter(base_api_def->mutable_endpoint())); |
376 | } |
377 | // Merge args |
378 | for (const auto& new_arg : new_api_def.in_arg()) { |
379 | bool found_base_arg = false; |
380 | for (int i = 0; i < base_api_def->in_arg_size(); ++i) { |
381 | auto* base_arg = base_api_def->mutable_in_arg(i); |
382 | if (base_arg->name() == new_arg.name()) { |
383 | MergeArg(base_arg, new_arg); |
384 | found_base_arg = true; |
385 | break; |
386 | } |
387 | } |
388 | if (!found_base_arg) { |
389 | return errors::FailedPrecondition("Argument " , new_arg.name(), |
390 | " not defined in base api for " , |
391 | base_api_def->graph_op_name()); |
392 | } |
393 | } |
394 | for (const auto& new_arg : new_api_def.out_arg()) { |
395 | bool found_base_arg = false; |
396 | for (int i = 0; i < base_api_def->out_arg_size(); ++i) { |
397 | auto* base_arg = base_api_def->mutable_out_arg(i); |
398 | if (base_arg->name() == new_arg.name()) { |
399 | MergeArg(base_arg, new_arg); |
400 | found_base_arg = true; |
401 | break; |
402 | } |
403 | } |
404 | if (!found_base_arg) { |
405 | return errors::FailedPrecondition("Argument " , new_arg.name(), |
406 | " not defined in base api for " , |
407 | base_api_def->graph_op_name()); |
408 | } |
409 | } |
410 | // Merge arg order |
411 | if (new_api_def.arg_order_size() > 0) { |
412 | // Validate that new arg_order is correct. |
413 | if (new_api_def.arg_order_size() != base_api_def->arg_order_size()) { |
414 | return errors::FailedPrecondition( |
415 | "Invalid number of arguments " , new_api_def.arg_order_size(), " for " , |
416 | base_api_def->graph_op_name(), |
417 | ". Expected: " , base_api_def->arg_order_size()); |
418 | } |
419 | if (!std::is_permutation(new_api_def.arg_order().begin(), |
420 | new_api_def.arg_order().end(), |
421 | base_api_def->arg_order().begin())) { |
422 | return errors::FailedPrecondition( |
423 | "Invalid arg_order: " , absl::StrJoin(new_api_def.arg_order(), ", " ), |
424 | " for " , base_api_def->graph_op_name(), |
425 | ". All elements in arg_order override must match base arg_order: " , |
426 | absl::StrJoin(base_api_def->arg_order(), ", " )); |
427 | } |
428 | |
429 | base_api_def->clear_arg_order(); |
430 | std::copy( |
431 | new_api_def.arg_order().begin(), new_api_def.arg_order().end(), |
432 | protobuf::RepeatedFieldBackInserter(base_api_def->mutable_arg_order())); |
433 | } |
434 | // Merge attributes |
435 | for (const auto& new_attr : new_api_def.attr()) { |
436 | bool found_base_attr = false; |
437 | for (int i = 0; i < base_api_def->attr_size(); ++i) { |
438 | auto* base_attr = base_api_def->mutable_attr(i); |
439 | if (base_attr->name() == new_attr.name()) { |
440 | MergeAttr(base_attr, new_attr); |
441 | found_base_attr = true; |
442 | break; |
443 | } |
444 | } |
445 | if (!found_base_attr) { |
446 | return errors::FailedPrecondition("Attribute " , new_attr.name(), |
447 | " not defined in base api for " , |
448 | base_api_def->graph_op_name()); |
449 | } |
450 | } |
451 | // Merge summary |
452 | if (!new_api_def.summary().empty()) { |
453 | base_api_def->set_summary(new_api_def.summary()); |
454 | } |
455 | // Merge description |
456 | auto description = new_api_def.description().empty() |
457 | ? base_api_def->description() |
458 | : new_api_def.description(); |
459 | |
460 | if (!new_api_def.description_prefix().empty()) { |
461 | description = |
462 | strings::StrCat(new_api_def.description_prefix(), "\n" , description); |
463 | } |
464 | if (!new_api_def.description_suffix().empty()) { |
465 | description = |
466 | strings::StrCat(description, "\n" , new_api_def.description_suffix()); |
467 | } |
468 | base_api_def->set_description(description); |
469 | return OkStatus(); |
470 | } |
471 | } // namespace |
472 | |
473 | ApiDefMap::ApiDefMap(const OpList& op_list) { |
474 | for (const auto& op : op_list.op()) { |
475 | ApiDef api_def; |
476 | InitApiDefFromOpDef(op, &api_def); |
477 | map_[op.name()] = api_def; |
478 | } |
479 | } |
480 | |
481 | ApiDefMap::~ApiDefMap() {} |
482 | |
483 | Status ApiDefMap::LoadFileList(Env* env, const std::vector<string>& filenames) { |
484 | for (const auto& filename : filenames) { |
485 | TF_RETURN_IF_ERROR(LoadFile(env, filename)); |
486 | } |
487 | return OkStatus(); |
488 | } |
489 | |
490 | Status ApiDefMap::LoadFile(Env* env, const string& filename) { |
491 | if (filename.empty()) return OkStatus(); |
492 | string contents; |
493 | TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents)); |
494 | Status status = LoadApiDef(contents); |
495 | if (!status.ok()) { |
496 | // Return failed status annotated with filename to aid in debugging. |
497 | return errors::CreateWithUpdatedMessage( |
498 | status, strings::StrCat("Error parsing ApiDef file " , filename, ": " , |
499 | status.error_message())); |
500 | } |
501 | return OkStatus(); |
502 | } |
503 | |
504 | Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) { |
505 | const string contents = PBTxtFromMultiline(api_def_file_contents); |
506 | ApiDefs api_defs; |
507 | TF_RETURN_IF_ERROR( |
508 | proto_utils::ParseTextFormatFromString(contents, &api_defs)); |
509 | for (const auto& api_def : api_defs.op()) { |
510 | // Check if the op definition is loaded. If op definition is not |
511 | // loaded, then we just skip this ApiDef. |
512 | if (map_.find(api_def.graph_op_name()) != map_.end()) { |
513 | // Overwrite current api def with data in api_def. |
514 | TF_RETURN_IF_ERROR(MergeApiDefs(&map_[api_def.graph_op_name()], api_def)); |
515 | } |
516 | } |
517 | return OkStatus(); |
518 | } |
519 | |
520 | void ApiDefMap::UpdateDocs() { |
521 | for (auto& name_and_api_def : map_) { |
522 | auto& api_def = name_and_api_def.second; |
523 | CHECK_GT(api_def.endpoint_size(), 0); |
524 | const string canonical_name = api_def.endpoint(0).name(); |
525 | if (api_def.graph_op_name() != canonical_name) { |
526 | RenameInDocs(api_def.graph_op_name(), canonical_name, &api_def); |
527 | } |
528 | for (const auto& in_arg : api_def.in_arg()) { |
529 | if (in_arg.name() != in_arg.rename_to()) { |
530 | RenameInDocs(in_arg.name(), in_arg.rename_to(), &api_def); |
531 | } |
532 | } |
533 | for (const auto& out_arg : api_def.out_arg()) { |
534 | if (out_arg.name() != out_arg.rename_to()) { |
535 | RenameInDocs(out_arg.name(), out_arg.rename_to(), &api_def); |
536 | } |
537 | } |
538 | for (const auto& attr : api_def.attr()) { |
539 | if (attr.name() != attr.rename_to()) { |
540 | RenameInDocs(attr.name(), attr.rename_to(), &api_def); |
541 | } |
542 | } |
543 | } |
544 | } |
545 | |
546 | const tensorflow::ApiDef* ApiDefMap::GetApiDef(const string& name) const { |
547 | return gtl::FindOrNull(map_, name); |
548 | } |
549 | } // namespace tensorflow |
550 | |