1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/tools/proto_text/gen_proto_text_functions_lib.h" |
17 | |
18 | #include <algorithm> |
19 | #include <set> |
20 | #include <unordered_set> |
21 | |
22 | #include "tensorflow/core/platform/logging.h" |
23 | #include "tensorflow/core/platform/macros.h" |
24 | #include "tensorflow/core/platform/types.h" |
25 | |
26 | using ::tensorflow::protobuf::Descriptor; |
27 | using ::tensorflow::protobuf::EnumDescriptor; |
28 | using ::tensorflow::protobuf::FieldDescriptor; |
29 | using ::tensorflow::protobuf::FieldOptions; |
30 | using ::tensorflow::protobuf::FileDescriptor; |
31 | |
32 | namespace tensorflow { |
33 | |
34 | namespace { |
35 | |
36 | template <typename... Args> |
37 | string StrCat(const Args&... args) { |
38 | std::ostringstream s; |
39 | std::vector<int>{((s << args), 0)...}; |
40 | return s.str(); |
41 | } |
42 | |
43 | template <typename... Args> |
44 | string StrAppend(string* to_append, const Args&... args) { |
45 | *to_append += StrCat(args...); |
46 | return *to_append; |
47 | } |
48 | |
49 | // Class used to generate the code for proto text functions. One of these should |
50 | // be created for each FileDescriptor whose code should be generated. |
51 | // |
52 | // This class has a notion of the current output Section. The Print, Nested, |
53 | // and Unnest functions apply their operations to the current output section, |
54 | // which can be toggled with SetOutput. |
55 | // |
56 | // Note that on the generated code, various pieces are not optimized - for |
57 | // example: map input and output, Cord input and output, comparisons against |
58 | // the field names (it's a loop over all names), and tracking of has_seen. |
59 | class Generator { |
60 | public: |
61 | explicit Generator(const string& ) |
62 | : tf_header_prefix_(tf_header_prefix), |
63 | header_(&code_.header), |
64 | header_impl_(&code_.header_impl), |
65 | cc_(&code_.cc) {} |
66 | |
67 | void Generate(const FileDescriptor& fd); |
68 | |
69 | // The generated code; valid after Generate has been called. |
70 | ProtoTextFunctionCode code() const { return code_; } |
71 | |
72 | private: |
73 | struct Section { |
74 | explicit Section(string* str) : str(str) {} |
75 | string* str; |
76 | string indent; |
77 | }; |
78 | |
79 | // Switches the currently active section to <section>. |
80 | Generator& SetOutput(Section* section) { |
81 | cur_ = section; |
82 | return *this; |
83 | } |
84 | |
85 | // Increases indent level. Returns <*this>, to allow chaining. |
86 | Generator& Nest() { |
87 | StrAppend(&cur_->indent, " " ); |
88 | return *this; |
89 | } |
90 | |
91 | // Decreases indent level. Returns <*this>, to allow chaining. |
92 | Generator& Unnest() { |
93 | cur_->indent = cur_->indent.substr(0, cur_->indent.size() - 2); |
94 | return *this; |
95 | } |
96 | |
97 | // Appends the concatenated args, with a trailing newline. Returns <*this>, to |
98 | // allow chaining. |
99 | template <typename... Args> |
100 | Generator& Print(Args... args) { |
101 | StrAppend(cur_->str, cur_->indent, args..., "\n" ); |
102 | return *this; |
103 | } |
104 | |
105 | // Appends the print code for a single field's value. |
106 | // If <omit_default> is true, then the emitted code will not print zero-valued |
107 | // values. |
108 | // <field_expr> is code that when emitted yields the field's value. |
109 | void AppendFieldValueAppend(const FieldDescriptor& field, |
110 | const bool omit_default, |
111 | const string& field_expr); |
112 | |
113 | // Appends the print code for as single field. |
114 | void AppendFieldAppend(const FieldDescriptor& field); |
115 | |
116 | // Appends the print code for a message. May change which section is currently |
117 | // active. |
118 | void AppendDebugStringFunctions(const Descriptor& md); |
119 | |
120 | // Appends the print and parse functions for an enum. May change which |
121 | // section is currently active. |
122 | void AppendEnumFunctions(const EnumDescriptor& enum_d); |
123 | |
124 | // Appends the parse functions for a message. May change which section is |
125 | // currently active. |
126 | void AppendParseMessageFunction(const Descriptor& md); |
127 | |
128 | // Appends all functions for a message and its nested message and enum types. |
129 | // May change which section is currently active. |
130 | void AppendMessageFunctions(const Descriptor& md); |
131 | |
132 | // Appends lines to open or close namespace declarations. |
133 | void AddNamespaceToCurrentSection(const string& package, bool open); |
134 | |
135 | // Appends the given headers as sorted #include lines. |
136 | void AddHeadersToCurrentSection(const std::vector<string>& ); |
137 | |
138 | // When adding #includes for tensorflow headers, prefix them with this. |
139 | const string ; |
140 | ProtoTextFunctionCode code_; |
141 | Section* cur_ = nullptr; |
142 | Section ; |
143 | Section ; |
144 | Section cc_; |
145 | |
146 | std::unordered_set<string> map_append_signatures_included_; |
147 | |
148 | TF_DISALLOW_COPY_AND_ASSIGN(Generator); |
149 | }; |
150 | |
151 | // Returns the prefix needed to reference objects defined in <fd>. E.g. |
152 | // "::tensorflow::test". |
153 | string GetPackageReferencePrefix(const FileDescriptor* fd) { |
154 | string result = "::" ; |
155 | const string& package = fd->package(); |
156 | for (size_t i = 0; i < package.size(); ++i) { |
157 | if (package[i] == '.') { |
158 | result += "::" ; |
159 | } else { |
160 | result += package[i]; |
161 | } |
162 | } |
163 | result += "::" ; |
164 | return result; |
165 | } |
166 | |
167 | // Returns the name of the class generated by proto to represent <d>. |
168 | string GetClassName(const Descriptor& d) { |
169 | if (d.containing_type() == nullptr) return d.name(); |
170 | return StrCat(GetClassName(*d.containing_type()), "_" , d.name()); |
171 | } |
172 | |
173 | // Returns the name of the class generated by proto to represent <ed>. |
174 | string GetClassName(const EnumDescriptor& ed) { |
175 | if (ed.containing_type() == nullptr) return ed.name(); |
176 | return StrCat(GetClassName(*ed.containing_type()), "_" , ed.name()); |
177 | } |
178 | |
179 | // Returns the qualified name that refers to the class generated by proto to |
180 | // represent <d>. |
181 | string GetQualifiedName(const Descriptor& d) { |
182 | return StrCat(GetPackageReferencePrefix(d.file()), GetClassName(d)); |
183 | } |
184 | |
185 | // Returns the qualified name that refers to the class generated by proto to |
186 | // represent <ed>. |
187 | string GetQualifiedName(const EnumDescriptor& d) { |
188 | return StrCat(GetPackageReferencePrefix(d.file()), GetClassName(d)); |
189 | } |
190 | |
191 | // Returns the qualified name that refers to the generated |
192 | // AppendProtoDebugString function for <d>. |
193 | string GetQualifiedAppendFn(const Descriptor& d) { |
194 | return StrCat(GetPackageReferencePrefix(d.file()), |
195 | "internal::AppendProtoDebugString" ); |
196 | } |
197 | |
198 | // Returns the name of the generated function that returns an enum value's |
199 | // string value. |
200 | string GetEnumNameFn(const EnumDescriptor& enum_d) { |
201 | return StrCat("EnumName_" , GetClassName(enum_d)); |
202 | } |
203 | |
204 | // Returns the qualified name of the function returned by GetEnumNameFn(). |
205 | string GetQualifiedEnumNameFn(const EnumDescriptor& enum_d) { |
206 | return StrCat(GetPackageReferencePrefix(enum_d.file()), |
207 | GetEnumNameFn(enum_d)); |
208 | } |
209 | |
210 | // Returns the name of a generated header file, either the public api (if impl |
211 | // is false) or the internal implementation header (if impl is true). |
212 | string (const FileDescriptor& fd, bool impl) { |
213 | const int dot_index = fd.name().find_last_of('.'); |
214 | return fd.name().substr(0, dot_index) + |
215 | (impl ? ".pb_text-impl.h" : ".pb_text.h" ); |
216 | } |
217 | |
218 | // Returns the name of the header generated by the proto library for <fd>. |
219 | string (const FileDescriptor& fd) { |
220 | const int dot_index = fd.name().find_last_of('.'); |
221 | return fd.name().substr(0, dot_index) + ".pb.h" ; |
222 | } |
223 | |
224 | // Returns the C++ class name for the given proto field. |
225 | string GetCppClass(const FieldDescriptor& d) { |
226 | string cpp_class = d.cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE |
227 | ? GetQualifiedName(*d.message_type()) |
228 | : d.cpp_type_name(); |
229 | |
230 | // In open-source TensorFlow, the definition of int64 varies across |
231 | // platforms. The following line, which is manipulated during internal- |
232 | // external sync'ing, takes care of the variability. |
233 | if (cpp_class == "int64" ) { |
234 | cpp_class = kProtobufInt64Typename; |
235 | } |
236 | |
237 | return cpp_class; |
238 | } |
239 | |
240 | // Returns the string that can be used for a header guard for the generated |
241 | // headers for <fd>, either for the public api (if impl is false) or the |
242 | // internal implementation header (if impl is true). |
243 | string (const FileDescriptor& fd, bool impl) { |
244 | string s = fd.name(); |
245 | std::replace(s.begin(), s.end(), '/', '_'); |
246 | std::replace(s.begin(), s.end(), '.', '_'); |
247 | return s + (impl ? "_IMPL_H_" : "_H_" ); |
248 | } |
249 | |
250 | void Generator::AppendFieldValueAppend(const FieldDescriptor& field, |
251 | const bool omit_default, |
252 | const string& field_expr) { |
253 | SetOutput(&cc_); |
254 | switch (field.cpp_type()) { |
255 | case FieldDescriptor::CPPTYPE_INT32: |
256 | case FieldDescriptor::CPPTYPE_INT64: |
257 | case FieldDescriptor::CPPTYPE_UINT32: |
258 | case FieldDescriptor::CPPTYPE_UINT64: |
259 | case FieldDescriptor::CPPTYPE_DOUBLE: |
260 | case FieldDescriptor::CPPTYPE_FLOAT: |
261 | Print("o->" , omit_default ? "AppendNumericIfNotZero" : "AppendNumeric" , |
262 | "(\"" , field.name(), "\", " , field_expr, ");" ); |
263 | break; |
264 | case FieldDescriptor::CPPTYPE_BOOL: |
265 | Print("o->" , omit_default ? "AppendBoolIfTrue" : "AppendBool" , "(\"" , |
266 | field.name(), "\", " , field_expr, ");" ); |
267 | break; |
268 | case FieldDescriptor::CPPTYPE_STRING: { |
269 | const auto ctype = field.options().ctype(); |
270 | CHECK(ctype == FieldOptions::CORD || ctype == FieldOptions::STRING) |
271 | << "Unsupported ctype " << ctype; |
272 | |
273 | Print("o->" , omit_default ? "AppendStringIfNotEmpty" : "AppendString" , |
274 | "(\"" , field.name(), "\", ProtobufStringToString(" , field_expr, |
275 | "));" ); |
276 | break; |
277 | } |
278 | case FieldDescriptor::CPPTYPE_ENUM: |
279 | if (omit_default) { |
280 | Print("if (" , field_expr, " != 0) {" ).Nest(); |
281 | } |
282 | Print("const char* enum_name = " , |
283 | GetQualifiedEnumNameFn(*field.enum_type()), "(" , field_expr, ");" ); |
284 | Print("if (enum_name[0]) {" ).Nest(); |
285 | Print("o->AppendEnumName(\"" , field.name(), "\", enum_name);" ); |
286 | Unnest().Print("} else {" ).Nest(); |
287 | Print("o->AppendNumeric(\"" , field.name(), "\", " , field_expr, ");" ); |
288 | Unnest().Print("}" ); |
289 | if (omit_default) { |
290 | Unnest().Print("}" ); |
291 | } |
292 | break; |
293 | case FieldDescriptor::CPPTYPE_MESSAGE: |
294 | CHECK(!field.message_type()->options().map_entry()); |
295 | if (omit_default) { |
296 | Print("if (msg.has_" , field.name(), "()) {" ).Nest(); |
297 | } |
298 | Print("o->OpenNestedMessage(\"" , field.name(), "\");" ); |
299 | Print(GetQualifiedAppendFn(*field.message_type()), "(o, " , field_expr, |
300 | ");" ); |
301 | Print("o->CloseNestedMessage();" ); |
302 | if (omit_default) { |
303 | Unnest().Print("}" ); |
304 | } |
305 | break; |
306 | } |
307 | } |
308 | |
309 | void Generator::AppendFieldAppend(const FieldDescriptor& field) { |
310 | const string& name = field.name(); |
311 | |
312 | if (field.is_map()) { |
313 | Print("{" ).Nest(); |
314 | const auto& key_type = *field.message_type()->FindFieldByName("key" ); |
315 | const auto& value_type = *field.message_type()->FindFieldByName("value" ); |
316 | |
317 | Print("std::vector<" , key_type.cpp_type_name(), "> keys;" ); |
318 | Print("for (const auto& e : msg." , name, "()) keys.push_back(e.first);" ); |
319 | Print("std::stable_sort(keys.begin(), keys.end());" ); |
320 | Print("for (const auto& key : keys) {" ).Nest(); |
321 | Print("o->OpenNestedMessage(\"" , name, "\");" ); |
322 | AppendFieldValueAppend(key_type, false /* omit_default */, "key" ); |
323 | AppendFieldValueAppend(value_type, false /* omit_default */, |
324 | StrCat("msg." , name, "().at(key)" )); |
325 | Print("o->CloseNestedMessage();" ); |
326 | Unnest().Print("}" ); |
327 | |
328 | Unnest().Print("}" ); |
329 | } else if (field.is_repeated()) { |
330 | Print("for (int i = 0; i < msg." , name, "_size(); ++i) {" ); |
331 | Nest(); |
332 | AppendFieldValueAppend(field, false /* omit_default */, |
333 | "msg." + name + "(i)" ); |
334 | Unnest().Print("}" ); |
335 | } else { |
336 | const auto* oneof = field.containing_oneof(); |
337 | if (oneof != nullptr) { |
338 | string camel_name = field.camelcase_name(); |
339 | camel_name[0] = toupper(camel_name[0]); |
340 | Print("if (msg." , oneof->name(), "_case() == " , |
341 | GetQualifiedName(*oneof->containing_type()), "::k" , camel_name, |
342 | ") {" ); |
343 | Nest(); |
344 | AppendFieldValueAppend(field, false /* omit_default */, |
345 | "msg." + name + "()" ); |
346 | Unnest(); |
347 | Print("}" ); |
348 | } else { |
349 | AppendFieldValueAppend(field, true /* omit_default */, |
350 | "msg." + name + "()" ); |
351 | } |
352 | } |
353 | } |
354 | |
355 | void Generator::AppendEnumFunctions(const EnumDescriptor& enum_d) { |
356 | const string sig = StrCat("const char* " , GetEnumNameFn(enum_d), "(\n " , |
357 | GetQualifiedName(enum_d), " value)" ); |
358 | SetOutput(&header_); |
359 | Print().Print("// Enum text output for " , string(enum_d.full_name())); |
360 | Print(sig, ";" ); |
361 | |
362 | SetOutput(&cc_); |
363 | Print().Print(sig, " {" ); |
364 | Nest().Print("switch (value) {" ).Nest(); |
365 | for (int i = 0; i < enum_d.value_count(); ++i) { |
366 | const auto& value = *enum_d.value(i); |
367 | Print("case " , value.number(), ": return \"" , value.name(), "\";" ); |
368 | } |
369 | Print("default: return \"\";" ); |
370 | Unnest().Print("}" ); |
371 | Unnest().Print("}" ); |
372 | } |
373 | |
374 | void Generator::AppendParseMessageFunction(const Descriptor& md) { |
375 | const bool map_append = (md.options().map_entry()); |
376 | string sig; |
377 | if (!map_append) { |
378 | sig = StrCat("bool ProtoParseFromString(\n const string& s,\n " , |
379 | GetQualifiedName(md), "* msg)" ); |
380 | SetOutput(&header_).Print(sig, "\n TF_MUST_USE_RESULT;" ); |
381 | |
382 | SetOutput(&cc_); |
383 | Print().Print(sig, " {" ).Nest(); |
384 | Print("msg->Clear();" ); |
385 | Print("Scanner scanner(s);" ); |
386 | Print("if (!internal::ProtoParseFromScanner(" , |
387 | "&scanner, false, false, msg)) return false;" ); |
388 | Print("scanner.Eos();" ); |
389 | Print("return scanner.GetResult();" ); |
390 | Unnest().Print("}" ); |
391 | } |
392 | |
393 | // Parse from scanner - the real work here. |
394 | sig = StrCat("bool ProtoParseFromScanner(" , |
395 | "\n ::tensorflow::strings::Scanner* scanner, bool nested, " |
396 | "bool close_curly,\n " ); |
397 | const FieldDescriptor* key_type = nullptr; |
398 | const FieldDescriptor* value_type = nullptr; |
399 | if (map_append) { |
400 | key_type = md.FindFieldByName("key" ); |
401 | value_type = md.FindFieldByName("value" ); |
402 | StrAppend(&sig, "::tensorflow::protobuf::Map<" , GetCppClass(*key_type), |
403 | ", " , GetCppClass(*value_type), ">* map)" ); |
404 | } else { |
405 | StrAppend(&sig, GetQualifiedName(md), "* msg)" ); |
406 | } |
407 | |
408 | if (!map_append_signatures_included_.insert(sig).second) { |
409 | // signature for function to append to a map of this type has |
410 | // already been defined in this .cc file. Don't define it again. |
411 | return; |
412 | } |
413 | |
414 | if (!map_append) { |
415 | SetOutput(&header_impl_).Print(sig, ";" ); |
416 | } |
417 | |
418 | SetOutput(&cc_); |
419 | Print().Print("namespace internal {" ); |
420 | if (map_append) { |
421 | Print("namespace {" ); |
422 | } |
423 | Print().Print(sig, " {" ).Nest(); |
424 | if (map_append) { |
425 | Print(GetCppClass(*key_type), " map_key;" ); |
426 | Print("bool set_map_key = false;" ); |
427 | Print(GetCppClass(*value_type), " map_value;" ); |
428 | Print("bool set_map_value = false;" ); |
429 | } |
430 | Print("std::vector<bool> has_seen(" , md.field_count(), ", false);" ); |
431 | Print("while(true) {" ).Nest(); |
432 | Print("ProtoSpaceAndComments(scanner);" ); |
433 | |
434 | // Emit success case |
435 | Print("if (nested && (scanner->Peek() == (close_curly ? '}' : '>'))) {" ) |
436 | .Nest(); |
437 | Print("scanner->One(Scanner::ALL);" ); |
438 | Print("ProtoSpaceAndComments(scanner);" ); |
439 | if (map_append) { |
440 | Print("if (!set_map_key || !set_map_value) return false;" ); |
441 | Print("(*map)[map_key] = map_value;" ); |
442 | } |
443 | Print("return true;" ); |
444 | Unnest().Print("}" ); |
445 | |
446 | Print("if (!nested && scanner->empty()) { return true; }" ); |
447 | Print("scanner->RestartCapture()" ); |
448 | Print(" .Many(Scanner::LETTER_DIGIT_UNDERSCORE)" ); |
449 | Print(" .StopCapture();" ); |
450 | Print("StringPiece identifier;" ); |
451 | Print("if (!scanner->GetResult(nullptr, &identifier)) return false;" ); |
452 | Print("bool parsed_colon = false;" ); |
453 | Print("(void)parsed_colon;" ); // Avoid "set but not used" compiler warning |
454 | Print("ProtoSpaceAndComments(scanner);" ); |
455 | Print("if (scanner->Peek() == ':') {" ); |
456 | Nest().Print("parsed_colon = true;" ); |
457 | Print("scanner->One(Scanner::ALL);" ); |
458 | Print("ProtoSpaceAndComments(scanner);" ); |
459 | Unnest().Print("}" ); |
460 | for (int i = 0; i < md.field_count(); ++i) { |
461 | const FieldDescriptor* field = md.field(i); |
462 | const string& field_name = field->name(); |
463 | string mutable_value_expr; |
464 | string set_value_prefix; |
465 | if (map_append) { |
466 | mutable_value_expr = StrCat("&map_" , field_name); |
467 | set_value_prefix = StrCat("map_" , field_name, " = " ); |
468 | } else if (field->is_repeated()) { |
469 | if (field->is_map()) { |
470 | mutable_value_expr = StrCat("msg->mutable_" , field_name, "()" ); |
471 | set_value_prefix = |
472 | "UNREACHABLE" ; // generator will never use this value. |
473 | } else { |
474 | mutable_value_expr = StrCat("msg->add_" , field_name, "()" ); |
475 | set_value_prefix = StrCat("msg->add_" , field_name); |
476 | } |
477 | } else { |
478 | mutable_value_expr = StrCat("msg->mutable_" , field_name, "()" ); |
479 | set_value_prefix = StrCat("msg->set_" , field_name); |
480 | } |
481 | |
482 | Print(i == 0 ? "" : "else " , "if (identifier == \"" , field_name, "\") {" ); |
483 | Nest(); |
484 | |
485 | if (field->is_repeated()) { |
486 | CHECK(!map_append); |
487 | |
488 | // Check to see if this is an array assignment, like a: [1, 2, 3] |
489 | Print("const bool is_list = (scanner->Peek() == '[');" ); |
490 | Print("do {" ); |
491 | // [ or , // skip |
492 | Nest().Print("if (is_list) {" ); |
493 | Nest().Print("scanner->One(Scanner::ALL);" ); |
494 | Print("ProtoSpaceAndComments(scanner);" ); |
495 | Unnest().Print("}" ); |
496 | } else if (field->containing_oneof() != nullptr) { |
497 | CHECK(!map_append); |
498 | |
499 | // Detect duplicate oneof value. |
500 | const string oneof_name = field->containing_oneof()->name(); |
501 | Print("if (msg->" , oneof_name, "_case() != 0) return false;" ); |
502 | } |
503 | |
504 | if (!field->is_repeated() && !map_append) { |
505 | // Detect duplicate nested repeated message. |
506 | Print("if (has_seen[" , i, "]) return false;" ); |
507 | Print("has_seen[" , i, "] = true;" ); |
508 | } |
509 | if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { |
510 | Print("const char open_char = scanner->Peek();" ); |
511 | Print("if (open_char != '{' && open_char != '<') return false;" ); |
512 | Print("scanner->One(Scanner::ALL);" ); |
513 | Print("ProtoSpaceAndComments(scanner);" ); |
514 | if (field->is_map()) { |
515 | Print("if (!ProtoParseFromScanner(" ); |
516 | } else { |
517 | Print("if (!" , GetPackageReferencePrefix(field->message_type()->file()), |
518 | "internal::ProtoParseFromScanner(" ); |
519 | } |
520 | Print(" scanner, true, open_char == '{', " , mutable_value_expr, |
521 | ")) return false;" ); |
522 | } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_STRING) { |
523 | Print("string str_value;" ); |
524 | Print( |
525 | "if (!parsed_colon || " |
526 | "!::tensorflow::strings::ProtoParseStringLiteralFromScanner(" ); |
527 | Print(" scanner, &str_value)) return false;" ); |
528 | Print("SetProtobufStringSwapAllowed(&str_value, " , mutable_value_expr, |
529 | ");" ); |
530 | } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { |
531 | Print("StringPiece value;" ); |
532 | Print( |
533 | "if (!parsed_colon || " |
534 | "!scanner->RestartCapture().Many(" |
535 | "Scanner::LETTER_DIGIT_DASH_UNDERSCORE)." |
536 | "GetResult(nullptr, &value)) return false;" ); |
537 | const auto* enum_d = field->enum_type(); |
538 | string value_prefix; |
539 | if (enum_d->containing_type() == nullptr) { |
540 | value_prefix = GetPackageReferencePrefix(enum_d->file()); |
541 | } else { |
542 | value_prefix = StrCat(GetQualifiedName(*enum_d), "_" ); |
543 | } |
544 | |
545 | for (int enum_i = 0; enum_i < enum_d->value_count(); ++enum_i) { |
546 | const auto* value_d = enum_d->value(enum_i); |
547 | const string& value_name = value_d->name(); |
548 | string condition = StrCat("value == \"" , value_name, "\"" ); |
549 | |
550 | Print(enum_i == 0 ? "" : "} else " , "if (" , condition, ") {" ); |
551 | Nest(); |
552 | Print(set_value_prefix, "(" , value_prefix, value_name, ");" ); |
553 | Unnest(); |
554 | } |
555 | Print("} else {" ); |
556 | Nest(); |
557 | // Proto3 allows all numeric values. |
558 | Print("int32 int_value;" ); |
559 | Print("if (strings::SafeStringToNumeric(value, &int_value)) {" ); |
560 | Nest(); |
561 | Print(set_value_prefix, "(static_cast<" , GetQualifiedName(*enum_d), |
562 | ">(int_value));" ); |
563 | Unnest(); |
564 | Print("} else {" ).Nest().Print("return false;" ).Unnest().Print("}" ); |
565 | Unnest().Print("}" ); |
566 | } else { |
567 | Print(field->cpp_type_name(), " value;" ); |
568 | switch (field->cpp_type()) { |
569 | case FieldDescriptor::CPPTYPE_INT32: |
570 | case FieldDescriptor::CPPTYPE_INT64: |
571 | case FieldDescriptor::CPPTYPE_UINT32: |
572 | case FieldDescriptor::CPPTYPE_UINT64: |
573 | case FieldDescriptor::CPPTYPE_DOUBLE: |
574 | case FieldDescriptor::CPPTYPE_FLOAT: |
575 | Print( |
576 | "if (!parsed_colon || " |
577 | "!::tensorflow::strings::ProtoParseNumericFromScanner(" , |
578 | "scanner, &value)) return false;" ); |
579 | break; |
580 | case FieldDescriptor::CPPTYPE_BOOL: |
581 | Print( |
582 | "if (!parsed_colon || " |
583 | "!::tensorflow::strings::ProtoParseBoolFromScanner(" , |
584 | "scanner, &value)) return false;" ); |
585 | break; |
586 | default: |
587 | LOG(FATAL) << "handled earlier" ; |
588 | } |
589 | Print(set_value_prefix, "(value);" ); |
590 | } |
591 | |
592 | if (field->is_repeated()) { |
593 | Unnest().Print("} while (is_list && scanner->Peek() == ',');" ); |
594 | Print( |
595 | "if (is_list && " |
596 | "!scanner->OneLiteral(\"]\").GetResult()) return false;" ); |
597 | } |
598 | if (map_append) { |
599 | Print("set_map_" , field_name, " = true;" ); |
600 | } |
601 | Unnest().Print("}" ); |
602 | } |
603 | Unnest().Print("}" ); |
604 | Unnest().Print("}" ); |
605 | Unnest().Print(); |
606 | if (map_append) { |
607 | Print("} // namespace" ); |
608 | } |
609 | Print("} // namespace internal" ); |
610 | } |
611 | |
612 | void Generator::AppendDebugStringFunctions(const Descriptor& md) { |
613 | SetOutput(&header_impl_).Print(); |
614 | SetOutput(&header_).Print().Print("// Message-text conversion for " , |
615 | string(md.full_name())); |
616 | |
617 | // Append the two debug string functions for <md>. |
618 | for (int short_pass = 0; short_pass < 2; ++short_pass) { |
619 | const bool short_debug = (short_pass == 1); |
620 | |
621 | // Make the Get functions. |
622 | const string sig = StrCat( |
623 | "string " , short_debug ? "ProtoShortDebugString" : "ProtoDebugString" , |
624 | "(\n const " , GetQualifiedName(md), "& msg)" ); |
625 | SetOutput(&header_).Print(sig, ";" ); |
626 | |
627 | SetOutput(&cc_); |
628 | Print().Print(sig, " {" ).Nest(); |
629 | Print("string s;" ); |
630 | Print("::tensorflow::strings::ProtoTextOutput o(&s, " , |
631 | short_debug ? "true" : "false" , ");" ); |
632 | Print("internal::AppendProtoDebugString(&o, msg);" ); |
633 | Print("o.CloseTopMessage();" ); |
634 | Print("return s;" ); |
635 | Unnest().Print("}" ); |
636 | } |
637 | |
638 | // Make the Append function. |
639 | const string sig = |
640 | StrCat("void AppendProtoDebugString(\n" , |
641 | " ::tensorflow::strings::ProtoTextOutput* o,\n const " , |
642 | GetQualifiedName(md), "& msg)" ); |
643 | SetOutput(&header_impl_).Print(sig, ";" ); |
644 | SetOutput(&cc_); |
645 | Print().Print("namespace internal {" ).Print(); |
646 | Print(sig, " {" ).Nest(); |
647 | std::vector<const FieldDescriptor*> fields; |
648 | fields.reserve(md.field_count()); |
649 | for (int i = 0; i < md.field_count(); ++i) { |
650 | fields.push_back(md.field(i)); |
651 | } |
652 | std::sort(fields.begin(), fields.end(), |
653 | [](const FieldDescriptor* left, const FieldDescriptor* right) { |
654 | return left->number() < right->number(); |
655 | }); |
656 | |
657 | for (const FieldDescriptor* field : fields) { |
658 | SetOutput(&cc_); |
659 | AppendFieldAppend(*field); |
660 | } |
661 | Unnest().Print("}" ).Print().Print("} // namespace internal" ); |
662 | } |
663 | |
664 | void Generator::AppendMessageFunctions(const Descriptor& md) { |
665 | if (md.options().map_entry()) { |
666 | // The 'map entry' Message is not a user-visible message type. Only its |
667 | // parse function is created (and that actually parsed the whole Map, not |
668 | // just the map entry). Printing of a map is done in the code generated for |
669 | // the containing message. |
670 | AppendParseMessageFunction(md); |
671 | return; |
672 | } |
673 | |
674 | // Recurse before adding the main message function, so that internal |
675 | // map_append functions are available before they are needed. |
676 | for (int i = 0; i < md.enum_type_count(); ++i) { |
677 | AppendEnumFunctions(*md.enum_type(i)); |
678 | } |
679 | for (int i = 0; i < md.nested_type_count(); ++i) { |
680 | AppendMessageFunctions(*md.nested_type(i)); |
681 | } |
682 | |
683 | AppendDebugStringFunctions(md); |
684 | AppendParseMessageFunction(md); |
685 | } |
686 | |
687 | void Generator::AddNamespaceToCurrentSection(const string& package, bool open) { |
688 | Print(); |
689 | std::vector<string> parts = {"" }; |
690 | for (size_t i = 0; i < package.size(); ++i) { |
691 | if (package[i] == '.') { |
692 | parts.resize(parts.size() + 1); |
693 | } else { |
694 | parts.back() += package[i]; |
695 | } |
696 | } |
697 | if (open) { |
698 | for (const auto& p : parts) { |
699 | Print("namespace " , p, " {" ); |
700 | } |
701 | } else { |
702 | for (auto it = parts.rbegin(); it != parts.rend(); ++it) { |
703 | Print("} // namespace " , *it); |
704 | } |
705 | } |
706 | } |
707 | |
708 | void Generator::(const std::vector<string>& ) { |
709 | std::vector<string> sorted = headers; |
710 | std::sort(sorted.begin(), sorted.end()); |
711 | for (const auto& h : sorted) { |
712 | Print("#include \"" , h, "\"" ); |
713 | } |
714 | } |
715 | |
716 | // Adds to <all_fd> and <all_d> with all descriptors recursively |
717 | // reachable from the given descriptor. |
718 | void GetAllFileDescriptorsFromFile(const FileDescriptor* fd, |
719 | std::set<const FileDescriptor*>* all_fd, |
720 | std::set<const Descriptor*>* all_d); |
721 | |
722 | // Adds to <all_fd> and <all_d> with all descriptors recursively |
723 | // reachable from the given descriptor. |
724 | void GetAllFileDescriptorsFromMessage(const Descriptor* d, |
725 | std::set<const FileDescriptor*>* all_fd, |
726 | std::set<const Descriptor*>* all_d) { |
727 | if (!all_d->insert(d).second) return; |
728 | GetAllFileDescriptorsFromFile(d->file(), all_fd, all_d); |
729 | for (int i = 0; i < d->field_count(); ++i) { |
730 | auto* f = d->field(i); |
731 | switch (f->cpp_type()) { |
732 | case FieldDescriptor::CPPTYPE_INT32: |
733 | case FieldDescriptor::CPPTYPE_INT64: |
734 | case FieldDescriptor::CPPTYPE_UINT32: |
735 | case FieldDescriptor::CPPTYPE_UINT64: |
736 | case FieldDescriptor::CPPTYPE_DOUBLE: |
737 | case FieldDescriptor::CPPTYPE_FLOAT: |
738 | case FieldDescriptor::CPPTYPE_BOOL: |
739 | case FieldDescriptor::CPPTYPE_STRING: |
740 | break; |
741 | case FieldDescriptor::CPPTYPE_MESSAGE: |
742 | GetAllFileDescriptorsFromMessage(f->message_type(), all_fd, all_d); |
743 | break; |
744 | case FieldDescriptor::CPPTYPE_ENUM: |
745 | GetAllFileDescriptorsFromFile(f->enum_type()->file(), all_fd, all_d); |
746 | break; |
747 | } |
748 | } |
749 | for (int i = 0; i < d->nested_type_count(); ++i) { |
750 | GetAllFileDescriptorsFromMessage(d->nested_type(i), all_fd, all_d); |
751 | } |
752 | } |
753 | |
754 | void GetAllFileDescriptorsFromFile(const FileDescriptor* fd, |
755 | std::set<const FileDescriptor*>* all_fd, |
756 | std::set<const Descriptor*>* all_d) { |
757 | if (!all_fd->insert(fd).second) return; |
758 | for (int i = 0; i < fd->message_type_count(); ++i) { |
759 | GetAllFileDescriptorsFromMessage(fd->message_type(i), all_fd, all_d); |
760 | } |
761 | } |
762 | |
763 | void Generator::Generate(const FileDescriptor& fd) { |
764 | // This does not emit code with proper proto2 semantics (e.g. it doesn't check |
765 | // 'has' fields on non-messages), so check that only proto3 is passed. |
766 | CHECK_EQ(fd.syntax(), FileDescriptor::SYNTAX_PROTO3) << fd.name(); |
767 | |
768 | const string package = fd.package(); |
769 | std::set<const FileDescriptor*> all_fd; |
770 | std::set<const Descriptor*> all_d; |
771 | GetAllFileDescriptorsFromFile(&fd, &all_fd, &all_d); |
772 | |
773 | std::vector<string> ; |
774 | |
775 | // Add header to header file. |
776 | SetOutput(&header_); |
777 | Print("// GENERATED FILE - DO NOT MODIFY" ); |
778 | Print("#ifndef " , GetHeaderGuard(fd, false /* impl */)); |
779 | Print("#define " , GetHeaderGuard(fd, false /* impl */)); |
780 | Print(); |
781 | headers = { |
782 | GetProtoHeaderName(fd), |
783 | StrCat(tf_header_prefix_, "tensorflow/core/platform/macros.h" ), |
784 | StrCat(tf_header_prefix_, "tensorflow/core/platform/protobuf.h" ), |
785 | StrCat(tf_header_prefix_, "tensorflow/core/platform/types.h" ), |
786 | }; |
787 | for (const auto& h : headers) { |
788 | Print("#include \"" , h, "\"" ); |
789 | } |
790 | AddNamespaceToCurrentSection(package, true /* is_open */); |
791 | |
792 | // Add header to impl file. |
793 | SetOutput(&header_impl_); |
794 | Print("// GENERATED FILE - DO NOT MODIFY" ); |
795 | Print("#ifndef " , GetHeaderGuard(fd, true /* impl */)); |
796 | Print("#define " , GetHeaderGuard(fd, true /* impl */)); |
797 | Print(); |
798 | headers = { |
799 | GetProtoTextHeaderName(fd, false /* impl */), |
800 | StrCat(tf_header_prefix_, |
801 | "tensorflow/core/lib/strings/proto_text_util.h" ), |
802 | StrCat(tf_header_prefix_, "tensorflow/core/lib/strings/scanner.h" ), |
803 | }; |
804 | for (const FileDescriptor* d : all_fd) { |
805 | if (d != &fd) { |
806 | headers.push_back(GetProtoTextHeaderName(*d, true /* impl */)); |
807 | } |
808 | headers.push_back(GetProtoHeaderName(*d)); |
809 | } |
810 | AddHeadersToCurrentSection(headers); |
811 | AddNamespaceToCurrentSection(package, true /* is_open */); |
812 | SetOutput(&header_impl_).Print().Print("namespace internal {" ); |
813 | |
814 | // Add header to cc file. |
815 | SetOutput(&cc_); |
816 | Print("// GENERATED FILE - DO NOT MODIFY" ); |
817 | Print(); |
818 | Print("#include <algorithm>" ); // for `std::stable_sort()` |
819 | Print(); |
820 | headers = {GetProtoTextHeaderName(fd, true /* impl */)}; |
821 | AddHeadersToCurrentSection(headers); |
822 | Print(); |
823 | Print("using ::tensorflow::strings::ProtoSpaceAndComments;" ); |
824 | Print("using ::tensorflow::strings::Scanner;" ); |
825 | Print("using ::tensorflow::strings::StrCat;" ); |
826 | AddNamespaceToCurrentSection(package, true /* is_open */); |
827 | |
828 | // Add declarations and definitions. |
829 | for (int i = 0; i < fd.enum_type_count(); ++i) { |
830 | AppendEnumFunctions(*fd.enum_type(i)); |
831 | } |
832 | for (int i = 0; i < fd.message_type_count(); ++i) { |
833 | AppendMessageFunctions(*fd.message_type(i)); |
834 | } |
835 | |
836 | // Add footer to header file. |
837 | SetOutput(&header_); |
838 | AddNamespaceToCurrentSection(package, false /* is_open */); |
839 | Print().Print("#endif // " , GetHeaderGuard(fd, false /* impl */)); |
840 | |
841 | // Add footer to header impl file. |
842 | SetOutput(&header_impl_).Print().Print("} // namespace internal" ); |
843 | AddNamespaceToCurrentSection(package, false /* is_open */); |
844 | Print().Print("#endif // " , GetHeaderGuard(fd, true /* impl */)); |
845 | |
846 | // Add footer to cc file. |
847 | SetOutput(&cc_); |
848 | AddNamespaceToCurrentSection(package, false /* is_open */); |
849 | } |
850 | |
851 | } // namespace |
852 | |
853 | ProtoTextFunctionCode GetProtoTextFunctionCode(const FileDescriptor& fd, |
854 | const string& ) { |
855 | Generator gen(tf_header_prefix); |
856 | gen.Generate(fd); |
857 | return gen.code(); |
858 | } |
859 | |
860 | } // namespace tensorflow |
861 | |