1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/tools/graph_transforms/transform_graph.h"
17
18#include "tensorflow/core/framework/function.pb.h"
19#include "tensorflow/core/lib/strings/scanner.h"
20#include "tensorflow/core/lib/strings/str_util.h"
21#include "tensorflow/core/platform/env.h"
22#include "tensorflow/core/platform/init_main.h"
23#include "tensorflow/core/platform/logging.h"
24#include "tensorflow/core/util/command_line_flags.h"
25#include "tensorflow/tools/graph_transforms/file_utils.h"
26#include "tensorflow/tools/graph_transforms/transform_utils.h"
27#if !defined(PLATFORM_WINDOWS)
28#include <pwd.h>
29#include <unistd.h>
30#endif
31
32namespace tensorflow {
33namespace graph_transforms {
34
35using tensorflow::strings::Scanner;
36
37Status ParseTransformParameters(const string& transforms_string,
38 TransformParameters* params_list) {
39 params_list->clear();
40 enum {
41 TRANSFORM_NAME,
42 TRANSFORM_PARAM_NAME,
43 TRANSFORM_PARAM_VALUE,
44 } state = TRANSFORM_NAME;
45 StringPiece remaining(transforms_string);
46 StringPiece match;
47 StringPiece transform_name;
48 StringPiece parameter_name;
49 StringPiece parameter_value;
50 TransformFuncParameters func_parameters;
51 while (!remaining.empty()) {
52 if (state == TRANSFORM_NAME) {
53 // Reset the list of parameters.
54 func_parameters.clear();
55 // Eat up any leading spaces.
56 Scanner(remaining).AnySpace().GetResult(&remaining, &match);
57 if (remaining.empty()) {
58 // Nothing remains after consuming trailing spaces.
59 // Consumed all transform parameter string without errors.
60 return OkStatus();
61 }
62 // See if we have a valid transform name.
63 const bool found_transform_name =
64 Scanner(remaining)
65 .Many(Scanner::LETTER_DIGIT_UNDERSCORE)
66 .GetResult(&remaining, &transform_name);
67 if (!found_transform_name) {
68 return errors::InvalidArgument("Looking for transform name, but found ",
69 string(remaining).c_str());
70 }
71 if (Scanner(remaining).OneLiteral("(").GetResult(&remaining, &match)) {
72 state = TRANSFORM_PARAM_NAME;
73 } else {
74 // Add a transform with no parameters.
75 params_list->push_back({string(transform_name), func_parameters});
76 transform_name = "";
77 state = TRANSFORM_NAME;
78 }
79 } else if (state == TRANSFORM_PARAM_NAME) {
80 if (Scanner(remaining).OneLiteral(")").GetResult(&remaining, &match)) {
81 params_list->push_back({string(transform_name), func_parameters});
82 transform_name = "";
83 state = TRANSFORM_NAME;
84 } else {
85 // Eat up any leading spaces or commas.
86 Scanner(remaining).ZeroOrOneLiteral(",").GetResult(&remaining, &match);
87 Scanner(remaining).AnySpace().GetResult(&remaining, &match);
88 // See if we have a valid parameter name.
89 const bool found_parameter_name =
90 Scanner(remaining)
91 .Many(Scanner::LETTER_DIGIT_UNDERSCORE)
92 .GetResult(&remaining, &parameter_name);
93 if (!found_parameter_name) {
94 return errors::InvalidArgument(
95 "Looking for parameter name, but found ",
96 string(remaining).c_str());
97 }
98 if (Scanner(remaining).OneLiteral("=").GetResult(&remaining, &match)) {
99 state = TRANSFORM_PARAM_VALUE;
100 } else {
101 return errors::InvalidArgument("Looking for =, but found ",
102 string(remaining).c_str());
103 }
104 }
105 } else if (state == TRANSFORM_PARAM_VALUE) {
106 bool found_parameter_value;
107 // Deal with quoted values.
108 if (Scanner(remaining).OneLiteral("\"").GetResult(&remaining, &match)) {
109 found_parameter_value =
110 Scanner(remaining).ScanEscapedUntil('"').GetResult(
111 &remaining, &parameter_value);
112 if (found_parameter_value) {
113 Scanner(remaining).OneLiteral("\"").GetResult(&remaining, &match);
114 }
115 } else {
116 // See if we have a valid parameter name.
117 found_parameter_value =
118 Scanner(remaining)
119 .Many(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
120 .GetResult(&remaining, &parameter_value);
121 }
122 if (!found_parameter_value) {
123 return errors::InvalidArgument("Looking for parameter name, but found ",
124 string(remaining).c_str());
125 }
126 func_parameters[string(parameter_name)].emplace_back(parameter_value);
127 // Eat up any trailing quotes.
128 Scanner(remaining).ZeroOrOneLiteral("\"").GetResult(&remaining, &match);
129 Scanner(remaining).ZeroOrOneLiteral("'").GetResult(&remaining, &match);
130 state = TRANSFORM_PARAM_NAME;
131 }
132 }
133 return OkStatus();
134}
135
136std::string ExpandPath(const std::string& path_string) {
137#if defined(PLATFORM_WINDOWS)
138 return path_string;
139#else
140 if (path_string.empty() || path_string[0] != '~') {
141 return path_string;
142 }
143
144 const char* home = nullptr;
145 std::string::size_type prefix = path_string.find_first_of('/');
146 if (path_string.length() == 1 || prefix == 1) {
147 // The value of $HOME, e.g., ~/foo
148 home = getenv("HOME");
149 if (!home) {
150 // If HOME is not available, get uid
151 struct passwd* pw = getpwuid(getuid());
152 if (pw) {
153 home = pw->pw_dir;
154 }
155 }
156 } else {
157 // The value of ~user, e.g., ~user/foo
158 std::string user(path_string, 1, (prefix == std::string::npos)
159 ? std::string::npos
160 : prefix - 1);
161 struct passwd* pw = getpwnam(user.c_str());
162 if (pw) {
163 home = pw->pw_dir;
164 }
165 }
166
167 if (!home) {
168 return path_string;
169 }
170
171 string path(home);
172 if (prefix == std::string::npos) {
173 return path;
174 }
175
176 if (path.length() == 0 || path[path.length() - 1] != '/') {
177 path += '/';
178 }
179 path += path_string.substr(prefix + 1);
180 return path;
181#endif
182}
183
184int ParseFlagsAndTransformGraph(int argc, char* argv[], bool init_main) {
185 string in_graph_string = "";
186 string out_graph_string = "";
187 string inputs_string = "";
188 string outputs_string = "";
189 string transforms_string = "";
190 bool output_as_text = false;
191 std::vector<Flag> flag_list = {
192 Flag("in_graph", &in_graph_string, "input graph file name"),
193 Flag("out_graph", &out_graph_string, "output graph file name"),
194 Flag("inputs", &inputs_string, "inputs"),
195 Flag("outputs", &outputs_string, "outputs"),
196 Flag("transforms", &transforms_string, "list of transforms"),
197 Flag("output_as_text", &output_as_text,
198 "whether to write the graph in text protobuf format"),
199 };
200 string usage = Flags::Usage(argv[0], flag_list);
201 usage += "\nTransforms are:\n";
202 TransformRegistry* transform_registry = GetTransformRegistry();
203 for (const auto& pair : *transform_registry) {
204 usage += pair.first + "\n";
205 }
206
207 const bool parse_result = Flags::Parse(&argc, argv, flag_list);
208 // We need to call this to set up global state for TensorFlow.
209 if (init_main) {
210 port::InitMain(argv[0], &argc, &argv);
211 }
212 if (!parse_result) {
213 LOG(ERROR) << usage;
214 return -1;
215 }
216 if (argc > 1) {
217 LOG(ERROR) << "Unknown argument " << argv[1] << ".\n" << usage;
218 return -1;
219 }
220 if (in_graph_string.empty()) {
221 LOG(ERROR) << "in_graph graph can't be empty.\n" << usage;
222 return -1;
223 }
224 if (out_graph_string.empty()) {
225 LOG(ERROR) << "out_graph graph can't be empty.\n" << usage;
226 return -1;
227 }
228 if (transforms_string.empty()) {
229 LOG(ERROR) << "You must specify at least one transform.\n" << usage;
230 return -1;
231 }
232
233 string in_graph = ExpandPath(in_graph_string);
234 string out_graph = ExpandPath(out_graph_string);
235
236 std::vector<string> inputs = str_util::Split(inputs_string, ',');
237 std::vector<string> outputs = str_util::Split(outputs_string, ',');
238 TransformParameters transform_params;
239 Status parse_status =
240 ParseTransformParameters(transforms_string, &transform_params);
241 if (!parse_status.ok()) {
242 LOG(ERROR) << "Failed to parse --transform argument, error was "
243 << parse_status.error_message();
244 return -1;
245 }
246 if (transform_params.empty()) {
247 LOG(ERROR) << "You must specify at least one transform.\n" << usage;
248 return -1;
249 }
250
251 GraphDef graph_def;
252 Status load_status = LoadTextOrBinaryGraphFile(in_graph, &graph_def);
253 if (!load_status.ok()) {
254 LOG(ERROR) << "Loading graph '" << in_graph_string << "' failed with "
255 << load_status.error_message();
256 LOG(ERROR) << usage;
257 return -1;
258 }
259
260 Status transform_result =
261 TransformGraph(inputs, outputs, transform_params, &graph_def);
262
263 if (!transform_result.ok()) {
264 LOG(ERROR) << transform_result.error_message();
265 LOG(ERROR) << usage;
266 return -1;
267 }
268
269 Status save_status;
270 if (output_as_text) {
271 save_status = WriteTextProto(Env::Default(), out_graph, graph_def);
272 } else {
273 save_status = WriteBinaryProto(Env::Default(), out_graph, graph_def);
274 }
275 if (!save_status.ok()) {
276 LOG(ERROR) << "Saving graph '" << out_graph_string << "' failed with "
277 << save_status.error_message();
278 return -1;
279 }
280
281 return 0;
282}
283
284Status ShouldIgnoreErrors(const TransformFuncParameters& transform_params,
285 bool* ignore_errors) {
286 *ignore_errors = false;
287 if (transform_params.count("ignore_errors") &&
288 (!transform_params.at("ignore_errors").empty())) {
289 const string& ignore_errors_string =
290 absl::AsciiStrToLower(transform_params.at("ignore_errors").at(0));
291 if (ignore_errors_string == "true") {
292 *ignore_errors = true;
293 } else if (ignore_errors_string == "false") {
294 *ignore_errors = false;
295 } else {
296 return errors::InvalidArgument(
297 "ignore_errors should be true or false, found ",
298 ignore_errors_string);
299 }
300 }
301 return OkStatus();
302}
303
304Status TransformGraph(const std::vector<string>& inputs,
305 const std::vector<string>& outputs,
306 const TransformParameters& transform_params,
307 GraphDef* graph_def) {
308 TransformRegistry* transform_registry = GetTransformRegistry();
309 for (const auto& transform_info : transform_params) {
310 const string& transform_name = transform_info.first;
311 if (transform_name.empty()) {
312 continue;
313 }
314 if (!transform_registry->count(transform_name)) {
315 return errors::InvalidArgument("Transform '", transform_name,
316 "' not recognized.");
317 }
318 LOG(INFO) << "Applying " << transform_name;
319 const TransformFunc& transform_func =
320 transform_registry->at(transform_name);
321 TransformFuncContext context;
322 context.input_names = inputs;
323 context.output_names = outputs;
324 context.params = transform_info.second;
325 bool ignore_errors;
326 TF_RETURN_IF_ERROR(
327 ShouldIgnoreErrors(transform_info.second, &ignore_errors));
328 GraphDef transformed_graph_def;
329 Status transform_result =
330 transform_func(*graph_def, context, &transformed_graph_def);
331 if (!transform_result.ok()) {
332 if (ignore_errors) {
333 LOG(ERROR) << transform_name << ": Ignoring error "
334 << transform_result.error_message();
335 transformed_graph_def = *graph_def;
336 } else {
337 return transform_result;
338 }
339 }
340 // Copy over the library from the original input graph.
341 *transformed_graph_def.mutable_library() = graph_def->library();
342 TF_RETURN_IF_ERROR(IsGraphValid(transformed_graph_def));
343
344 *graph_def = transformed_graph_def;
345 }
346 return OkStatus();
347}
348} // namespace graph_transforms
349} // namespace tensorflow
350