1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/op_gen_lib.h"
17
18#include <algorithm>
19#include <vector>
20
21#include "absl/strings/escaping.h"
22#include "tensorflow/core/framework/attr_value.pb.h"
23#include "tensorflow/core/lib/core/errors.h"
24#include "tensorflow/core/lib/gtl/map_util.h"
25#include "tensorflow/core/lib/strings/str_util.h"
26#include "tensorflow/core/lib/strings/strcat.h"
27#include "tensorflow/core/platform/errors.h"
28#include "tensorflow/core/platform/protobuf.h"
29#include "tensorflow/core/util/proto/proto_utils.h"
30
31namespace tensorflow {
32
33string WordWrap(StringPiece prefix, StringPiece str, int width) {
34 const string indent_next_line = "\n" + Spaces(prefix.size());
35 width -= prefix.size();
36 string result;
37 strings::StrAppend(&result, prefix);
38
39 while (!str.empty()) {
40 if (static_cast<int>(str.size()) <= width) {
41 // Remaining text fits on one line.
42 strings::StrAppend(&result, str);
43 break;
44 }
45 auto space = str.rfind(' ', width);
46 if (space == StringPiece::npos) {
47 // Rather make a too-long line and break at a space.
48 space = str.find(' ');
49 if (space == StringPiece::npos) {
50 strings::StrAppend(&result, str);
51 break;
52 }
53 }
54 // Breaking at character at position <space>.
55 StringPiece to_append = str.substr(0, space);
56 str.remove_prefix(space + 1);
57 // Remove spaces at break.
58 while (str_util::EndsWith(to_append, " ")) {
59 to_append.remove_suffix(1);
60 }
61 while (absl::ConsumePrefix(&str, " ")) {
62 }
63
64 // Go on to the next line.
65 strings::StrAppend(&result, to_append);
66 if (!str.empty()) strings::StrAppend(&result, indent_next_line);
67 }
68
69 return result;
70}
71
72bool ConsumeEquals(StringPiece* description) {
73 if (absl::ConsumePrefix(description, "=")) {
74 while (absl::ConsumePrefix(description,
75 " ")) { // Also remove spaces after "=".
76 }
77 return true;
78 }
79 return false;
80}
81
82// Split `*orig` into two pieces at the first occurrence of `split_ch`.
83// Returns whether `split_ch` was found. Afterwards, `*before_split`
84// contains the maximum prefix of the input `*orig` that doesn't
85// contain `split_ch`, and `*orig` contains everything after the
86// first `split_ch`.
87static bool SplitAt(char split_ch, StringPiece* orig,
88 StringPiece* before_split) {
89 auto pos = orig->find(split_ch);
90 if (pos == StringPiece::npos) {
91 *before_split = *orig;
92 *orig = StringPiece();
93 return false;
94 } else {
95 *before_split = orig->substr(0, pos);
96 orig->remove_prefix(pos + 1);
97 return true;
98 }
99}
100
101// Does this line start with "<spaces><field>:" where "<field>" is
102// in multi_line_fields? Sets *colon_pos to the position of the colon.
103static bool StartsWithFieldName(StringPiece line,
104 const std::vector<string>& multi_line_fields) {
105 StringPiece up_to_colon;
106 if (!SplitAt(':', &line, &up_to_colon)) return false;
107 while (absl::ConsumePrefix(&up_to_colon, " "))
108 ; // Remove leading spaces.
109 for (const auto& field : multi_line_fields) {
110 if (up_to_colon == field) {
111 return true;
112 }
113 }
114 return false;
115}
116
117static bool ConvertLine(StringPiece line,
118 const std::vector<string>& multi_line_fields,
119 string* ml) {
120 // Is this a field we should convert?
121 if (!StartsWithFieldName(line, multi_line_fields)) {
122 return false;
123 }
124 // Has a matching field name, so look for "..." after the colon.
125 StringPiece up_to_colon;
126 StringPiece after_colon = line;
127 SplitAt(':', &after_colon, &up_to_colon);
128 while (absl::ConsumePrefix(&after_colon, " "))
129 ; // Remove leading spaces.
130 if (!absl::ConsumePrefix(&after_colon, "\"")) {
131 // We only convert string fields, so don't convert this line.
132 return false;
133 }
134 auto last_quote = after_colon.rfind('\"');
135 if (last_quote == StringPiece::npos) {
136 // Error: we don't see the expected matching quote, abort the conversion.
137 return false;
138 }
139 StringPiece escaped = after_colon.substr(0, last_quote);
140 StringPiece suffix = after_colon.substr(last_quote + 1);
141 // We've now parsed line into '<up_to_colon>: "<escaped>"<suffix>'
142
143 string unescaped;
144 if (!absl::CUnescape(escaped, &unescaped, nullptr)) {
145 // Error unescaping, abort the conversion.
146 return false;
147 }
148 // No more errors possible at this point.
149
150 // Find a string to mark the end that isn't in unescaped.
151 string end = "END";
152 for (int s = 0; unescaped.find(end) != string::npos; ++s) {
153 end = strings::StrCat("END", s);
154 }
155
156 // Actually start writing the converted output.
157 strings::StrAppend(ml, up_to_colon, ": <<", end, "\n", unescaped, "\n", end);
158 if (!suffix.empty()) {
159 // Output suffix, in case there was a trailing comment in the source.
160 strings::StrAppend(ml, suffix);
161 }
162 strings::StrAppend(ml, "\n");
163 return true;
164}
165
166string PBTxtToMultiline(StringPiece pbtxt,
167 const std::vector<string>& multi_line_fields) {
168 string ml;
169 // Probably big enough, since the input and output are about the
170 // same size, but just a guess.
171 ml.reserve(pbtxt.size() * (17. / 16));
172 StringPiece line;
173 while (!pbtxt.empty()) {
174 // Split pbtxt into its first line and everything after.
175 SplitAt('\n', &pbtxt, &line);
176 // Convert line or output it unchanged
177 if (!ConvertLine(line, multi_line_fields, &ml)) {
178 strings::StrAppend(&ml, line, "\n");
179 }
180 }
181 return ml;
182}
183
184// Given a single line of text `line` with first : at `colon`, determine if
185// there is an "<<END" expression after the colon and if so return true and set
186// `*end` to everything after the "<<".
187static bool FindMultiline(StringPiece line, size_t colon, string* end) {
188 if (colon == StringPiece::npos) return false;
189 line.remove_prefix(colon + 1);
190 while (absl::ConsumePrefix(&line, " ")) {
191 }
192 if (absl::ConsumePrefix(&line, "<<")) {
193 *end = string(line);
194 return true;
195 }
196 return false;
197}
198
199string PBTxtFromMultiline(StringPiece multiline_pbtxt) {
200 string pbtxt;
201 // Probably big enough, since the input and output are about the
202 // same size, but just a guess.
203 pbtxt.reserve(multiline_pbtxt.size() * (33. / 32));
204 StringPiece line;
205 while (!multiline_pbtxt.empty()) {
206 // Split multiline_pbtxt into its first line and everything after.
207 if (!SplitAt('\n', &multiline_pbtxt, &line)) {
208 strings::StrAppend(&pbtxt, line);
209 break;
210 }
211
212 string end;
213 auto colon = line.find(':');
214 if (!FindMultiline(line, colon, &end)) {
215 // Normal case: not a multi-line string, just output the line as-is.
216 strings::StrAppend(&pbtxt, line, "\n");
217 continue;
218 }
219
220 // Multi-line case:
221 // something: <<END
222 // xx
223 // yy
224 // END
225 // Should be converted to:
226 // something: "xx\nyy"
227
228 // Output everything up to the colon (" something:").
229 strings::StrAppend(&pbtxt, line.substr(0, colon + 1));
230
231 // Add every line to unescaped until we see the "END" string.
232 string unescaped;
233 bool first = true;
234 while (!multiline_pbtxt.empty()) {
235 SplitAt('\n', &multiline_pbtxt, &line);
236 if (absl::ConsumePrefix(&line, end)) break;
237 if (first) {
238 first = false;
239 } else {
240 unescaped.push_back('\n');
241 }
242 strings::StrAppend(&unescaped, line);
243 line = StringPiece();
244 }
245
246 // Escape what we extracted and then output it in quotes.
247 strings::StrAppend(&pbtxt, " \"", absl::CEscape(unescaped), "\"", line,
248 "\n");
249 }
250 return pbtxt;
251}
252
253static void StringReplace(const string& from, const string& to, string* s) {
254 // Split *s into pieces delimited by `from`.
255 std::vector<string> split;
256 string::size_type pos = 0;
257 while (pos < s->size()) {
258 auto found = s->find(from, pos);
259 if (found == string::npos) {
260 split.push_back(s->substr(pos));
261 break;
262 } else {
263 split.push_back(s->substr(pos, found - pos));
264 pos = found + from.size();
265 if (pos == s->size()) { // handle case where `from` is at the very end.
266 split.push_back("");
267 }
268 }
269 }
270 // Join the pieces back together with a new delimiter.
271 *s = absl::StrJoin(split, to);
272}
273
274static void RenameInDocs(const string& from, const string& to,
275 ApiDef* api_def) {
276 const string from_quoted = strings::StrCat("`", from, "`");
277 const string to_quoted = strings::StrCat("`", to, "`");
278 for (int i = 0; i < api_def->in_arg_size(); ++i) {
279 if (!api_def->in_arg(i).description().empty()) {
280 StringReplace(from_quoted, to_quoted,
281 api_def->mutable_in_arg(i)->mutable_description());
282 }
283 }
284 for (int i = 0; i < api_def->out_arg_size(); ++i) {
285 if (!api_def->out_arg(i).description().empty()) {
286 StringReplace(from_quoted, to_quoted,
287 api_def->mutable_out_arg(i)->mutable_description());
288 }
289 }
290 for (int i = 0; i < api_def->attr_size(); ++i) {
291 if (!api_def->attr(i).description().empty()) {
292 StringReplace(from_quoted, to_quoted,
293 api_def->mutable_attr(i)->mutable_description());
294 }
295 }
296 if (!api_def->summary().empty()) {
297 StringReplace(from_quoted, to_quoted, api_def->mutable_summary());
298 }
299 if (!api_def->description().empty()) {
300 StringReplace(from_quoted, to_quoted, api_def->mutable_description());
301 }
302}
303
304namespace {
305
306// Initializes given ApiDef with data in OpDef.
307void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) {
308 api_def->set_graph_op_name(op_def.name());
309 api_def->set_visibility(ApiDef::VISIBLE);
310
311 auto* endpoint = api_def->add_endpoint();
312 endpoint->set_name(op_def.name());
313
314 for (const auto& op_in_arg : op_def.input_arg()) {
315 auto* api_in_arg = api_def->add_in_arg();
316 api_in_arg->set_name(op_in_arg.name());
317 api_in_arg->set_rename_to(op_in_arg.name());
318 api_in_arg->set_description(op_in_arg.description());
319
320 *api_def->add_arg_order() = op_in_arg.name();
321 }
322 for (const auto& op_out_arg : op_def.output_arg()) {
323 auto* api_out_arg = api_def->add_out_arg();
324 api_out_arg->set_name(op_out_arg.name());
325 api_out_arg->set_rename_to(op_out_arg.name());
326 api_out_arg->set_description(op_out_arg.description());
327 }
328 for (const auto& op_attr : op_def.attr()) {
329 auto* api_attr = api_def->add_attr();
330 api_attr->set_name(op_attr.name());
331 api_attr->set_rename_to(op_attr.name());
332 if (op_attr.has_default_value()) {
333 *api_attr->mutable_default_value() = op_attr.default_value();
334 }
335 api_attr->set_description(op_attr.description());
336 }
337 api_def->set_summary(op_def.summary());
338 api_def->set_description(op_def.description());
339}
340
341// Updates base_arg based on overrides in new_arg.
342void MergeArg(ApiDef::Arg* base_arg, const ApiDef::Arg& new_arg) {
343 if (!new_arg.rename_to().empty()) {
344 base_arg->set_rename_to(new_arg.rename_to());
345 }
346 if (!new_arg.description().empty()) {
347 base_arg->set_description(new_arg.description());
348 }
349}
350
351// Updates base_attr based on overrides in new_attr.
352void MergeAttr(ApiDef::Attr* base_attr, const ApiDef::Attr& new_attr) {
353 if (!new_attr.rename_to().empty()) {
354 base_attr->set_rename_to(new_attr.rename_to());
355 }
356 if (new_attr.has_default_value()) {
357 *base_attr->mutable_default_value() = new_attr.default_value();
358 }
359 if (!new_attr.description().empty()) {
360 base_attr->set_description(new_attr.description());
361 }
362}
363
364// Updates base_api_def based on overrides in new_api_def.
365Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) {
366 // Merge visibility
367 if (new_api_def.visibility() != ApiDef::DEFAULT_VISIBILITY) {
368 base_api_def->set_visibility(new_api_def.visibility());
369 }
370 // Merge endpoints
371 if (new_api_def.endpoint_size() > 0) {
372 base_api_def->clear_endpoint();
373 std::copy(
374 new_api_def.endpoint().begin(), new_api_def.endpoint().end(),
375 protobuf::RepeatedFieldBackInserter(base_api_def->mutable_endpoint()));
376 }
377 // Merge args
378 for (const auto& new_arg : new_api_def.in_arg()) {
379 bool found_base_arg = false;
380 for (int i = 0; i < base_api_def->in_arg_size(); ++i) {
381 auto* base_arg = base_api_def->mutable_in_arg(i);
382 if (base_arg->name() == new_arg.name()) {
383 MergeArg(base_arg, new_arg);
384 found_base_arg = true;
385 break;
386 }
387 }
388 if (!found_base_arg) {
389 return errors::FailedPrecondition("Argument ", new_arg.name(),
390 " not defined in base api for ",
391 base_api_def->graph_op_name());
392 }
393 }
394 for (const auto& new_arg : new_api_def.out_arg()) {
395 bool found_base_arg = false;
396 for (int i = 0; i < base_api_def->out_arg_size(); ++i) {
397 auto* base_arg = base_api_def->mutable_out_arg(i);
398 if (base_arg->name() == new_arg.name()) {
399 MergeArg(base_arg, new_arg);
400 found_base_arg = true;
401 break;
402 }
403 }
404 if (!found_base_arg) {
405 return errors::FailedPrecondition("Argument ", new_arg.name(),
406 " not defined in base api for ",
407 base_api_def->graph_op_name());
408 }
409 }
410 // Merge arg order
411 if (new_api_def.arg_order_size() > 0) {
412 // Validate that new arg_order is correct.
413 if (new_api_def.arg_order_size() != base_api_def->arg_order_size()) {
414 return errors::FailedPrecondition(
415 "Invalid number of arguments ", new_api_def.arg_order_size(), " for ",
416 base_api_def->graph_op_name(),
417 ". Expected: ", base_api_def->arg_order_size());
418 }
419 if (!std::is_permutation(new_api_def.arg_order().begin(),
420 new_api_def.arg_order().end(),
421 base_api_def->arg_order().begin())) {
422 return errors::FailedPrecondition(
423 "Invalid arg_order: ", absl::StrJoin(new_api_def.arg_order(), ", "),
424 " for ", base_api_def->graph_op_name(),
425 ". All elements in arg_order override must match base arg_order: ",
426 absl::StrJoin(base_api_def->arg_order(), ", "));
427 }
428
429 base_api_def->clear_arg_order();
430 std::copy(
431 new_api_def.arg_order().begin(), new_api_def.arg_order().end(),
432 protobuf::RepeatedFieldBackInserter(base_api_def->mutable_arg_order()));
433 }
434 // Merge attributes
435 for (const auto& new_attr : new_api_def.attr()) {
436 bool found_base_attr = false;
437 for (int i = 0; i < base_api_def->attr_size(); ++i) {
438 auto* base_attr = base_api_def->mutable_attr(i);
439 if (base_attr->name() == new_attr.name()) {
440 MergeAttr(base_attr, new_attr);
441 found_base_attr = true;
442 break;
443 }
444 }
445 if (!found_base_attr) {
446 return errors::FailedPrecondition("Attribute ", new_attr.name(),
447 " not defined in base api for ",
448 base_api_def->graph_op_name());
449 }
450 }
451 // Merge summary
452 if (!new_api_def.summary().empty()) {
453 base_api_def->set_summary(new_api_def.summary());
454 }
455 // Merge description
456 auto description = new_api_def.description().empty()
457 ? base_api_def->description()
458 : new_api_def.description();
459
460 if (!new_api_def.description_prefix().empty()) {
461 description =
462 strings::StrCat(new_api_def.description_prefix(), "\n", description);
463 }
464 if (!new_api_def.description_suffix().empty()) {
465 description =
466 strings::StrCat(description, "\n", new_api_def.description_suffix());
467 }
468 base_api_def->set_description(description);
469 return OkStatus();
470}
471} // namespace
472
473ApiDefMap::ApiDefMap(const OpList& op_list) {
474 for (const auto& op : op_list.op()) {
475 ApiDef api_def;
476 InitApiDefFromOpDef(op, &api_def);
477 map_[op.name()] = api_def;
478 }
479}
480
481ApiDefMap::~ApiDefMap() {}
482
483Status ApiDefMap::LoadFileList(Env* env, const std::vector<string>& filenames) {
484 for (const auto& filename : filenames) {
485 TF_RETURN_IF_ERROR(LoadFile(env, filename));
486 }
487 return OkStatus();
488}
489
490Status ApiDefMap::LoadFile(Env* env, const string& filename) {
491 if (filename.empty()) return OkStatus();
492 string contents;
493 TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents));
494 Status status = LoadApiDef(contents);
495 if (!status.ok()) {
496 // Return failed status annotated with filename to aid in debugging.
497 return errors::CreateWithUpdatedMessage(
498 status, strings::StrCat("Error parsing ApiDef file ", filename, ": ",
499 status.error_message()));
500 }
501 return OkStatus();
502}
503
504Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) {
505 const string contents = PBTxtFromMultiline(api_def_file_contents);
506 ApiDefs api_defs;
507 TF_RETURN_IF_ERROR(
508 proto_utils::ParseTextFormatFromString(contents, &api_defs));
509 for (const auto& api_def : api_defs.op()) {
510 // Check if the op definition is loaded. If op definition is not
511 // loaded, then we just skip this ApiDef.
512 if (map_.find(api_def.graph_op_name()) != map_.end()) {
513 // Overwrite current api def with data in api_def.
514 TF_RETURN_IF_ERROR(MergeApiDefs(&map_[api_def.graph_op_name()], api_def));
515 }
516 }
517 return OkStatus();
518}
519
520void ApiDefMap::UpdateDocs() {
521 for (auto& name_and_api_def : map_) {
522 auto& api_def = name_and_api_def.second;
523 CHECK_GT(api_def.endpoint_size(), 0);
524 const string canonical_name = api_def.endpoint(0).name();
525 if (api_def.graph_op_name() != canonical_name) {
526 RenameInDocs(api_def.graph_op_name(), canonical_name, &api_def);
527 }
528 for (const auto& in_arg : api_def.in_arg()) {
529 if (in_arg.name() != in_arg.rename_to()) {
530 RenameInDocs(in_arg.name(), in_arg.rename_to(), &api_def);
531 }
532 }
533 for (const auto& out_arg : api_def.out_arg()) {
534 if (out_arg.name() != out_arg.rename_to()) {
535 RenameInDocs(out_arg.name(), out_arg.rename_to(), &api_def);
536 }
537 }
538 for (const auto& attr : api_def.attr()) {
539 if (attr.name() != attr.rename_to()) {
540 RenameInDocs(attr.name(), attr.rename_to(), &api_def);
541 }
542 }
543 }
544}
545
546const tensorflow::ApiDef* ApiDefMap::GetApiDef(const string& name) const {
547 return gtl::FindOrNull(map_, name);
548}
549} // namespace tensorflow
550