1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/tools/graph_transforms/transform_graph.h" |
17 | |
18 | #include "tensorflow/core/framework/function.pb.h" |
19 | #include "tensorflow/core/lib/strings/scanner.h" |
20 | #include "tensorflow/core/lib/strings/str_util.h" |
21 | #include "tensorflow/core/platform/env.h" |
22 | #include "tensorflow/core/platform/init_main.h" |
23 | #include "tensorflow/core/platform/logging.h" |
24 | #include "tensorflow/core/util/command_line_flags.h" |
25 | #include "tensorflow/tools/graph_transforms/file_utils.h" |
26 | #include "tensorflow/tools/graph_transforms/transform_utils.h" |
27 | #if !defined(PLATFORM_WINDOWS) |
28 | #include <pwd.h> |
29 | #include <unistd.h> |
30 | #endif |
31 | |
32 | namespace tensorflow { |
33 | namespace graph_transforms { |
34 | |
35 | using tensorflow::strings::Scanner; |
36 | |
37 | Status ParseTransformParameters(const string& transforms_string, |
38 | TransformParameters* params_list) { |
39 | params_list->clear(); |
40 | enum { |
41 | TRANSFORM_NAME, |
42 | TRANSFORM_PARAM_NAME, |
43 | TRANSFORM_PARAM_VALUE, |
44 | } state = TRANSFORM_NAME; |
45 | StringPiece remaining(transforms_string); |
46 | StringPiece match; |
47 | StringPiece transform_name; |
48 | StringPiece parameter_name; |
49 | StringPiece parameter_value; |
50 | TransformFuncParameters func_parameters; |
51 | while (!remaining.empty()) { |
52 | if (state == TRANSFORM_NAME) { |
53 | // Reset the list of parameters. |
54 | func_parameters.clear(); |
55 | // Eat up any leading spaces. |
56 | Scanner(remaining).AnySpace().GetResult(&remaining, &match); |
57 | if (remaining.empty()) { |
58 | // Nothing remains after consuming trailing spaces. |
59 | // Consumed all transform parameter string without errors. |
60 | return OkStatus(); |
61 | } |
62 | // See if we have a valid transform name. |
63 | const bool found_transform_name = |
64 | Scanner(remaining) |
65 | .Many(Scanner::LETTER_DIGIT_UNDERSCORE) |
66 | .GetResult(&remaining, &transform_name); |
67 | if (!found_transform_name) { |
68 | return errors::InvalidArgument("Looking for transform name, but found " , |
69 | string(remaining).c_str()); |
70 | } |
71 | if (Scanner(remaining).OneLiteral("(" ).GetResult(&remaining, &match)) { |
72 | state = TRANSFORM_PARAM_NAME; |
73 | } else { |
74 | // Add a transform with no parameters. |
75 | params_list->push_back({string(transform_name), func_parameters}); |
76 | transform_name = "" ; |
77 | state = TRANSFORM_NAME; |
78 | } |
79 | } else if (state == TRANSFORM_PARAM_NAME) { |
80 | if (Scanner(remaining).OneLiteral(")" ).GetResult(&remaining, &match)) { |
81 | params_list->push_back({string(transform_name), func_parameters}); |
82 | transform_name = "" ; |
83 | state = TRANSFORM_NAME; |
84 | } else { |
85 | // Eat up any leading spaces or commas. |
86 | Scanner(remaining).ZeroOrOneLiteral("," ).GetResult(&remaining, &match); |
87 | Scanner(remaining).AnySpace().GetResult(&remaining, &match); |
88 | // See if we have a valid parameter name. |
89 | const bool found_parameter_name = |
90 | Scanner(remaining) |
91 | .Many(Scanner::LETTER_DIGIT_UNDERSCORE) |
92 | .GetResult(&remaining, ¶meter_name); |
93 | if (!found_parameter_name) { |
94 | return errors::InvalidArgument( |
95 | "Looking for parameter name, but found " , |
96 | string(remaining).c_str()); |
97 | } |
98 | if (Scanner(remaining).OneLiteral("=" ).GetResult(&remaining, &match)) { |
99 | state = TRANSFORM_PARAM_VALUE; |
100 | } else { |
101 | return errors::InvalidArgument("Looking for =, but found " , |
102 | string(remaining).c_str()); |
103 | } |
104 | } |
105 | } else if (state == TRANSFORM_PARAM_VALUE) { |
106 | bool found_parameter_value; |
107 | // Deal with quoted values. |
108 | if (Scanner(remaining).OneLiteral("\"" ).GetResult(&remaining, &match)) { |
109 | found_parameter_value = |
110 | Scanner(remaining).ScanEscapedUntil('"').GetResult( |
111 | &remaining, ¶meter_value); |
112 | if (found_parameter_value) { |
113 | Scanner(remaining).OneLiteral("\"" ).GetResult(&remaining, &match); |
114 | } |
115 | } else { |
116 | // See if we have a valid parameter name. |
117 | found_parameter_value = |
118 | Scanner(remaining) |
119 | .Many(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) |
120 | .GetResult(&remaining, ¶meter_value); |
121 | } |
122 | if (!found_parameter_value) { |
123 | return errors::InvalidArgument("Looking for parameter name, but found " , |
124 | string(remaining).c_str()); |
125 | } |
126 | func_parameters[string(parameter_name)].emplace_back(parameter_value); |
127 | // Eat up any trailing quotes. |
128 | Scanner(remaining).ZeroOrOneLiteral("\"" ).GetResult(&remaining, &match); |
129 | Scanner(remaining).ZeroOrOneLiteral("'" ).GetResult(&remaining, &match); |
130 | state = TRANSFORM_PARAM_NAME; |
131 | } |
132 | } |
133 | return OkStatus(); |
134 | } |
135 | |
136 | std::string ExpandPath(const std::string& path_string) { |
137 | #if defined(PLATFORM_WINDOWS) |
138 | return path_string; |
139 | #else |
140 | if (path_string.empty() || path_string[0] != '~') { |
141 | return path_string; |
142 | } |
143 | |
144 | const char* home = nullptr; |
145 | std::string::size_type prefix = path_string.find_first_of('/'); |
146 | if (path_string.length() == 1 || prefix == 1) { |
147 | // The value of $HOME, e.g., ~/foo |
148 | home = getenv("HOME" ); |
149 | if (!home) { |
150 | // If HOME is not available, get uid |
151 | struct passwd* pw = getpwuid(getuid()); |
152 | if (pw) { |
153 | home = pw->pw_dir; |
154 | } |
155 | } |
156 | } else { |
157 | // The value of ~user, e.g., ~user/foo |
158 | std::string user(path_string, 1, (prefix == std::string::npos) |
159 | ? std::string::npos |
160 | : prefix - 1); |
161 | struct passwd* pw = getpwnam(user.c_str()); |
162 | if (pw) { |
163 | home = pw->pw_dir; |
164 | } |
165 | } |
166 | |
167 | if (!home) { |
168 | return path_string; |
169 | } |
170 | |
171 | string path(home); |
172 | if (prefix == std::string::npos) { |
173 | return path; |
174 | } |
175 | |
176 | if (path.length() == 0 || path[path.length() - 1] != '/') { |
177 | path += '/'; |
178 | } |
179 | path += path_string.substr(prefix + 1); |
180 | return path; |
181 | #endif |
182 | } |
183 | |
184 | int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) { |
185 | string in_graph_string = "" ; |
186 | string out_graph_string = "" ; |
187 | string inputs_string = "" ; |
188 | string outputs_string = "" ; |
189 | string transforms_string = "" ; |
190 | bool output_as_text = false; |
191 | std::vector<Flag> flag_list = { |
192 | Flag("in_graph" , &in_graph_string, "input graph file name" ), |
193 | Flag("out_graph" , &out_graph_string, "output graph file name" ), |
194 | Flag("inputs" , &inputs_string, "inputs" ), |
195 | Flag("outputs" , &outputs_string, "outputs" ), |
196 | Flag("transforms" , &transforms_string, "list of transforms" ), |
197 | Flag("output_as_text" , &output_as_text, |
198 | "whether to write the graph in text protobuf format" ), |
199 | }; |
200 | string usage = Flags::Usage(argv[0], flag_list); |
201 | usage += "\nTransforms are:\n" ; |
202 | TransformRegistry* transform_registry = GetTransformRegistry(); |
203 | for (const auto& pair : *transform_registry) { |
204 | usage += pair.first + "\n" ; |
205 | } |
206 | |
207 | const bool parse_result = Flags::Parse(&argc, argv, flag_list); |
208 | // We need to call this to set up global state for TensorFlow. |
209 | if (init_main) { |
210 | port::InitMain(argv[0], &argc, &argv); |
211 | } |
212 | if (!parse_result) { |
213 | LOG(ERROR) << usage; |
214 | return -1; |
215 | } |
216 | if (argc > 1) { |
217 | LOG(ERROR) << "Unknown argument " << argv[1] << ".\n" << usage; |
218 | return -1; |
219 | } |
220 | if (in_graph_string.empty()) { |
221 | LOG(ERROR) << "in_graph graph can't be empty.\n" << usage; |
222 | return -1; |
223 | } |
224 | if (out_graph_string.empty()) { |
225 | LOG(ERROR) << "out_graph graph can't be empty.\n" << usage; |
226 | return -1; |
227 | } |
228 | if (transforms_string.empty()) { |
229 | LOG(ERROR) << "You must specify at least one transform.\n" << usage; |
230 | return -1; |
231 | } |
232 | |
233 | string in_graph = ExpandPath(in_graph_string); |
234 | string out_graph = ExpandPath(out_graph_string); |
235 | |
236 | std::vector<string> inputs = str_util::Split(inputs_string, ','); |
237 | std::vector<string> outputs = str_util::Split(outputs_string, ','); |
238 | TransformParameters transform_params; |
239 | Status parse_status = |
240 | ParseTransformParameters(transforms_string, &transform_params); |
241 | if (!parse_status.ok()) { |
242 | LOG(ERROR) << "Failed to parse --transform argument, error was " |
243 | << parse_status.error_message(); |
244 | return -1; |
245 | } |
246 | if (transform_params.empty()) { |
247 | LOG(ERROR) << "You must specify at least one transform.\n" << usage; |
248 | return -1; |
249 | } |
250 | |
251 | GraphDef graph_def; |
252 | Status load_status = LoadTextOrBinaryGraphFile(in_graph, &graph_def); |
253 | if (!load_status.ok()) { |
254 | LOG(ERROR) << "Loading graph '" << in_graph_string << "' failed with " |
255 | << load_status.error_message(); |
256 | LOG(ERROR) << usage; |
257 | return -1; |
258 | } |
259 | |
260 | Status transform_result = |
261 | TransformGraph(inputs, outputs, transform_params, &graph_def); |
262 | |
263 | if (!transform_result.ok()) { |
264 | LOG(ERROR) << transform_result.error_message(); |
265 | LOG(ERROR) << usage; |
266 | return -1; |
267 | } |
268 | |
269 | Status save_status; |
270 | if (output_as_text) { |
271 | save_status = WriteTextProto(Env::Default(), out_graph, graph_def); |
272 | } else { |
273 | save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def); |
274 | } |
275 | if (!save_status.ok()) { |
276 | LOG(ERROR) << "Saving graph '" << out_graph_string << "' failed with " |
277 | << save_status.error_message(); |
278 | return -1; |
279 | } |
280 | |
281 | return 0; |
282 | } |
283 | |
284 | Status ShouldIgnoreErrors(const TransformFuncParameters& transform_params, |
285 | bool* ignore_errors) { |
286 | *ignore_errors = false; |
287 | if (transform_params.count("ignore_errors" ) && |
288 | (!transform_params.at("ignore_errors" ).empty())) { |
289 | const string& ignore_errors_string = |
290 | absl::AsciiStrToLower(transform_params.at("ignore_errors" ).at(0)); |
291 | if (ignore_errors_string == "true" ) { |
292 | *ignore_errors = true; |
293 | } else if (ignore_errors_string == "false" ) { |
294 | *ignore_errors = false; |
295 | } else { |
296 | return errors::InvalidArgument( |
297 | "ignore_errors should be true or false, found " , |
298 | ignore_errors_string); |
299 | } |
300 | } |
301 | return OkStatus(); |
302 | } |
303 | |
304 | Status TransformGraph(const std::vector<string>& inputs, |
305 | const std::vector<string>& outputs, |
306 | const TransformParameters& transform_params, |
307 | GraphDef* graph_def) { |
308 | TransformRegistry* transform_registry = GetTransformRegistry(); |
309 | for (const auto& transform_info : transform_params) { |
310 | const string& transform_name = transform_info.first; |
311 | if (transform_name.empty()) { |
312 | continue; |
313 | } |
314 | if (!transform_registry->count(transform_name)) { |
315 | return errors::InvalidArgument("Transform '" , transform_name, |
316 | "' not recognized." ); |
317 | } |
318 | LOG(INFO) << "Applying " << transform_name; |
319 | const TransformFunc& transform_func = |
320 | transform_registry->at(transform_name); |
321 | TransformFuncContext context; |
322 | context.input_names = inputs; |
323 | context.output_names = outputs; |
324 | context.params = transform_info.second; |
325 | bool ignore_errors; |
326 | TF_RETURN_IF_ERROR( |
327 | ShouldIgnoreErrors(transform_info.second, &ignore_errors)); |
328 | GraphDef transformed_graph_def; |
329 | Status transform_result = |
330 | transform_func(*graph_def, context, &transformed_graph_def); |
331 | if (!transform_result.ok()) { |
332 | if (ignore_errors) { |
333 | LOG(ERROR) << transform_name << ": Ignoring error " |
334 | << transform_result.error_message(); |
335 | transformed_graph_def = *graph_def; |
336 | } else { |
337 | return transform_result; |
338 | } |
339 | } |
340 | // Copy over the library from the original input graph. |
341 | *transformed_graph_def.mutable_library() = graph_def->library(); |
342 | TF_RETURN_IF_ERROR(IsGraphValid(transformed_graph_def)); |
343 | |
344 | *graph_def = transformed_graph_def; |
345 | } |
346 | return OkStatus(); |
347 | } |
348 | } // namespace graph_transforms |
349 | } // namespace tensorflow |
350 | |