1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/grappler/grappler_item_builder.h" |
16 | |
17 | #include <type_traits> |
18 | #include <unordered_map> |
19 | #include <unordered_set> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/common_runtime/device.h" |
23 | #include "tensorflow/core/common_runtime/device_factory.h" |
24 | #include "tensorflow/core/common_runtime/device_mgr.h" |
25 | #include "tensorflow/core/common_runtime/function.h" |
26 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
27 | #include "tensorflow/core/common_runtime/graph_optimizer.h" |
28 | #include "tensorflow/core/framework/attr_value.pb.h" |
29 | #include "tensorflow/core/framework/function.h" |
30 | #include "tensorflow/core/framework/function.pb.h" |
31 | #include "tensorflow/core/framework/graph_def_util.h" |
32 | #include "tensorflow/core/framework/node_def.pb.h" |
33 | #include "tensorflow/core/framework/op.h" |
34 | #include "tensorflow/core/framework/tensor.pb.h" |
35 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
36 | #include "tensorflow/core/framework/types.pb.h" |
37 | #include "tensorflow/core/framework/variable.pb.h" |
38 | #include "tensorflow/core/framework/versions.pb.h" |
39 | #include "tensorflow/core/grappler/inputs/utils.h" |
40 | #include "tensorflow/core/grappler/op_types.h" |
41 | #include "tensorflow/core/grappler/optimizers/model_pruner.h" |
42 | #include "tensorflow/core/grappler/utils.h" |
43 | #include "tensorflow/core/lib/gtl/map_util.h" |
44 | #include "tensorflow/core/lib/io/path.h" |
45 | #include "tensorflow/core/platform/protobuf_internal.h" |
46 | #include "tensorflow/core/protobuf/meta_graph.pb.h" |
47 | #include "tensorflow/core/protobuf/saver.pb.h" |
48 | #include "tensorflow/core/public/session_options.h" |
49 | |
50 | namespace tensorflow { |
51 | namespace grappler { |
52 | |
53 | namespace { |
54 | |
55 | void InitializeTensor(DataType type, Tensor* tensor) { |
56 | const int period = 7; |
57 | if (type == DT_FLOAT) { |
58 | auto flat = tensor->flat<float>(); |
59 | // Populate numbers 0, 0.1, 0.2, ..., 0.5, 0.6, 0, 0.1, 0.2, ... |
60 | for (int i = 0; i < flat.size(); i++) { |
61 | flat(i) = static_cast<float>(i % period) / 10.0f; |
62 | } |
63 | } else if (type == DT_INT64) { |
64 | auto flat = tensor->flat<int64_t>(); |
65 | // Populate numbers 0, 1, 2, ..., 5, 6, 0, 1, 2, ... |
66 | for (int i = 0; i < flat.size(); i++) { |
67 | flat(i) = i % period; |
68 | } |
69 | } else if (type != DT_STRING && type != DT_RESOURCE && type != DT_VARIANT) { |
70 | // DT_STRING, DT_RESOURCE and DT_VARIANT are not simple types according to |
71 | // is_simple_type<> in tensorflow/core/framework/type_traits.h, and |
72 | // Allocator will run non-trivial constructor/destructor for a Tensor with |
73 | // one of these types, so we should not memset its buffer. |
74 | memset(const_cast<char*>(tensor->tensor_data().data()), 0, |
75 | tensor->tensor_data().size()); |
76 | } |
77 | } |
78 | |
79 | // Applies the same graph pruning logic to the graph as Session.Run in TF. |
80 | // If the returned status is not OK, item state may be inconsistent. |
81 | Status PruneGraph(GrapplerItem* item) { |
82 | ModelPruner pruner; |
83 | GraphDef pruned_graph; |
84 | Cluster* cluster = nullptr; // ModelPruner doesn't check cluster. |
85 | TF_RETURN_IF_ERROR(pruner.Optimize(cluster, *item, &pruned_graph)); |
86 | item->graph = std::move(pruned_graph); |
87 | return OkStatus(); |
88 | } |
89 | |
90 | // Replace any unknown dimensions in a shape with |
91 | // cfg.placeholder_unknown_output_shape_dim if it is no less than 0. |
92 | Status ReplaceUnknownShapeDim(const ItemConfig& cfg, |
93 | const TensorShapeProto& shape_pb_in, |
94 | TensorShapeProto* shape_pb_out, |
95 | TensorShape* shape_out) { |
96 | std::vector<int32> dims; |
97 | for (const auto& dim_proto : shape_pb_in.dim()) { |
98 | if (cfg.placeholder_unknown_output_shape_dim >= 0 && |
99 | dim_proto.size() == -1) { |
100 | dims.push_back(cfg.placeholder_unknown_output_shape_dim); |
101 | shape_pb_out->add_dim()->set_size( |
102 | cfg.placeholder_unknown_output_shape_dim); |
103 | } else { |
104 | dims.push_back(std::max<int32>(1, dim_proto.size())); |
105 | shape_pb_out->add_dim()->set_size(dim_proto.size()); |
106 | } |
107 | } |
108 | return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out); |
109 | } |
110 | |
111 | // Replace unknown dimensions in Placeholder shape if |
112 | // cfg.placeholder_unknown_output_shape_dim is set or |
113 | // the Placeholder node has _output_shapes. |
114 | // Otherwise keep it intact to keep compatible with shape annotation |
115 | // (b/134092018). |
116 | Status UpdatePlaceholderShape( |
117 | const ItemConfig& cfg, |
118 | const std::unordered_set<string>& signature_feed_nodes, |
119 | GrapplerItem* new_item, NodeDef* node) { |
120 | if (node->attr().count("dtype" ) == 0) { |
121 | return errors::Internal("Unknown type for placeholder " , node->name(), |
122 | ", skipping this input" ); |
123 | } |
124 | DataType type = node->attr().at("dtype" ).type(); |
125 | |
126 | // TODO(andiryxu): Consider cfg.placeholder_unknown_output_shape_dim >= 0 and |
127 | // _output_shapes is present case. |
128 | if (node->attr().count("shape" ) == 0) { |
129 | return errors::Internal("Unknown shape for placeholder " , node->name(), |
130 | ", skipping this input" ); |
131 | } |
132 | |
133 | // Replace all unknown dimensions in the placeholder's tensorshape proto |
134 | // with cfg.placeholder_unknown_output_shape_dim and create a tensorshape |
135 | // from it. We do this because in newer protos, the input placeholder |
136 | // shape is not empty if the shape is partially defined. |
137 | TensorShape shape; |
138 | TensorShapeProto shape_proto; |
139 | Status make_shape_status = ReplaceUnknownShapeDim( |
140 | cfg, node->attr().at("shape" ).shape(), &shape_proto, &shape); |
141 | if (!make_shape_status.ok()) { |
142 | return errors::Internal("Invalid shape for placeholder " , node->name(), |
143 | ": " , make_shape_status, ", skipping this input" ); |
144 | } |
145 | |
146 | // Some placeholder nodes have a mismatch between the node |
147 | // attribute "shape" and a different node attribute "_output_shapes". |
148 | // Specifically, a shape with shape.dims() == 0 could indicate either |
149 | // a scalar or an unknown shape. In those cases, we check _output_shapes |
150 | // for additional information. |
151 | // This case is observed in the bnmt graphs. Have not observed any |
152 | // cases where there was more than 1 _output_shapes, so limit it |
153 | // to cases where there is only 1 _output_shapes. |
154 | // We only do this if cfg.placeholder_unknown_output_shape_dim has |
155 | // been set to avoid crashing non-BNMT graphs. |
156 | // TODO(andiryxu): Investigate if this is a bug in BNMT graph. |
157 | if ((cfg.placeholder_unknown_output_shape_dim >= 0) && (shape.dims() == 0) && |
158 | (node->attr().count("_output_shapes" ) == 1)) { |
159 | const auto& output_shapes = |
160 | node->attr().at("_output_shapes" ).list().shape(0); |
161 | |
162 | if (output_shapes.dim_size() != 0) { |
163 | shape.Clear(); |
164 | shape_proto.clear_dim(); |
165 | |
166 | for (const auto& dim : output_shapes.dim()) { |
167 | auto size = dim.size(); |
168 | if (size == -1) size = cfg.placeholder_unknown_output_shape_dim; |
169 | shape.AddDim(size); |
170 | shape_proto.add_dim()->set_size(size); |
171 | } |
172 | } |
173 | } |
174 | |
175 | Tensor fake_input(type, shape); |
176 | InitializeTensor(type, &fake_input); |
177 | |
178 | if (cfg.feed_nodes.empty()) { |
179 | // No specific feed nodes were given. Assume all placeholders are fed. |
180 | if (signature_feed_nodes.count(node->name()) == 0) { |
181 | new_item->feed.emplace_back(node->name(), fake_input); |
182 | } |
183 | } else if (cfg.feed_nodes.count(node->name()) > 0) { |
184 | // If specific feed nodes were given, only update their tensors. |
185 | auto it = find_if(new_item->feed.begin(), new_item->feed.end(), |
186 | [&node](std::pair<string, Tensor>& f) { |
187 | return f.first == node->name(); |
188 | }); |
189 | DCHECK(it != new_item->feed.end()); |
190 | it->second = fake_input; |
191 | } |
192 | |
193 | // Set the shape of the node in the graph. This is needed for statically |
194 | // inferring shapes and is a no-op when dynamically inferring shapes as |
195 | // the Placeholder shape will match the shape passed from new_item->feed. |
196 | // Only replace node shape with known shape. For unknown shape keep it intact |
197 | // (b/134092018). |
198 | if (!shape_proto.dim().empty()) |
199 | *(node->mutable_attr()->at("shape" ).mutable_shape()) = shape_proto; |
200 | |
201 | return OkStatus(); |
202 | } |
203 | |
204 | } // namespace |
205 | |
206 | Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg, |
207 | GraphDef* output_graph_def, |
208 | const ItemConfig& cfg) { |
209 | // This is a temporary change that optimizes the graph in context of a single |
210 | // gpu machine. Down the line, we may want to make grappler_item_builder aware |
211 | // of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated |
212 | // in order to get the correct session options and environment, and performing |
213 | // the correct optimizations. |
214 | |
215 | // Return input as is if no graph-modifying config is set. |
216 | if (!cfg.apply_optimizations && !cfg.inline_functions && |
217 | !cfg.erase_noinline_attributes) { |
218 | if (output_graph_def != &graph_def_arg) { |
219 | *output_graph_def = graph_def_arg; |
220 | } |
221 | return OkStatus(); |
222 | } |
223 | |
224 | // Create a session option for a single GPU device. |
225 | SessionOptions options; |
226 | |
227 | // Make a local copy of graph def, because we need to change some things. |
228 | GraphDef graph_def(graph_def_arg); |
229 | |
230 | if (cfg.erase_noinline_attributes) { |
231 | // TF optimizer doesn't inline functions with "_noinline" attribute, |
232 | // so let's go over the function library and erase it. |
233 | for (auto& func : *graph_def.mutable_library()->mutable_function()) { |
234 | func.mutable_attr()->erase("_noinline" ); |
235 | } |
236 | } |
237 | |
238 | // Instantiate all variables for function library runtime creation. |
239 | std::vector<std::unique_ptr<Device>> devices; |
240 | // Only CPU device is used so instead of calling DeviceFactory::AddDevices() |
241 | // with dummy session config, which will conflict with user defined options |
242 | // and create unwanted devices, call cpu_factory->CreateDevices() to get CPU |
243 | // only devices. |
244 | DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU" ); |
245 | TF_RETURN_IF_ERROR(cpu_factory->CreateDevices( |
246 | options, "/job:localhost/replica:0/task:0" , &devices)); |
247 | Device* cpu_device = devices[0].get(); |
248 | auto dvc_mgr = std::make_unique<StaticDeviceMgr>(std::move(devices)); |
249 | FunctionLibraryDefinition function_library(OpRegistry::Global(), |
250 | graph_def.library()); |
251 | Env* env = Env::Default(); |
252 | |
253 | // Optimizer options: L1 and inlining. L1 is default. |
254 | OptimizerOptions* optimizer_opts = |
255 | options.config.mutable_graph_options()->mutable_optimizer_options(); |
256 | if (cfg.apply_optimizations) { |
257 | optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions::L1); |
258 | } else { |
259 | optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions::L0); |
260 | } |
261 | optimizer_opts->set_do_function_inlining(cfg.inline_functions); |
262 | |
263 | // Create the function library runtime. |
264 | std::unique_ptr<ProcessFunctionLibraryRuntime> pflr( |
265 | new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, &options.config, |
266 | graph_def.versions().producer(), |
267 | &function_library, *optimizer_opts)); |
268 | FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name()); |
269 | |
270 | // Create the GraphOptimizer to optimize the graph def. |
271 | GraphConstructorOptions graph_ctor_opts; |
272 | graph_ctor_opts.allow_internal_ops = true; |
273 | graph_ctor_opts.expect_device_spec = false; |
274 | std::unique_ptr<Graph> graphptr(new Graph(function_library)); |
275 | |
276 | TF_RETURN_IF_ERROR(ConvertGraphDefToGraph( |
277 | graph_ctor_opts, std::move(graph_def), graphptr.get())); |
278 | |
279 | // Optimize the graph. |
280 | ::tensorflow::GraphOptimizer optimizer(*optimizer_opts); |
281 | optimizer.Optimize(flr, env, cpu_device, &graphptr, |
282 | tensorflow::GraphOptimizer::Options()); |
283 | graphptr->ToGraphDef(output_graph_def); |
284 | |
285 | // The default values of attributes might have been stripped by the optimizer. |
286 | // Add them back. |
287 | return AddDefaultAttrsToGraphDef(output_graph_def, *graphptr->op_registry(), |
288 | 0, true); |
289 | } |
290 | |
291 | std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( |
292 | const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) { |
293 | if (id.empty()) { |
294 | LOG(ERROR) << "id must be non-empty." ; |
295 | return nullptr; |
296 | } |
297 | std::unique_ptr<GrapplerItem> new_item(new GrapplerItem()); |
298 | new_item->id = id; |
299 | new_item->graph = meta_graph.graph_def(); |
300 | |
301 | // Fill in feed nodes from config, if any provided. |
302 | for (const auto& feed_node : cfg.feed_nodes) { |
303 | const string feed_name = NodeName(feed_node); |
304 | new_item->feed.emplace_back(feed_name, Tensor()); |
305 | } |
306 | for (const auto& fetch_node : cfg.fetch_nodes) { |
307 | new_item->fetch.emplace_back(NodeName(fetch_node)); |
308 | } |
309 | |
310 | // Attempt to detect the fetch node(s) if they were not set explicitly. |
311 | if (new_item->fetch.empty() && |
312 | meta_graph.collection_def().count("train_op" ) > 0) { |
313 | const CollectionDef& nodes = meta_graph.collection_def().at("train_op" ); |
314 | if (nodes.has_node_list()) { |
315 | for (const auto& node : nodes.node_list().value()) { |
316 | new_item->fetch.push_back(NodeName(node)); |
317 | } |
318 | } |
319 | } |
320 | |
321 | // Detect feed and fetch nodes from signature defs. Signatures may share same |
322 | // inputs or outputs. |
323 | std::unordered_set<string> signature_feed_nodes; |
324 | std::unordered_set<string> signature_fetch_nodes; |
325 | for (const auto& name_and_signature : meta_graph.signature_def()) { |
326 | for (const auto& name_and_input : name_and_signature.second.inputs()) { |
327 | const TensorInfo& input = name_and_input.second; |
328 | if (input.has_coo_sparse()) { |
329 | // Define the shapes following the comment of CooSparse. |
330 | // TODO(yuefengz): we probably want to use different dim values for the |
331 | // three tensors of a SparseTensor. |
332 | int64_t dim = std::max(1, cfg.placeholder_unknown_output_shape_dim); |
333 | TensorShape shape_1d({dim}); |
334 | TensorShape shape_2d({dim, dim}); |
335 | |
336 | if (gtl::InsertIfNotPresent( |
337 | &signature_feed_nodes, |
338 | NodeName(input.coo_sparse().values_tensor_name()))) { |
339 | Tensor value_tensor(input.dtype(), shape_1d); |
340 | InitializeTensor(input.dtype(), &value_tensor); |
341 | new_item->feed.emplace_back( |
342 | NodeName(input.coo_sparse().values_tensor_name()), value_tensor); |
343 | } |
344 | if (gtl::InsertIfNotPresent( |
345 | &signature_feed_nodes, |
346 | NodeName(input.coo_sparse().indices_tensor_name()))) { |
347 | Tensor indices_tensor(DT_INT64, shape_2d); |
348 | InitializeTensor(input.dtype(), &indices_tensor); |
349 | new_item->feed.emplace_back( |
350 | NodeName(input.coo_sparse().indices_tensor_name()), |
351 | indices_tensor); |
352 | } |
353 | if (gtl::InsertIfNotPresent( |
354 | &signature_feed_nodes, |
355 | NodeName(input.coo_sparse().dense_shape_tensor_name()))) { |
356 | Tensor dense_shape_tensor(DT_INT64, shape_1d); |
357 | InitializeTensor(input.dtype(), &dense_shape_tensor); |
358 | new_item->feed.emplace_back( |
359 | NodeName(input.coo_sparse().dense_shape_tensor_name()), |
360 | dense_shape_tensor); |
361 | } |
362 | } else { |
363 | if (gtl::InsertIfNotPresent(&signature_feed_nodes, |
364 | NodeName(input.name()))) { |
365 | TensorShape shape; |
366 | TensorShapeProto shape_proto; |
367 | Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(), |
368 | &shape_proto, &shape); |
369 | if (!s.ok()) { |
370 | LOG(ERROR) << "Invalid shape for signature input " << input.name() |
371 | << ": " << s << ", skipping this input" ; |
372 | return nullptr; |
373 | } |
374 | |
375 | Tensor fake_input(input.dtype(), shape); |
376 | InitializeTensor(input.dtype(), &fake_input); |
377 | new_item->feed.emplace_back(NodeName(input.name()), fake_input); |
378 | } |
379 | } |
380 | } |
381 | for (const auto& name_and_output : name_and_signature.second.outputs()) { |
382 | const TensorInfo& output = name_and_output.second; |
383 | if (output.has_coo_sparse()) { |
384 | if (gtl::InsertIfNotPresent( |
385 | &signature_fetch_nodes, |
386 | NodeName(output.coo_sparse().values_tensor_name()))) { |
387 | new_item->fetch.push_back( |
388 | NodeName(output.coo_sparse().values_tensor_name())); |
389 | } |
390 | if (gtl::InsertIfNotPresent( |
391 | &signature_fetch_nodes, |
392 | NodeName(output.coo_sparse().indices_tensor_name()))) { |
393 | new_item->fetch.push_back( |
394 | NodeName(output.coo_sparse().indices_tensor_name())); |
395 | } |
396 | if (gtl::InsertIfNotPresent( |
397 | &signature_fetch_nodes, |
398 | NodeName(output.coo_sparse().dense_shape_tensor_name()))) { |
399 | new_item->fetch.push_back( |
400 | NodeName(output.coo_sparse().dense_shape_tensor_name())); |
401 | } |
402 | } else { |
403 | if (gtl::InsertIfNotPresent(&signature_fetch_nodes, |
404 | NodeName(output.name()))) { |
405 | new_item->fetch.push_back(NodeName(output.name())); |
406 | } |
407 | } |
408 | } |
409 | } |
410 | |
411 | for (const auto& feed : new_item->feed) { |
412 | if (feed.first.empty()) { |
413 | LOG(ERROR) << "Invalid feed node name skipping this input" ; |
414 | return nullptr; |
415 | } else { |
416 | VLOG(1) << "Will use feed node " << feed.first; |
417 | } |
418 | } |
419 | |
420 | for (const auto& fetch : new_item->fetch) { |
421 | if (fetch.empty()) { |
422 | LOG(ERROR) << "Invalid fetch node name skipping this input" ; |
423 | return nullptr; |
424 | } else { |
425 | VLOG(1) << "Will use fetch node " << fetch; |
426 | } |
427 | } |
428 | |
429 | if (new_item->fetch.empty()) { |
430 | LOG(ERROR) << "Failed to detect the fetch node(s), skipping this input" ; |
431 | return nullptr; |
432 | } |
433 | |
434 | // TODO(yuefengz): consider handling saved_model_main_op and legacy_init_op. |
435 | // The reason why they are difficult to handle is because they may not intend |
436 | // to initialize all variables that are required to run fetch nodes. We may |
437 | // have to run restore op first. |
438 | |
439 | // Try to find initializers from variables and tables as init ops. |
440 | for (const string& var_collection : |
441 | {"variables" , "local_variables" , "model_variables" , |
442 | "trainable_variables" }) { |
443 | if (meta_graph.collection_def().count(var_collection) == 0) { |
444 | continue; |
445 | } |
446 | const CollectionDef& vars = meta_graph.collection_def().at(var_collection); |
447 | for (const auto& raw_var : vars.bytes_list().value()) { |
448 | VariableDef var; |
449 | var.ParseFromString(raw_var); |
450 | if (!var.initializer_name().empty()) { |
451 | new_item->init_ops.push_back(NodeName(var.initializer_name())); |
452 | } |
453 | } |
454 | } |
455 | |
456 | if (meta_graph.collection_def().count("table_initializer" ) > 0) { |
457 | const CollectionDef& inits = |
458 | meta_graph.collection_def().at("table_initializer" ); |
459 | if (inits.has_node_list()) { |
460 | for (const auto& node : inits.node_list().value()) { |
461 | new_item->init_ops.push_back(NodeName(node)); |
462 | // Tables are initialized from files, which can take a long time. Add |
463 | // 30 minutes to the initialization time for each table to avoid |
464 | // timing out. |
465 | // TODO(bsteiner): adjust the timeout based on the file size. |
466 | new_item->expected_init_time += 30 * 60; |
467 | } |
468 | } |
469 | } |
470 | |
471 | // We keep the mapping from asset node to asset files. This should have been |
472 | // used as feed but since asset node is usually a constant node, we will fill |
473 | // the values of these constant nodes with their actual asset file paths. |
474 | std::unordered_map<string, string> asset_node_to_value; |
475 | |
476 | // Assets file may have changed their directory, we assemble their new paths |
477 | // if assets_directory_override is set. We also make sure we still can |
478 | // access these asset files. |
479 | if (!cfg.assets_directory_override.empty()) { |
480 | if (meta_graph.collection_def().count("saved_model_assets" ) > 0) { |
481 | const CollectionDef& collection = |
482 | meta_graph.collection_def().at("saved_model_assets" ); |
483 | const auto& any_assets = collection.any_list().value(); |
484 | if (!any_assets.empty()) { |
485 | if (std::is_base_of<protobuf::Message, AssetFileDef>()) { |
486 | for (const auto& any_asset : any_assets) { |
487 | AssetFileDef asset_file_def; |
488 | if (!ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef" ) |
489 | .ok()) { |
490 | LOG(ERROR) << "Failed to parse AssetFile." ; |
491 | continue; |
492 | } |
493 | string asset_filepath = io::JoinPath(cfg.assets_directory_override, |
494 | asset_file_def.filename()); |
495 | if (!FilesExist({asset_filepath}, nullptr)) { |
496 | LOG(ERROR) << "Can't access one or more of the asset files " |
497 | << asset_filepath << ", skipping this input" ; |
498 | return nullptr; |
499 | } |
500 | asset_node_to_value[NodeName(asset_file_def.tensor_info().name())] = |
501 | asset_filepath; |
502 | } |
503 | } else { |
504 | LOG(ERROR) << "Can't parse AssetFileDef when using lite protos." ; |
505 | return nullptr; |
506 | } |
507 | } |
508 | } |
509 | } else if (meta_graph.collection_def().count("asset_filepaths" ) > 0) { |
510 | const CollectionDef& file_paths = |
511 | meta_graph.collection_def().at("asset_filepaths" ); |
512 | std::vector<string> paths; |
513 | for (const auto& raw_path : file_paths.bytes_list().value()) { |
514 | paths.push_back(raw_path); |
515 | } |
516 | if (!FilesExist(paths, nullptr)) { |
517 | LOG(ERROR) << "Can't access one or more of the asset files, skipping " |
518 | "this input" ; |
519 | return nullptr; |
520 | } |
521 | } |
522 | |
523 | if (meta_graph.collection_def().count("queue_runners" ) > 0) { |
524 | const CollectionDef& vars = meta_graph.collection_def().at("queue_runners" ); |
525 | for (const auto& raw : vars.bytes_list().value()) { |
526 | QueueRunnerDef queue_runner; |
527 | if (!queue_runner.ParseFromString(raw)) { |
528 | LOG(ERROR) << "Could not parse queue_runners, skipping this input" ; |
529 | return nullptr; |
530 | } |
531 | if (queue_runner.cancel_op_name().empty()) { |
532 | LOG(ERROR) << "Queue without a cancel op, skipping this input" ; |
533 | return nullptr; |
534 | } |
535 | new_item->queue_runners.push_back(queue_runner); |
536 | } |
537 | } |
538 | |
539 | // Add each node referenced in a collection to the list of nodes to keep. |
540 | for (const auto& col : meta_graph.collection_def()) { |
541 | const CollectionDef& collection = col.second; |
542 | for (const string& node : collection.node_list().value()) { |
543 | new_item->keep_ops.push_back(NodeName(node)); |
544 | } |
545 | } |
546 | |
547 | for (auto& node : *new_item->graph.mutable_node()) { |
548 | if (IsPlaceholder(node) && node.op() != "PlaceholderWithDefault" ) { |
549 | Status s = UpdatePlaceholderShape(cfg, signature_feed_nodes, |
550 | new_item.get(), &node); |
551 | if (!s.ok()) return nullptr; |
552 | } else if (IsConstant(node)) { |
553 | auto it = asset_node_to_value.find(node.name()); |
554 | if (it != asset_node_to_value.end()) { |
555 | auto iter = node.mutable_attr()->find("value" ); |
556 | if (iter == node.attr().end()) { |
557 | LOG(ERROR) << "Value attribute expected in const op for asset files" ; |
558 | return nullptr; |
559 | } |
560 | if (!iter->second.has_tensor() || |
561 | iter->second.tensor().string_val_size() != 1) { |
562 | LOG(INFO) << "Unexpected AttrValue proto: " |
563 | << iter->second.DebugString(); |
564 | return nullptr; |
565 | } |
566 | LOG(INFO) << "Using asset file " << it->second << " for node " |
567 | << node.name(); |
568 | *(iter->second.mutable_tensor()->mutable_string_val(0)) = it->second; |
569 | } |
570 | } |
571 | |
572 | // Erase the recorded result of any previous shape inference to start again |
573 | // from scratch. |
574 | node.mutable_attr()->erase("_output_shapes" ); |
575 | |
576 | // Delete user specified placement if requested. |
577 | if (cfg.ignore_user_placement) { |
578 | node.clear_device(); |
579 | } |
580 | // Delete colocation constraints if requested. |
581 | if (cfg.ignore_colocation) { |
582 | auto attr = node.mutable_attr(); |
583 | auto it = attr->find("_class" ); |
584 | if (it != attr->end()) { |
585 | attr->erase(it); |
586 | } |
587 | } |
588 | } |
589 | |
590 | if (meta_graph.collection_def().count("savers" ) > 0) { |
591 | const CollectionDef& savers = meta_graph.collection_def().at("savers" ); |
592 | for (const auto& raw : savers.bytes_list().value()) { |
593 | SaverDef saver; |
594 | // Skip bad savers since we don't need saves/restores to be able to run a |
595 | // graph. |
596 | if (!saver.ParseFromString(raw)) { |
597 | continue; |
598 | } |
599 | if (saver.filename_tensor_name().empty()) { |
600 | continue; |
601 | } |
602 | new_item->save_op = saver.save_tensor_name(); |
603 | new_item->restore_op = saver.restore_op_name(); |
604 | new_item->save_restore_loc_tensor = saver.filename_tensor_name(); |
605 | // Only use the first saver since it's not clear what to do if there's |
606 | // more than one. |
607 | break; |
608 | } |
609 | } else { |
610 | const SaverDef& saver = meta_graph.saver_def(); |
611 | new_item->save_op = saver.save_tensor_name(); |
612 | new_item->restore_op = saver.restore_op_name(); |
613 | new_item->save_restore_loc_tensor = saver.filename_tensor_name(); |
614 | } |
615 | |
616 | // Instantiate all the missing attributes with their default values. |
617 | Status attr_status = AddDefaultAttrsToGraphDef( |
618 | &new_item->graph, |
619 | FunctionLibraryDefinition(OpRegistry::Global(), |
620 | new_item->graph.library()), |
621 | 0, true); |
622 | if (!attr_status.ok()) { |
623 | LOG(ERROR) << "Failed to instantiate default attribute values: " |
624 | << attr_status.error_message(); |
625 | return nullptr; |
626 | } |
627 | |
628 | // Optimize the graph (function inlining, l1 optimizations, etc). |
629 | VLOG(1) << "Number of nodes in graph before RuntimeGraphOptimizer: " |
630 | << new_item->graph.node_size(); |
631 | Status optimize_status = |
632 | RuntimeGraphOptimizer(new_item->graph, &new_item->graph, cfg); |
633 | if (!optimize_status.ok()) { |
634 | LOG(ERROR) << "Graph preprocessing failed: " << optimize_status; |
635 | return nullptr; |
636 | } |
637 | VLOG(1) << "Number of nodes in graph after RuntimeGraphOptimizer: " |
638 | << new_item->graph.node_size(); |
639 | |
640 | if (cfg.prune_graph) { |
641 | VLOG(1) << "Pruning graph..." ; |
642 | auto status = PruneGraph(new_item.get()); |
643 | if (!status.ok()) { |
644 | LOG(ERROR) << "Pruning failed: " << status.error_message(); |
645 | return nullptr; |
646 | } |
647 | VLOG(1) << "Number of nodes in graph after pruning: " |
648 | << new_item->graph.node_size(); |
649 | } |
650 | |
651 | // Validate feed, fetch and init nodes |
652 | std::unordered_set<string> nodes; |
653 | for (const auto& node : new_item->graph.node()) { |
654 | nodes.insert(node.name()); |
655 | } |
656 | for (const auto& feed : new_item->feed) { |
657 | if (nodes.find(feed.first) == nodes.end()) { |
658 | LOG(ERROR) << "Feed node " << feed.first << " doesn't exist in graph" ; |
659 | return nullptr; |
660 | } |
661 | } |
662 | for (const auto& fetch : new_item->fetch) { |
663 | if (nodes.find(fetch) == nodes.end()) { |
664 | LOG(ERROR) << "Fetch node " << fetch << " doesn't exist in graph" ; |
665 | return nullptr; |
666 | } |
667 | } |
668 | for (const auto& init : new_item->init_ops) { |
669 | if (nodes.find(init) == nodes.end()) { |
670 | LOG(ERROR) << "Init node " << init << " doesn't exist in graph" ; |
671 | return nullptr; |
672 | } |
673 | } |
674 | return new_item; |
675 | } |
676 | |
677 | std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDefFile( |
678 | const string& id, const string& meta_graph_file, const ItemConfig& cfg) { |
679 | MetaGraphDef meta_graph; |
680 | if (!ReadMetaGraphDefFromFile(meta_graph_file, &meta_graph).ok()) { |
681 | LOG(ERROR) << "Failed to read " << meta_graph_file; |
682 | return nullptr; |
683 | } |
684 | return GrapplerItemFromMetaGraphDef(id, meta_graph, cfg); |
685 | } |
686 | |
687 | } // end namespace grappler |
688 | } // end namespace tensorflow |
689 | |