1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <stdio.h> |
17 | |
18 | #include <set> |
19 | |
20 | #include "tensorflow/core/platform/logging.h" |
21 | #include "tensorflow/core/platform/protobuf.h" |
22 | #include "tensorflow/core/platform/types.h" |
23 | #include "tensorflow/tools/proto_text/gen_proto_text_functions_lib.h" |
24 | #include "tensorflow/tsl/platform/protobuf_compiler.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | namespace { |
29 | class CrashOnErrorCollector |
30 | : public tensorflow::protobuf::compiler::MultiFileErrorCollector { |
31 | public: |
32 | ~CrashOnErrorCollector() override {} |
33 | |
34 | void AddError(const string& filename, int line, int column, |
35 | const string& message) override { |
36 | LOG(FATAL) << "Unexpected error at " << filename << "@" << line << ":" |
37 | << column << " - " << message; |
38 | } |
39 | }; |
40 | |
41 | static const char [] = "" ; |
42 | |
43 | static const char kPlaceholderFile[] = |
44 | "tensorflow/tools/proto_text/placeholder.txt" ; |
45 | |
46 | bool IsPlaceholderFile(const char* s) { |
47 | string ph(kPlaceholderFile); |
48 | string str(s); |
49 | return str.size() >= strlen(kPlaceholderFile) && |
50 | ph == str.substr(str.size() - ph.size()); |
51 | } |
52 | |
53 | } // namespace |
54 | |
55 | // Main program to take input protos and write output pb_text source files that |
56 | // contain generated proto text input and output functions. |
57 | // |
58 | // Main expects: |
59 | // - First argument is output path |
60 | // - Second argument is the relative path of the protos to the root. E.g., |
61 | // for protos built by a rule in tensorflow/core, this will be |
62 | // tensorflow/core. |
63 | // - Then any number of source proto file names, plus one source name must be |
64 | // placeholder.txt from this gen tool's package. placeholder.txt is |
65 | // ignored for proto resolution, but is used to determine the root at which |
66 | // the build tool has placed the source proto files. |
67 | // |
68 | // Note that this code doesn't use tensorflow's command line parsing, because of |
69 | // circular dependencies between libraries if that were done. |
70 | // |
71 | // This is meant to be invoked by a genrule. See BUILD for more information. |
72 | int MainImpl(int argc, char** argv) { |
73 | if (argc < 4) { |
74 | LOG(ERROR) << "Pass output path, relative path, and at least proto file" ; |
75 | return -1; |
76 | } |
77 | |
78 | const string output_root = argv[1]; |
79 | const string output_relative_path = kTensorFlowHeaderPrefix + string(argv[2]); |
80 | |
81 | string src_relative_path; |
82 | bool has_placeholder = false; |
83 | for (int i = 3; i < argc; ++i) { |
84 | if (IsPlaceholderFile(argv[i])) { |
85 | const string s(argv[i]); |
86 | src_relative_path = s.substr(0, s.size() - strlen(kPlaceholderFile)); |
87 | has_placeholder = true; |
88 | } |
89 | } |
90 | if (!has_placeholder) { |
91 | LOG(ERROR) << kPlaceholderFile << " must be passed" ; |
92 | return -1; |
93 | } |
94 | |
95 | tensorflow::protobuf::compiler::DiskSourceTree source_tree; |
96 | |
97 | source_tree.MapPath("" , src_relative_path.empty() ? "." : src_relative_path); |
98 | CrashOnErrorCollector crash_on_error; |
99 | tensorflow::protobuf::compiler::Importer importer(&source_tree, |
100 | &crash_on_error); |
101 | |
102 | for (int i = 3; i < argc; i++) { |
103 | if (IsPlaceholderFile(argv[i])) continue; |
104 | const string proto_path = string(argv[i]).substr(src_relative_path.size()); |
105 | |
106 | const tensorflow::protobuf::FileDescriptor* fd = |
107 | importer.Import(proto_path); |
108 | |
109 | const int index = proto_path.find_last_of('.'); |
110 | string proto_path_no_suffix = proto_path.substr(0, index); |
111 | |
112 | proto_path_no_suffix = |
113 | proto_path_no_suffix.substr(output_relative_path.size()); |
114 | |
115 | const auto code = |
116 | tensorflow::GetProtoTextFunctionCode(*fd, kTensorFlowHeaderPrefix); |
117 | |
118 | // Three passes, one for each output file. |
119 | for (int pass = 0; pass < 3; ++pass) { |
120 | string suffix; |
121 | string data; |
122 | if (pass == 0) { |
123 | suffix = ".pb_text.h" ; |
124 | data = code.header; |
125 | } else if (pass == 1) { |
126 | suffix = ".pb_text-impl.h" ; |
127 | data = code.header_impl; |
128 | } else { |
129 | suffix = ".pb_text.cc" ; |
130 | data = code.cc; |
131 | } |
132 | |
133 | const string path = output_root + "/" + proto_path_no_suffix + suffix; |
134 | FILE* f = fopen(path.c_str(), "w" ); |
135 | if (f == nullptr) { |
136 | // We don't expect this output to be generated. It was specified in the |
137 | // list of sources solely to satisfy a proto import dependency. |
138 | continue; |
139 | } |
140 | if (fwrite(data.c_str(), 1, data.size(), f) != data.size()) { |
141 | fclose(f); |
142 | return -1; |
143 | } |
144 | if (fclose(f) != 0) { |
145 | return -1; |
146 | } |
147 | } |
148 | } |
149 | return 0; |
150 | } |
151 | |
152 | } // namespace tensorflow |
153 | |
154 | int main(int argc, char** argv) { return tensorflow::MainImpl(argc, argv); } |
155 | |