1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/tools/proto_text/gen_proto_text_functions_lib.h"
17
18#include <algorithm>
19#include <set>
20#include <unordered_set>
21
22#include "tensorflow/core/platform/logging.h"
23#include "tensorflow/core/platform/macros.h"
24#include "tensorflow/core/platform/types.h"
25
26using ::tensorflow::protobuf::Descriptor;
27using ::tensorflow::protobuf::EnumDescriptor;
28using ::tensorflow::protobuf::FieldDescriptor;
29using ::tensorflow::protobuf::FieldOptions;
30using ::tensorflow::protobuf::FileDescriptor;
31
32namespace tensorflow {
33
34namespace {
35
36template <typename... Args>
37string StrCat(const Args&... args) {
38 std::ostringstream s;
39 std::vector<int>{((s << args), 0)...};
40 return s.str();
41}
42
43template <typename... Args>
44string StrAppend(string* to_append, const Args&... args) {
45 *to_append += StrCat(args...);
46 return *to_append;
47}
48
49// Class used to generate the code for proto text functions. One of these should
50// be created for each FileDescriptor whose code should be generated.
51//
52// This class has a notion of the current output Section. The Print, Nested,
53// and Unnest functions apply their operations to the current output section,
54// which can be toggled with SetOutput.
55//
56// Note that on the generated code, various pieces are not optimized - for
57// example: map input and output, Cord input and output, comparisons against
58// the field names (it's a loop over all names), and tracking of has_seen.
59class Generator {
60 public:
61 explicit Generator(const string& tf_header_prefix)
62 : tf_header_prefix_(tf_header_prefix),
63 header_(&code_.header),
64 header_impl_(&code_.header_impl),
65 cc_(&code_.cc) {}
66
67 void Generate(const FileDescriptor& fd);
68
69 // The generated code; valid after Generate has been called.
70 ProtoTextFunctionCode code() const { return code_; }
71
72 private:
73 struct Section {
74 explicit Section(string* str) : str(str) {}
75 string* str;
76 string indent;
77 };
78
79 // Switches the currently active section to <section>.
80 Generator& SetOutput(Section* section) {
81 cur_ = section;
82 return *this;
83 }
84
85 // Increases indent level. Returns <*this>, to allow chaining.
86 Generator& Nest() {
87 StrAppend(&cur_->indent, " ");
88 return *this;
89 }
90
91 // Decreases indent level. Returns <*this>, to allow chaining.
92 Generator& Unnest() {
93 cur_->indent = cur_->indent.substr(0, cur_->indent.size() - 2);
94 return *this;
95 }
96
97 // Appends the concatenated args, with a trailing newline. Returns <*this>, to
98 // allow chaining.
99 template <typename... Args>
100 Generator& Print(Args... args) {
101 StrAppend(cur_->str, cur_->indent, args..., "\n");
102 return *this;
103 }
104
105 // Appends the print code for a single field's value.
106 // If <omit_default> is true, then the emitted code will not print zero-valued
107 // values.
108 // <field_expr> is code that when emitted yields the field's value.
109 void AppendFieldValueAppend(const FieldDescriptor& field,
110 const bool omit_default,
111 const string& field_expr);
112
113 // Appends the print code for as single field.
114 void AppendFieldAppend(const FieldDescriptor& field);
115
116 // Appends the print code for a message. May change which section is currently
117 // active.
118 void AppendDebugStringFunctions(const Descriptor& md);
119
120 // Appends the print and parse functions for an enum. May change which
121 // section is currently active.
122 void AppendEnumFunctions(const EnumDescriptor& enum_d);
123
124 // Appends the parse functions for a message. May change which section is
125 // currently active.
126 void AppendParseMessageFunction(const Descriptor& md);
127
128 // Appends all functions for a message and its nested message and enum types.
129 // May change which section is currently active.
130 void AppendMessageFunctions(const Descriptor& md);
131
132 // Appends lines to open or close namespace declarations.
133 void AddNamespaceToCurrentSection(const string& package, bool open);
134
135 // Appends the given headers as sorted #include lines.
136 void AddHeadersToCurrentSection(const std::vector<string>& headers);
137
138 // When adding #includes for tensorflow headers, prefix them with this.
139 const string tf_header_prefix_;
140 ProtoTextFunctionCode code_;
141 Section* cur_ = nullptr;
142 Section header_;
143 Section header_impl_;
144 Section cc_;
145
146 std::unordered_set<string> map_append_signatures_included_;
147
148 TF_DISALLOW_COPY_AND_ASSIGN(Generator);
149};
150
151// Returns the prefix needed to reference objects defined in <fd>. E.g.
152// "::tensorflow::test".
153string GetPackageReferencePrefix(const FileDescriptor* fd) {
154 string result = "::";
155 const string& package = fd->package();
156 for (size_t i = 0; i < package.size(); ++i) {
157 if (package[i] == '.') {
158 result += "::";
159 } else {
160 result += package[i];
161 }
162 }
163 result += "::";
164 return result;
165}
166
167// Returns the name of the class generated by proto to represent <d>.
168string GetClassName(const Descriptor& d) {
169 if (d.containing_type() == nullptr) return d.name();
170 return StrCat(GetClassName(*d.containing_type()), "_", d.name());
171}
172
173// Returns the name of the class generated by proto to represent <ed>.
174string GetClassName(const EnumDescriptor& ed) {
175 if (ed.containing_type() == nullptr) return ed.name();
176 return StrCat(GetClassName(*ed.containing_type()), "_", ed.name());
177}
178
179// Returns the qualified name that refers to the class generated by proto to
180// represent <d>.
181string GetQualifiedName(const Descriptor& d) {
182 return StrCat(GetPackageReferencePrefix(d.file()), GetClassName(d));
183}
184
185// Returns the qualified name that refers to the class generated by proto to
186// represent <ed>.
187string GetQualifiedName(const EnumDescriptor& d) {
188 return StrCat(GetPackageReferencePrefix(d.file()), GetClassName(d));
189}
190
191// Returns the qualified name that refers to the generated
192// AppendProtoDebugString function for <d>.
193string GetQualifiedAppendFn(const Descriptor& d) {
194 return StrCat(GetPackageReferencePrefix(d.file()),
195 "internal::AppendProtoDebugString");
196}
197
198// Returns the name of the generated function that returns an enum value's
199// string value.
200string GetEnumNameFn(const EnumDescriptor& enum_d) {
201 return StrCat("EnumName_", GetClassName(enum_d));
202}
203
204// Returns the qualified name of the function returned by GetEnumNameFn().
205string GetQualifiedEnumNameFn(const EnumDescriptor& enum_d) {
206 return StrCat(GetPackageReferencePrefix(enum_d.file()),
207 GetEnumNameFn(enum_d));
208}
209
210// Returns the name of a generated header file, either the public api (if impl
211// is false) or the internal implementation header (if impl is true).
212string GetProtoTextHeaderName(const FileDescriptor& fd, bool impl) {
213 const int dot_index = fd.name().find_last_of('.');
214 return fd.name().substr(0, dot_index) +
215 (impl ? ".pb_text-impl.h" : ".pb_text.h");
216}
217
218// Returns the name of the header generated by the proto library for <fd>.
219string GetProtoHeaderName(const FileDescriptor& fd) {
220 const int dot_index = fd.name().find_last_of('.');
221 return fd.name().substr(0, dot_index) + ".pb.h";
222}
223
224// Returns the C++ class name for the given proto field.
225string GetCppClass(const FieldDescriptor& d) {
226 string cpp_class = d.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE
227 ? GetQualifiedName(*d.message_type())
228 : d.cpp_type_name();
229
230 // In open-source TensorFlow, the definition of int64 varies across
231 // platforms. The following line, which is manipulated during internal-
232 // external sync'ing, takes care of the variability.
233 if (cpp_class == "int64") {
234 cpp_class = kProtobufInt64Typename;
235 }
236
237 return cpp_class;
238}
239
240// Returns the string that can be used for a header guard for the generated
241// headers for <fd>, either for the public api (if impl is false) or the
242// internal implementation header (if impl is true).
243string GetHeaderGuard(const FileDescriptor& fd, bool impl) {
244 string s = fd.name();
245 std::replace(s.begin(), s.end(), '/', '_');
246 std::replace(s.begin(), s.end(), '.', '_');
247 return s + (impl ? "_IMPL_H_" : "_H_");
248}
249
250void Generator::AppendFieldValueAppend(const FieldDescriptor& field,
251 const bool omit_default,
252 const string& field_expr) {
253 SetOutput(&cc_);
254 switch (field.cpp_type()) {
255 case FieldDescriptor::CPPTYPE_INT32:
256 case FieldDescriptor::CPPTYPE_INT64:
257 case FieldDescriptor::CPPTYPE_UINT32:
258 case FieldDescriptor::CPPTYPE_UINT64:
259 case FieldDescriptor::CPPTYPE_DOUBLE:
260 case FieldDescriptor::CPPTYPE_FLOAT:
261 Print("o->", omit_default ? "AppendNumericIfNotZero" : "AppendNumeric",
262 "(\"", field.name(), "\", ", field_expr, ");");
263 break;
264 case FieldDescriptor::CPPTYPE_BOOL:
265 Print("o->", omit_default ? "AppendBoolIfTrue" : "AppendBool", "(\"",
266 field.name(), "\", ", field_expr, ");");
267 break;
268 case FieldDescriptor::CPPTYPE_STRING: {
269 const auto ctype = field.options().ctype();
270 CHECK(ctype == FieldOptions::CORD || ctype == FieldOptions::STRING)
271 << "Unsupported ctype " << ctype;
272
273 Print("o->", omit_default ? "AppendStringIfNotEmpty" : "AppendString",
274 "(\"", field.name(), "\", ProtobufStringToString(", field_expr,
275 "));");
276 break;
277 }
278 case FieldDescriptor::CPPTYPE_ENUM:
279 if (omit_default) {
280 Print("if (", field_expr, " != 0) {").Nest();
281 }
282 Print("const char* enum_name = ",
283 GetQualifiedEnumNameFn(*field.enum_type()), "(", field_expr, ");");
284 Print("if (enum_name[0]) {").Nest();
285 Print("o->AppendEnumName(\"", field.name(), "\", enum_name);");
286 Unnest().Print("} else {").Nest();
287 Print("o->AppendNumeric(\"", field.name(), "\", ", field_expr, ");");
288 Unnest().Print("}");
289 if (omit_default) {
290 Unnest().Print("}");
291 }
292 break;
293 case FieldDescriptor::CPPTYPE_MESSAGE:
294 CHECK(!field.message_type()->options().map_entry());
295 if (omit_default) {
296 Print("if (msg.has_", field.name(), "()) {").Nest();
297 }
298 Print("o->OpenNestedMessage(\"", field.name(), "\");");
299 Print(GetQualifiedAppendFn(*field.message_type()), "(o, ", field_expr,
300 ");");
301 Print("o->CloseNestedMessage();");
302 if (omit_default) {
303 Unnest().Print("}");
304 }
305 break;
306 }
307}
308
309void Generator::AppendFieldAppend(const FieldDescriptor& field) {
310 const string& name = field.name();
311
312 if (field.is_map()) {
313 Print("{").Nest();
314 const auto& key_type = *field.message_type()->FindFieldByName("key");
315 const auto& value_type = *field.message_type()->FindFieldByName("value");
316
317 Print("std::vector<", key_type.cpp_type_name(), "> keys;");
318 Print("for (const auto& e : msg.", name, "()) keys.push_back(e.first);");
319 Print("std::stable_sort(keys.begin(), keys.end());");
320 Print("for (const auto& key : keys) {").Nest();
321 Print("o->OpenNestedMessage(\"", name, "\");");
322 AppendFieldValueAppend(key_type, false /* omit_default */, "key");
323 AppendFieldValueAppend(value_type, false /* omit_default */,
324 StrCat("msg.", name, "().at(key)"));
325 Print("o->CloseNestedMessage();");
326 Unnest().Print("}");
327
328 Unnest().Print("}");
329 } else if (field.is_repeated()) {
330 Print("for (int i = 0; i < msg.", name, "_size(); ++i) {");
331 Nest();
332 AppendFieldValueAppend(field, false /* omit_default */,
333 "msg." + name + "(i)");
334 Unnest().Print("}");
335 } else {
336 const auto* oneof = field.containing_oneof();
337 if (oneof != nullptr) {
338 string camel_name = field.camelcase_name();
339 camel_name[0] = toupper(camel_name[0]);
340 Print("if (msg.", oneof->name(), "_case() == ",
341 GetQualifiedName(*oneof->containing_type()), "::k", camel_name,
342 ") {");
343 Nest();
344 AppendFieldValueAppend(field, false /* omit_default */,
345 "msg." + name + "()");
346 Unnest();
347 Print("}");
348 } else {
349 AppendFieldValueAppend(field, true /* omit_default */,
350 "msg." + name + "()");
351 }
352 }
353}
354
355void Generator::AppendEnumFunctions(const EnumDescriptor& enum_d) {
356 const string sig = StrCat("const char* ", GetEnumNameFn(enum_d), "(\n ",
357 GetQualifiedName(enum_d), " value)");
358 SetOutput(&header_);
359 Print().Print("// Enum text output for ", string(enum_d.full_name()));
360 Print(sig, ";");
361
362 SetOutput(&cc_);
363 Print().Print(sig, " {");
364 Nest().Print("switch (value) {").Nest();
365 for (int i = 0; i < enum_d.value_count(); ++i) {
366 const auto& value = *enum_d.value(i);
367 Print("case ", value.number(), ": return \"", value.name(), "\";");
368 }
369 Print("default: return \"\";");
370 Unnest().Print("}");
371 Unnest().Print("}");
372}
373
374void Generator::AppendParseMessageFunction(const Descriptor& md) {
375 const bool map_append = (md.options().map_entry());
376 string sig;
377 if (!map_append) {
378 sig = StrCat("bool ProtoParseFromString(\n const string& s,\n ",
379 GetQualifiedName(md), "* msg)");
380 SetOutput(&header_).Print(sig, "\n TF_MUST_USE_RESULT;");
381
382 SetOutput(&cc_);
383 Print().Print(sig, " {").Nest();
384 Print("msg->Clear();");
385 Print("Scanner scanner(s);");
386 Print("if (!internal::ProtoParseFromScanner(",
387 "&scanner, false, false, msg)) return false;");
388 Print("scanner.Eos();");
389 Print("return scanner.GetResult();");
390 Unnest().Print("}");
391 }
392
393 // Parse from scanner - the real work here.
394 sig = StrCat("bool ProtoParseFromScanner(",
395 "\n ::tensorflow::strings::Scanner* scanner, bool nested, "
396 "bool close_curly,\n ");
397 const FieldDescriptor* key_type = nullptr;
398 const FieldDescriptor* value_type = nullptr;
399 if (map_append) {
400 key_type = md.FindFieldByName("key");
401 value_type = md.FindFieldByName("value");
402 StrAppend(&sig, "::tensorflow::protobuf::Map<", GetCppClass(*key_type),
403 ", ", GetCppClass(*value_type), ">* map)");
404 } else {
405 StrAppend(&sig, GetQualifiedName(md), "* msg)");
406 }
407
408 if (!map_append_signatures_included_.insert(sig).second) {
409 // signature for function to append to a map of this type has
410 // already been defined in this .cc file. Don't define it again.
411 return;
412 }
413
414 if (!map_append) {
415 SetOutput(&header_impl_).Print(sig, ";");
416 }
417
418 SetOutput(&cc_);
419 Print().Print("namespace internal {");
420 if (map_append) {
421 Print("namespace {");
422 }
423 Print().Print(sig, " {").Nest();
424 if (map_append) {
425 Print(GetCppClass(*key_type), " map_key;");
426 Print("bool set_map_key = false;");
427 Print(GetCppClass(*value_type), " map_value;");
428 Print("bool set_map_value = false;");
429 }
430 Print("std::vector<bool> has_seen(", md.field_count(), ", false);");
431 Print("while(true) {").Nest();
432 Print("ProtoSpaceAndComments(scanner);");
433
434 // Emit success case
435 Print("if (nested && (scanner->Peek() == (close_curly ? '}' : '>'))) {")
436 .Nest();
437 Print("scanner->One(Scanner::ALL);");
438 Print("ProtoSpaceAndComments(scanner);");
439 if (map_append) {
440 Print("if (!set_map_key || !set_map_value) return false;");
441 Print("(*map)[map_key] = map_value;");
442 }
443 Print("return true;");
444 Unnest().Print("}");
445
446 Print("if (!nested && scanner->empty()) { return true; }");
447 Print("scanner->RestartCapture()");
448 Print(" .Many(Scanner::LETTER_DIGIT_UNDERSCORE)");
449 Print(" .StopCapture();");
450 Print("StringPiece identifier;");
451 Print("if (!scanner->GetResult(nullptr, &identifier)) return false;");
452 Print("bool parsed_colon = false;");
453 Print("(void)parsed_colon;"); // Avoid "set but not used" compiler warning
454 Print("ProtoSpaceAndComments(scanner);");
455 Print("if (scanner->Peek() == ':') {");
456 Nest().Print("parsed_colon = true;");
457 Print("scanner->One(Scanner::ALL);");
458 Print("ProtoSpaceAndComments(scanner);");
459 Unnest().Print("}");
460 for (int i = 0; i < md.field_count(); ++i) {
461 const FieldDescriptor* field = md.field(i);
462 const string& field_name = field->name();
463 string mutable_value_expr;
464 string set_value_prefix;
465 if (map_append) {
466 mutable_value_expr = StrCat("&map_", field_name);
467 set_value_prefix = StrCat("map_", field_name, " = ");
468 } else if (field->is_repeated()) {
469 if (field->is_map()) {
470 mutable_value_expr = StrCat("msg->mutable_", field_name, "()");
471 set_value_prefix =
472 "UNREACHABLE"; // generator will never use this value.
473 } else {
474 mutable_value_expr = StrCat("msg->add_", field_name, "()");
475 set_value_prefix = StrCat("msg->add_", field_name);
476 }
477 } else {
478 mutable_value_expr = StrCat("msg->mutable_", field_name, "()");
479 set_value_prefix = StrCat("msg->set_", field_name);
480 }
481
482 Print(i == 0 ? "" : "else ", "if (identifier == \"", field_name, "\") {");
483 Nest();
484
485 if (field->is_repeated()) {
486 CHECK(!map_append);
487
488 // Check to see if this is an array assignment, like a: [1, 2, 3]
489 Print("const bool is_list = (scanner->Peek() == '[');");
490 Print("do {");
491 // [ or , // skip
492 Nest().Print("if (is_list) {");
493 Nest().Print("scanner->One(Scanner::ALL);");
494 Print("ProtoSpaceAndComments(scanner);");
495 Unnest().Print("}");
496 } else if (field->containing_oneof() != nullptr) {
497 CHECK(!map_append);
498
499 // Detect duplicate oneof value.
500 const string oneof_name = field->containing_oneof()->name();
501 Print("if (msg->", oneof_name, "_case() != 0) return false;");
502 }
503
504 if (!field->is_repeated() && !map_append) {
505 // Detect duplicate nested repeated message.
506 Print("if (has_seen[", i, "]) return false;");
507 Print("has_seen[", i, "] = true;");
508 }
509 if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
510 Print("const char open_char = scanner->Peek();");
511 Print("if (open_char != '{' && open_char != '<') return false;");
512 Print("scanner->One(Scanner::ALL);");
513 Print("ProtoSpaceAndComments(scanner);");
514 if (field->is_map()) {
515 Print("if (!ProtoParseFromScanner(");
516 } else {
517 Print("if (!", GetPackageReferencePrefix(field->message_type()->file()),
518 "internal::ProtoParseFromScanner(");
519 }
520 Print(" scanner, true, open_char == '{', ", mutable_value_expr,
521 ")) return false;");
522 } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_STRING) {
523 Print("string str_value;");
524 Print(
525 "if (!parsed_colon || "
526 "!::tensorflow::strings::ProtoParseStringLiteralFromScanner(");
527 Print(" scanner, &str_value)) return false;");
528 Print("SetProtobufStringSwapAllowed(&str_value, ", mutable_value_expr,
529 ");");
530 } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) {
531 Print("StringPiece value;");
532 Print(
533 "if (!parsed_colon || "
534 "!scanner->RestartCapture().Many("
535 "Scanner::LETTER_DIGIT_DASH_UNDERSCORE)."
536 "GetResult(nullptr, &value)) return false;");
537 const auto* enum_d = field->enum_type();
538 string value_prefix;
539 if (enum_d->containing_type() == nullptr) {
540 value_prefix = GetPackageReferencePrefix(enum_d->file());
541 } else {
542 value_prefix = StrCat(GetQualifiedName(*enum_d), "_");
543 }
544
545 for (int enum_i = 0; enum_i < enum_d->value_count(); ++enum_i) {
546 const auto* value_d = enum_d->value(enum_i);
547 const string& value_name = value_d->name();
548 string condition = StrCat("value == \"", value_name, "\"");
549
550 Print(enum_i == 0 ? "" : "} else ", "if (", condition, ") {");
551 Nest();
552 Print(set_value_prefix, "(", value_prefix, value_name, ");");
553 Unnest();
554 }
555 Print("} else {");
556 Nest();
557 // Proto3 allows all numeric values.
558 Print("int32 int_value;");
559 Print("if (strings::SafeStringToNumeric(value, &int_value)) {");
560 Nest();
561 Print(set_value_prefix, "(static_cast<", GetQualifiedName(*enum_d),
562 ">(int_value));");
563 Unnest();
564 Print("} else {").Nest().Print("return false;").Unnest().Print("}");
565 Unnest().Print("}");
566 } else {
567 Print(field->cpp_type_name(), " value;");
568 switch (field->cpp_type()) {
569 case FieldDescriptor::CPPTYPE_INT32:
570 case FieldDescriptor::CPPTYPE_INT64:
571 case FieldDescriptor::CPPTYPE_UINT32:
572 case FieldDescriptor::CPPTYPE_UINT64:
573 case FieldDescriptor::CPPTYPE_DOUBLE:
574 case FieldDescriptor::CPPTYPE_FLOAT:
575 Print(
576 "if (!parsed_colon || "
577 "!::tensorflow::strings::ProtoParseNumericFromScanner(",
578 "scanner, &value)) return false;");
579 break;
580 case FieldDescriptor::CPPTYPE_BOOL:
581 Print(
582 "if (!parsed_colon || "
583 "!::tensorflow::strings::ProtoParseBoolFromScanner(",
584 "scanner, &value)) return false;");
585 break;
586 default:
587 LOG(FATAL) << "handled earlier";
588 }
589 Print(set_value_prefix, "(value);");
590 }
591
592 if (field->is_repeated()) {
593 Unnest().Print("} while (is_list && scanner->Peek() == ',');");
594 Print(
595 "if (is_list && "
596 "!scanner->OneLiteral(\"]\").GetResult()) return false;");
597 }
598 if (map_append) {
599 Print("set_map_", field_name, " = true;");
600 }
601 Unnest().Print("}");
602 }
603 Unnest().Print("}");
604 Unnest().Print("}");
605 Unnest().Print();
606 if (map_append) {
607 Print("} // namespace");
608 }
609 Print("} // namespace internal");
610}
611
612void Generator::AppendDebugStringFunctions(const Descriptor& md) {
613 SetOutput(&header_impl_).Print();
614 SetOutput(&header_).Print().Print("// Message-text conversion for ",
615 string(md.full_name()));
616
617 // Append the two debug string functions for <md>.
618 for (int short_pass = 0; short_pass < 2; ++short_pass) {
619 const bool short_debug = (short_pass == 1);
620
621 // Make the Get functions.
622 const string sig = StrCat(
623 "string ", short_debug ? "ProtoShortDebugString" : "ProtoDebugString",
624 "(\n const ", GetQualifiedName(md), "& msg)");
625 SetOutput(&header_).Print(sig, ";");
626
627 SetOutput(&cc_);
628 Print().Print(sig, " {").Nest();
629 Print("string s;");
630 Print("::tensorflow::strings::ProtoTextOutput o(&s, ",
631 short_debug ? "true" : "false", ");");
632 Print("internal::AppendProtoDebugString(&o, msg);");
633 Print("o.CloseTopMessage();");
634 Print("return s;");
635 Unnest().Print("}");
636 }
637
638 // Make the Append function.
639 const string sig =
640 StrCat("void AppendProtoDebugString(\n",
641 " ::tensorflow::strings::ProtoTextOutput* o,\n const ",
642 GetQualifiedName(md), "& msg)");
643 SetOutput(&header_impl_).Print(sig, ";");
644 SetOutput(&cc_);
645 Print().Print("namespace internal {").Print();
646 Print(sig, " {").Nest();
647 std::vector<const FieldDescriptor*> fields;
648 fields.reserve(md.field_count());
649 for (int i = 0; i < md.field_count(); ++i) {
650 fields.push_back(md.field(i));
651 }
652 std::sort(fields.begin(), fields.end(),
653 [](const FieldDescriptor* left, const FieldDescriptor* right) {
654 return left->number() < right->number();
655 });
656
657 for (const FieldDescriptor* field : fields) {
658 SetOutput(&cc_);
659 AppendFieldAppend(*field);
660 }
661 Unnest().Print("}").Print().Print("} // namespace internal");
662}
663
664void Generator::AppendMessageFunctions(const Descriptor& md) {
665 if (md.options().map_entry()) {
666 // The 'map entry' Message is not a user-visible message type. Only its
667 // parse function is created (and that actually parsed the whole Map, not
668 // just the map entry). Printing of a map is done in the code generated for
669 // the containing message.
670 AppendParseMessageFunction(md);
671 return;
672 }
673
674 // Recurse before adding the main message function, so that internal
675 // map_append functions are available before they are needed.
676 for (int i = 0; i < md.enum_type_count(); ++i) {
677 AppendEnumFunctions(*md.enum_type(i));
678 }
679 for (int i = 0; i < md.nested_type_count(); ++i) {
680 AppendMessageFunctions(*md.nested_type(i));
681 }
682
683 AppendDebugStringFunctions(md);
684 AppendParseMessageFunction(md);
685}
686
687void Generator::AddNamespaceToCurrentSection(const string& package, bool open) {
688 Print();
689 std::vector<string> parts = {""};
690 for (size_t i = 0; i < package.size(); ++i) {
691 if (package[i] == '.') {
692 parts.resize(parts.size() + 1);
693 } else {
694 parts.back() += package[i];
695 }
696 }
697 if (open) {
698 for (const auto& p : parts) {
699 Print("namespace ", p, " {");
700 }
701 } else {
702 for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
703 Print("} // namespace ", *it);
704 }
705 }
706}
707
708void Generator::AddHeadersToCurrentSection(const std::vector<string>& headers) {
709 std::vector<string> sorted = headers;
710 std::sort(sorted.begin(), sorted.end());
711 for (const auto& h : sorted) {
712 Print("#include \"", h, "\"");
713 }
714}
715
716// Adds to <all_fd> and <all_d> with all descriptors recursively
717// reachable from the given descriptor.
718void GetAllFileDescriptorsFromFile(const FileDescriptor* fd,
719 std::set<const FileDescriptor*>* all_fd,
720 std::set<const Descriptor*>* all_d);
721
722// Adds to <all_fd> and <all_d> with all descriptors recursively
723// reachable from the given descriptor.
724void GetAllFileDescriptorsFromMessage(const Descriptor* d,
725 std::set<const FileDescriptor*>* all_fd,
726 std::set<const Descriptor*>* all_d) {
727 if (!all_d->insert(d).second) return;
728 GetAllFileDescriptorsFromFile(d->file(), all_fd, all_d);
729 for (int i = 0; i < d->field_count(); ++i) {
730 auto* f = d->field(i);
731 switch (f->cpp_type()) {
732 case FieldDescriptor::CPPTYPE_INT32:
733 case FieldDescriptor::CPPTYPE_INT64:
734 case FieldDescriptor::CPPTYPE_UINT32:
735 case FieldDescriptor::CPPTYPE_UINT64:
736 case FieldDescriptor::CPPTYPE_DOUBLE:
737 case FieldDescriptor::CPPTYPE_FLOAT:
738 case FieldDescriptor::CPPTYPE_BOOL:
739 case FieldDescriptor::CPPTYPE_STRING:
740 break;
741 case FieldDescriptor::CPPTYPE_MESSAGE:
742 GetAllFileDescriptorsFromMessage(f->message_type(), all_fd, all_d);
743 break;
744 case FieldDescriptor::CPPTYPE_ENUM:
745 GetAllFileDescriptorsFromFile(f->enum_type()->file(), all_fd, all_d);
746 break;
747 }
748 }
749 for (int i = 0; i < d->nested_type_count(); ++i) {
750 GetAllFileDescriptorsFromMessage(d->nested_type(i), all_fd, all_d);
751 }
752}
753
754void GetAllFileDescriptorsFromFile(const FileDescriptor* fd,
755 std::set<const FileDescriptor*>* all_fd,
756 std::set<const Descriptor*>* all_d) {
757 if (!all_fd->insert(fd).second) return;
758 for (int i = 0; i < fd->message_type_count(); ++i) {
759 GetAllFileDescriptorsFromMessage(fd->message_type(i), all_fd, all_d);
760 }
761}
762
763void Generator::Generate(const FileDescriptor& fd) {
764 // This does not emit code with proper proto2 semantics (e.g. it doesn't check
765 // 'has' fields on non-messages), so check that only proto3 is passed.
766 CHECK_EQ(fd.syntax(), FileDescriptor::SYNTAX_PROTO3) << fd.name();
767
768 const string package = fd.package();
769 std::set<const FileDescriptor*> all_fd;
770 std::set<const Descriptor*> all_d;
771 GetAllFileDescriptorsFromFile(&fd, &all_fd, &all_d);
772
773 std::vector<string> headers;
774
775 // Add header to header file.
776 SetOutput(&header_);
777 Print("// GENERATED FILE - DO NOT MODIFY");
778 Print("#ifndef ", GetHeaderGuard(fd, false /* impl */));
779 Print("#define ", GetHeaderGuard(fd, false /* impl */));
780 Print();
781 headers = {
782 GetProtoHeaderName(fd),
783 StrCat(tf_header_prefix_, "tensorflow/core/platform/macros.h"),
784 StrCat(tf_header_prefix_, "tensorflow/core/platform/protobuf.h"),
785 StrCat(tf_header_prefix_, "tensorflow/core/platform/types.h"),
786 };
787 for (const auto& h : headers) {
788 Print("#include \"", h, "\"");
789 }
790 AddNamespaceToCurrentSection(package, true /* is_open */);
791
792 // Add header to impl file.
793 SetOutput(&header_impl_);
794 Print("// GENERATED FILE - DO NOT MODIFY");
795 Print("#ifndef ", GetHeaderGuard(fd, true /* impl */));
796 Print("#define ", GetHeaderGuard(fd, true /* impl */));
797 Print();
798 headers = {
799 GetProtoTextHeaderName(fd, false /* impl */),
800 StrCat(tf_header_prefix_,
801 "tensorflow/core/lib/strings/proto_text_util.h"),
802 StrCat(tf_header_prefix_, "tensorflow/core/lib/strings/scanner.h"),
803 };
804 for (const FileDescriptor* d : all_fd) {
805 if (d != &fd) {
806 headers.push_back(GetProtoTextHeaderName(*d, true /* impl */));
807 }
808 headers.push_back(GetProtoHeaderName(*d));
809 }
810 AddHeadersToCurrentSection(headers);
811 AddNamespaceToCurrentSection(package, true /* is_open */);
812 SetOutput(&header_impl_).Print().Print("namespace internal {");
813
814 // Add header to cc file.
815 SetOutput(&cc_);
816 Print("// GENERATED FILE - DO NOT MODIFY");
817 Print();
818 Print("#include <algorithm>"); // for `std::stable_sort()`
819 Print();
820 headers = {GetProtoTextHeaderName(fd, true /* impl */)};
821 AddHeadersToCurrentSection(headers);
822 Print();
823 Print("using ::tensorflow::strings::ProtoSpaceAndComments;");
824 Print("using ::tensorflow::strings::Scanner;");
825 Print("using ::tensorflow::strings::StrCat;");
826 AddNamespaceToCurrentSection(package, true /* is_open */);
827
828 // Add declarations and definitions.
829 for (int i = 0; i < fd.enum_type_count(); ++i) {
830 AppendEnumFunctions(*fd.enum_type(i));
831 }
832 for (int i = 0; i < fd.message_type_count(); ++i) {
833 AppendMessageFunctions(*fd.message_type(i));
834 }
835
836 // Add footer to header file.
837 SetOutput(&header_);
838 AddNamespaceToCurrentSection(package, false /* is_open */);
839 Print().Print("#endif // ", GetHeaderGuard(fd, false /* impl */));
840
841 // Add footer to header impl file.
842 SetOutput(&header_impl_).Print().Print("} // namespace internal");
843 AddNamespaceToCurrentSection(package, false /* is_open */);
844 Print().Print("#endif // ", GetHeaderGuard(fd, true /* impl */));
845
846 // Add footer to cc file.
847 SetOutput(&cc_);
848 AddNamespaceToCurrentSection(package, false /* is_open */);
849}
850
851} // namespace
852
853ProtoTextFunctionCode GetProtoTextFunctionCode(const FileDescriptor& fd,
854 const string& tf_header_prefix) {
855 Generator gen(tf_header_prefix);
856 gen.Generate(fd);
857 return gen.code();
858}
859
860} // namespace tensorflow
861