1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/grappler/grappler_item_builder.h"
16
17#include <type_traits>
18#include <unordered_map>
19#include <unordered_set>
20#include <vector>
21
22#include "tensorflow/core/common_runtime/device.h"
23#include "tensorflow/core/common_runtime/device_factory.h"
24#include "tensorflow/core/common_runtime/device_mgr.h"
25#include "tensorflow/core/common_runtime/function.h"
26#include "tensorflow/core/common_runtime/graph_constructor.h"
27#include "tensorflow/core/common_runtime/graph_optimizer.h"
28#include "tensorflow/core/framework/attr_value.pb.h"
29#include "tensorflow/core/framework/function.h"
30#include "tensorflow/core/framework/function.pb.h"
31#include "tensorflow/core/framework/graph_def_util.h"
32#include "tensorflow/core/framework/node_def.pb.h"
33#include "tensorflow/core/framework/op.h"
34#include "tensorflow/core/framework/tensor.pb.h"
35#include "tensorflow/core/framework/tensor_shape.pb.h"
36#include "tensorflow/core/framework/types.pb.h"
37#include "tensorflow/core/framework/variable.pb.h"
38#include "tensorflow/core/framework/versions.pb.h"
39#include "tensorflow/core/grappler/inputs/utils.h"
40#include "tensorflow/core/grappler/op_types.h"
41#include "tensorflow/core/grappler/optimizers/model_pruner.h"
42#include "tensorflow/core/grappler/utils.h"
43#include "tensorflow/core/lib/gtl/map_util.h"
44#include "tensorflow/core/lib/io/path.h"
45#include "tensorflow/core/platform/protobuf_internal.h"
46#include "tensorflow/core/protobuf/meta_graph.pb.h"
47#include "tensorflow/core/protobuf/saver.pb.h"
48#include "tensorflow/core/public/session_options.h"
49
50namespace tensorflow {
51namespace grappler {
52
53namespace {
54
55void InitializeTensor(DataType type, Tensor* tensor) {
56 const int period = 7;
57 if (type == DT_FLOAT) {
58 auto flat = tensor->flat<float>();
59 // Populate numbers 0, 0.1, 0.2, ..., 0.5, 0.6, 0, 0.1, 0.2, ...
60 for (int i = 0; i < flat.size(); i++) {
61 flat(i) = static_cast<float>(i % period) / 10.0f;
62 }
63 } else if (type == DT_INT64) {
64 auto flat = tensor->flat<int64_t>();
65 // Populate numbers 0, 1, 2, ..., 5, 6, 0, 1, 2, ...
66 for (int i = 0; i < flat.size(); i++) {
67 flat(i) = i % period;
68 }
69 } else if (type != DT_STRING && type != DT_RESOURCE && type != DT_VARIANT) {
70 // DT_STRING, DT_RESOURCE and DT_VARIANT are not simple types according to
71 // is_simple_type<> in tensorflow/core/framework/type_traits.h, and
72 // Allocator will run non-trivial constructor/destructor for a Tensor with
73 // one of these types, so we should not memset its buffer.
74 memset(const_cast<char*>(tensor->tensor_data().data()), 0,
75 tensor->tensor_data().size());
76 }
77}
78
79// Applies the same graph pruning logic to the graph as Session.Run in TF.
80// If the returned status is not OK, item state may be inconsistent.
81Status PruneGraph(GrapplerItem* item) {
82 ModelPruner pruner;
83 GraphDef pruned_graph;
84 Cluster* cluster = nullptr; // ModelPruner doesn't check cluster.
85 TF_RETURN_IF_ERROR(pruner.Optimize(cluster, *item, &pruned_graph));
86 item->graph = std::move(pruned_graph);
87 return OkStatus();
88}
89
90// Replace any unknown dimensions in a shape with
91// cfg.placeholder_unknown_output_shape_dim if it is no less than 0.
92Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
93 const TensorShapeProto& shape_pb_in,
94 TensorShapeProto* shape_pb_out,
95 TensorShape* shape_out) {
96 std::vector<int32> dims;
97 for (const auto& dim_proto : shape_pb_in.dim()) {
98 if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
99 dim_proto.size() == -1) {
100 dims.push_back(cfg.placeholder_unknown_output_shape_dim);
101 shape_pb_out->add_dim()->set_size(
102 cfg.placeholder_unknown_output_shape_dim);
103 } else {
104 dims.push_back(std::max<int32>(1, dim_proto.size()));
105 shape_pb_out->add_dim()->set_size(dim_proto.size());
106 }
107 }
108 return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out);
109}
110
111// Replace unknown dimensions in Placeholder shape if
112// cfg.placeholder_unknown_output_shape_dim is set or
113// the Placeholder node has _output_shapes.
114// Otherwise keep it intact to keep compatible with shape annotation
115// (b/134092018).
116Status UpdatePlaceholderShape(
117 const ItemConfig& cfg,
118 const std::unordered_set<string>& signature_feed_nodes,
119 GrapplerItem* new_item, NodeDef* node) {
120 if (node->attr().count("dtype") == 0) {
121 return errors::Internal("Unknown type for placeholder ", node->name(),
122 ", skipping this input");
123 }
124 DataType type = node->attr().at("dtype").type();
125
126 // TODO(andiryxu): Consider cfg.placeholder_unknown_output_shape_dim >= 0 and
127 // _output_shapes is present case.
128 if (node->attr().count("shape") == 0) {
129 return errors::Internal("Unknown shape for placeholder ", node->name(),
130 ", skipping this input");
131 }
132
133 // Replace all unknown dimensions in the placeholder's tensorshape proto
134 // with cfg.placeholder_unknown_output_shape_dim and create a tensorshape
135 // from it. We do this because in newer protos, the input placeholder
136 // shape is not empty if the shape is partially defined.
137 TensorShape shape;
138 TensorShapeProto shape_proto;
139 Status make_shape_status = ReplaceUnknownShapeDim(
140 cfg, node->attr().at("shape").shape(), &shape_proto, &shape);
141 if (!make_shape_status.ok()) {
142 return errors::Internal("Invalid shape for placeholder ", node->name(),
143 ": ", make_shape_status, ", skipping this input");
144 }
145
146 // Some placeholder nodes have a mismatch between the node
147 // attribute "shape" and a different node attribute "_output_shapes".
148 // Specifically, a shape with shape.dims() == 0 could indicate either
149 // a scalar or an unknown shape. In those cases, we check _output_shapes
150 // for additional information.
151 // This case is observed in the bnmt graphs. Have not observed any
152 // cases where there was more than 1 _output_shapes, so limit it
153 // to cases where there is only 1 _output_shapes.
154 // We only do this if cfg.placeholder_unknown_output_shape_dim has
155 // been set to avoid crashing non-BNMT graphs.
156 // TODO(andiryxu): Investigate if this is a bug in BNMT graph.
157 if ((cfg.placeholder_unknown_output_shape_dim >= 0) && (shape.dims() == 0) &&
158 (node->attr().count("_output_shapes") == 1)) {
159 const auto& output_shapes =
160 node->attr().at("_output_shapes").list().shape(0);
161
162 if (output_shapes.dim_size() != 0) {
163 shape.Clear();
164 shape_proto.clear_dim();
165
166 for (const auto& dim : output_shapes.dim()) {
167 auto size = dim.size();
168 if (size == -1) size = cfg.placeholder_unknown_output_shape_dim;
169 shape.AddDim(size);
170 shape_proto.add_dim()->set_size(size);
171 }
172 }
173 }
174
175 Tensor fake_input(type, shape);
176 InitializeTensor(type, &fake_input);
177
178 if (cfg.feed_nodes.empty()) {
179 // No specific feed nodes were given. Assume all placeholders are fed.
180 if (signature_feed_nodes.count(node->name()) == 0) {
181 new_item->feed.emplace_back(node->name(), fake_input);
182 }
183 } else if (cfg.feed_nodes.count(node->name()) > 0) {
184 // If specific feed nodes were given, only update their tensors.
185 auto it = find_if(new_item->feed.begin(), new_item->feed.end(),
186 [&node](std::pair<string, Tensor>& f) {
187 return f.first == node->name();
188 });
189 DCHECK(it != new_item->feed.end());
190 it->second = fake_input;
191 }
192
193 // Set the shape of the node in the graph. This is needed for statically
194 // inferring shapes and is a no-op when dynamically inferring shapes as
195 // the Placeholder shape will match the shape passed from new_item->feed.
196 // Only replace node shape with known shape. For unknown shape keep it intact
197 // (b/134092018).
198 if (!shape_proto.dim().empty())
199 *(node->mutable_attr()->at("shape").mutable_shape()) = shape_proto;
200
201 return OkStatus();
202}
203
204} // namespace
205
206Status RuntimeGraphOptimizer(const GraphDef& graph_def_arg,
207 GraphDef* output_graph_def,
208 const ItemConfig& cfg) {
209 // This is a temporary change that optimizes the graph in context of a single
210 // gpu machine. Down the line, we may want to make grappler_item_builder aware
211 // of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated
212 // in order to get the correct session options and environment, and performing
213 // the correct optimizations.
214
215 // Return input as is if no graph-modifying config is set.
216 if (!cfg.apply_optimizations && !cfg.inline_functions &&
217 !cfg.erase_noinline_attributes) {
218 if (output_graph_def != &graph_def_arg) {
219 *output_graph_def = graph_def_arg;
220 }
221 return OkStatus();
222 }
223
224 // Create a session option for a single GPU device.
225 SessionOptions options;
226
227 // Make a local copy of graph def, because we need to change some things.
228 GraphDef graph_def(graph_def_arg);
229
230 if (cfg.erase_noinline_attributes) {
231 // TF optimizer doesn't inline functions with "_noinline" attribute,
232 // so let's go over the function library and erase it.
233 for (auto& func : *graph_def.mutable_library()->mutable_function()) {
234 func.mutable_attr()->erase("_noinline");
235 }
236 }
237
238 // Instantiate all variables for function library runtime creation.
239 std::vector<std::unique_ptr<Device>> devices;
240 // Only CPU device is used so instead of calling DeviceFactory::AddDevices()
241 // with dummy session config, which will conflict with user defined options
242 // and create unwanted devices, call cpu_factory->CreateDevices() to get CPU
243 // only devices.
244 DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
245 TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
246 options, "/job:localhost/replica:0/task:0", &devices));
247 Device* cpu_device = devices[0].get();
248 auto dvc_mgr = std::make_unique<StaticDeviceMgr>(std::move(devices));
249 FunctionLibraryDefinition function_library(OpRegistry::Global(),
250 graph_def.library());
251 Env* env = Env::Default();
252
253 // Optimizer options: L1 and inlining. L1 is default.
254 OptimizerOptions* optimizer_opts =
255 options.config.mutable_graph_options()->mutable_optimizer_options();
256 if (cfg.apply_optimizations) {
257 optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions::L1);
258 } else {
259 optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions::L0);
260 }
261 optimizer_opts->set_do_function_inlining(cfg.inline_functions);
262
263 // Create the function library runtime.
264 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
265 new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env, &options.config,
266 graph_def.versions().producer(),
267 &function_library, *optimizer_opts));
268 FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name());
269
270 // Create the GraphOptimizer to optimize the graph def.
271 GraphConstructorOptions graph_ctor_opts;
272 graph_ctor_opts.allow_internal_ops = true;
273 graph_ctor_opts.expect_device_spec = false;
274 std::unique_ptr<Graph> graphptr(new Graph(function_library));
275
276 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
277 graph_ctor_opts, std::move(graph_def), graphptr.get()));
278
279 // Optimize the graph.
280 ::tensorflow::GraphOptimizer optimizer(*optimizer_opts);
281 optimizer.Optimize(flr, env, cpu_device, &graphptr,
282 tensorflow::GraphOptimizer::Options());
283 graphptr->ToGraphDef(output_graph_def);
284
285 // The default values of attributes might have been stripped by the optimizer.
286 // Add them back.
287 return AddDefaultAttrsToGraphDef(output_graph_def, *graphptr->op_registry(),
288 0, true);
289}
290
291std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
292 const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) {
293 if (id.empty()) {
294 LOG(ERROR) << "id must be non-empty.";
295 return nullptr;
296 }
297 std::unique_ptr<GrapplerItem> new_item(new GrapplerItem());
298 new_item->id = id;
299 new_item->graph = meta_graph.graph_def();
300
301 // Fill in feed nodes from config, if any provided.
302 for (const auto& feed_node : cfg.feed_nodes) {
303 const string feed_name = NodeName(feed_node);
304 new_item->feed.emplace_back(feed_name, Tensor());
305 }
306 for (const auto& fetch_node : cfg.fetch_nodes) {
307 new_item->fetch.emplace_back(NodeName(fetch_node));
308 }
309
310 // Attempt to detect the fetch node(s) if they were not set explicitly.
311 if (new_item->fetch.empty() &&
312 meta_graph.collection_def().count("train_op") > 0) {
313 const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
314 if (nodes.has_node_list()) {
315 for (const auto& node : nodes.node_list().value()) {
316 new_item->fetch.push_back(NodeName(node));
317 }
318 }
319 }
320
321 // Detect feed and fetch nodes from signature defs. Signatures may share same
322 // inputs or outputs.
323 std::unordered_set<string> signature_feed_nodes;
324 std::unordered_set<string> signature_fetch_nodes;
325 for (const auto& name_and_signature : meta_graph.signature_def()) {
326 for (const auto& name_and_input : name_and_signature.second.inputs()) {
327 const TensorInfo& input = name_and_input.second;
328 if (input.has_coo_sparse()) {
329 // Define the shapes following the comment of CooSparse.
330 // TODO(yuefengz): we probably want to use different dim values for the
331 // three tensors of a SparseTensor.
332 int64_t dim = std::max(1, cfg.placeholder_unknown_output_shape_dim);
333 TensorShape shape_1d({dim});
334 TensorShape shape_2d({dim, dim});
335
336 if (gtl::InsertIfNotPresent(
337 &signature_feed_nodes,
338 NodeName(input.coo_sparse().values_tensor_name()))) {
339 Tensor value_tensor(input.dtype(), shape_1d);
340 InitializeTensor(input.dtype(), &value_tensor);
341 new_item->feed.emplace_back(
342 NodeName(input.coo_sparse().values_tensor_name()), value_tensor);
343 }
344 if (gtl::InsertIfNotPresent(
345 &signature_feed_nodes,
346 NodeName(input.coo_sparse().indices_tensor_name()))) {
347 Tensor indices_tensor(DT_INT64, shape_2d);
348 InitializeTensor(input.dtype(), &indices_tensor);
349 new_item->feed.emplace_back(
350 NodeName(input.coo_sparse().indices_tensor_name()),
351 indices_tensor);
352 }
353 if (gtl::InsertIfNotPresent(
354 &signature_feed_nodes,
355 NodeName(input.coo_sparse().dense_shape_tensor_name()))) {
356 Tensor dense_shape_tensor(DT_INT64, shape_1d);
357 InitializeTensor(input.dtype(), &dense_shape_tensor);
358 new_item->feed.emplace_back(
359 NodeName(input.coo_sparse().dense_shape_tensor_name()),
360 dense_shape_tensor);
361 }
362 } else {
363 if (gtl::InsertIfNotPresent(&signature_feed_nodes,
364 NodeName(input.name()))) {
365 TensorShape shape;
366 TensorShapeProto shape_proto;
367 Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(),
368 &shape_proto, &shape);
369 if (!s.ok()) {
370 LOG(ERROR) << "Invalid shape for signature input " << input.name()
371 << ": " << s << ", skipping this input";
372 return nullptr;
373 }
374
375 Tensor fake_input(input.dtype(), shape);
376 InitializeTensor(input.dtype(), &fake_input);
377 new_item->feed.emplace_back(NodeName(input.name()), fake_input);
378 }
379 }
380 }
381 for (const auto& name_and_output : name_and_signature.second.outputs()) {
382 const TensorInfo& output = name_and_output.second;
383 if (output.has_coo_sparse()) {
384 if (gtl::InsertIfNotPresent(
385 &signature_fetch_nodes,
386 NodeName(output.coo_sparse().values_tensor_name()))) {
387 new_item->fetch.push_back(
388 NodeName(output.coo_sparse().values_tensor_name()));
389 }
390 if (gtl::InsertIfNotPresent(
391 &signature_fetch_nodes,
392 NodeName(output.coo_sparse().indices_tensor_name()))) {
393 new_item->fetch.push_back(
394 NodeName(output.coo_sparse().indices_tensor_name()));
395 }
396 if (gtl::InsertIfNotPresent(
397 &signature_fetch_nodes,
398 NodeName(output.coo_sparse().dense_shape_tensor_name()))) {
399 new_item->fetch.push_back(
400 NodeName(output.coo_sparse().dense_shape_tensor_name()));
401 }
402 } else {
403 if (gtl::InsertIfNotPresent(&signature_fetch_nodes,
404 NodeName(output.name()))) {
405 new_item->fetch.push_back(NodeName(output.name()));
406 }
407 }
408 }
409 }
410
411 for (const auto& feed : new_item->feed) {
412 if (feed.first.empty()) {
413 LOG(ERROR) << "Invalid feed node name skipping this input";
414 return nullptr;
415 } else {
416 VLOG(1) << "Will use feed node " << feed.first;
417 }
418 }
419
420 for (const auto& fetch : new_item->fetch) {
421 if (fetch.empty()) {
422 LOG(ERROR) << "Invalid fetch node name skipping this input";
423 return nullptr;
424 } else {
425 VLOG(1) << "Will use fetch node " << fetch;
426 }
427 }
428
429 if (new_item->fetch.empty()) {
430 LOG(ERROR) << "Failed to detect the fetch node(s), skipping this input";
431 return nullptr;
432 }
433
434 // TODO(yuefengz): consider handling saved_model_main_op and legacy_init_op.
435 // The reason why they are difficult to handle is because they may not intend
436 // to initialize all variables that are required to run fetch nodes. We may
437 // have to run restore op first.
438
439 // Try to find initializers from variables and tables as init ops.
440 for (const string& var_collection :
441 {"variables", "local_variables", "model_variables",
442 "trainable_variables"}) {
443 if (meta_graph.collection_def().count(var_collection) == 0) {
444 continue;
445 }
446 const CollectionDef& vars = meta_graph.collection_def().at(var_collection);
447 for (const auto& raw_var : vars.bytes_list().value()) {
448 VariableDef var;
449 var.ParseFromString(raw_var);
450 if (!var.initializer_name().empty()) {
451 new_item->init_ops.push_back(NodeName(var.initializer_name()));
452 }
453 }
454 }
455
456 if (meta_graph.collection_def().count("table_initializer") > 0) {
457 const CollectionDef& inits =
458 meta_graph.collection_def().at("table_initializer");
459 if (inits.has_node_list()) {
460 for (const auto& node : inits.node_list().value()) {
461 new_item->init_ops.push_back(NodeName(node));
462 // Tables are initialized from files, which can take a long time. Add
463 // 30 minutes to the initialization time for each table to avoid
464 // timing out.
465 // TODO(bsteiner): adjust the timeout based on the file size.
466 new_item->expected_init_time += 30 * 60;
467 }
468 }
469 }
470
471 // We keep the mapping from asset node to asset files. This should have been
472 // used as feed but since asset node is usually a constant node, we will fill
473 // the values of these constant nodes with their actual asset file paths.
474 std::unordered_map<string, string> asset_node_to_value;
475
476 // Assets file may have changed their directory, we assemble their new paths
477 // if assets_directory_override is set. We also make sure we still can
478 // access these asset files.
479 if (!cfg.assets_directory_override.empty()) {
480 if (meta_graph.collection_def().count("saved_model_assets") > 0) {
481 const CollectionDef& collection =
482 meta_graph.collection_def().at("saved_model_assets");
483 const auto& any_assets = collection.any_list().value();
484 if (!any_assets.empty()) {
485 if (std::is_base_of<protobuf::Message, AssetFileDef>()) {
486 for (const auto& any_asset : any_assets) {
487 AssetFileDef asset_file_def;
488 if (!ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef")
489 .ok()) {
490 LOG(ERROR) << "Failed to parse AssetFile.";
491 continue;
492 }
493 string asset_filepath = io::JoinPath(cfg.assets_directory_override,
494 asset_file_def.filename());
495 if (!FilesExist({asset_filepath}, nullptr)) {
496 LOG(ERROR) << "Can't access one or more of the asset files "
497 << asset_filepath << ", skipping this input";
498 return nullptr;
499 }
500 asset_node_to_value[NodeName(asset_file_def.tensor_info().name())] =
501 asset_filepath;
502 }
503 } else {
504 LOG(ERROR) << "Can't parse AssetFileDef when using lite protos.";
505 return nullptr;
506 }
507 }
508 }
509 } else if (meta_graph.collection_def().count("asset_filepaths") > 0) {
510 const CollectionDef& file_paths =
511 meta_graph.collection_def().at("asset_filepaths");
512 std::vector<string> paths;
513 for (const auto& raw_path : file_paths.bytes_list().value()) {
514 paths.push_back(raw_path);
515 }
516 if (!FilesExist(paths, nullptr)) {
517 LOG(ERROR) << "Can't access one or more of the asset files, skipping "
518 "this input";
519 return nullptr;
520 }
521 }
522
523 if (meta_graph.collection_def().count("queue_runners") > 0) {
524 const CollectionDef& vars = meta_graph.collection_def().at("queue_runners");
525 for (const auto& raw : vars.bytes_list().value()) {
526 QueueRunnerDef queue_runner;
527 if (!queue_runner.ParseFromString(raw)) {
528 LOG(ERROR) << "Could not parse queue_runners, skipping this input";
529 return nullptr;
530 }
531 if (queue_runner.cancel_op_name().empty()) {
532 LOG(ERROR) << "Queue without a cancel op, skipping this input";
533 return nullptr;
534 }
535 new_item->queue_runners.push_back(queue_runner);
536 }
537 }
538
539 // Add each node referenced in a collection to the list of nodes to keep.
540 for (const auto& col : meta_graph.collection_def()) {
541 const CollectionDef& collection = col.second;
542 for (const string& node : collection.node_list().value()) {
543 new_item->keep_ops.push_back(NodeName(node));
544 }
545 }
546
547 for (auto& node : *new_item->graph.mutable_node()) {
548 if (IsPlaceholder(node) && node.op() != "PlaceholderWithDefault") {
549 Status s = UpdatePlaceholderShape(cfg, signature_feed_nodes,
550 new_item.get(), &node);
551 if (!s.ok()) return nullptr;
552 } else if (IsConstant(node)) {
553 auto it = asset_node_to_value.find(node.name());
554 if (it != asset_node_to_value.end()) {
555 auto iter = node.mutable_attr()->find("value");
556 if (iter == node.attr().end()) {
557 LOG(ERROR) << "Value attribute expected in const op for asset files";
558 return nullptr;
559 }
560 if (!iter->second.has_tensor() ||
561 iter->second.tensor().string_val_size() != 1) {
562 LOG(INFO) << "Unexpected AttrValue proto: "
563 << iter->second.DebugString();
564 return nullptr;
565 }
566 LOG(INFO) << "Using asset file " << it->second << " for node "
567 << node.name();
568 *(iter->second.mutable_tensor()->mutable_string_val(0)) = it->second;
569 }
570 }
571
572 // Erase the recorded result of any previous shape inference to start again
573 // from scratch.
574 node.mutable_attr()->erase("_output_shapes");
575
576 // Delete user specified placement if requested.
577 if (cfg.ignore_user_placement) {
578 node.clear_device();
579 }
580 // Delete colocation constraints if requested.
581 if (cfg.ignore_colocation) {
582 auto attr = node.mutable_attr();
583 auto it = attr->find("_class");
584 if (it != attr->end()) {
585 attr->erase(it);
586 }
587 }
588 }
589
590 if (meta_graph.collection_def().count("savers") > 0) {
591 const CollectionDef& savers = meta_graph.collection_def().at("savers");
592 for (const auto& raw : savers.bytes_list().value()) {
593 SaverDef saver;
594 // Skip bad savers since we don't need saves/restores to be able to run a
595 // graph.
596 if (!saver.ParseFromString(raw)) {
597 continue;
598 }
599 if (saver.filename_tensor_name().empty()) {
600 continue;
601 }
602 new_item->save_op = saver.save_tensor_name();
603 new_item->restore_op = saver.restore_op_name();
604 new_item->save_restore_loc_tensor = saver.filename_tensor_name();
605 // Only use the first saver since it's not clear what to do if there's
606 // more than one.
607 break;
608 }
609 } else {
610 const SaverDef& saver = meta_graph.saver_def();
611 new_item->save_op = saver.save_tensor_name();
612 new_item->restore_op = saver.restore_op_name();
613 new_item->save_restore_loc_tensor = saver.filename_tensor_name();
614 }
615
616 // Instantiate all the missing attributes with their default values.
617 Status attr_status = AddDefaultAttrsToGraphDef(
618 &new_item->graph,
619 FunctionLibraryDefinition(OpRegistry::Global(),
620 new_item->graph.library()),
621 0, true);
622 if (!attr_status.ok()) {
623 LOG(ERROR) << "Failed to instantiate default attribute values: "
624 << attr_status.error_message();
625 return nullptr;
626 }
627
628 // Optimize the graph (function inlining, l1 optimizations, etc).
629 VLOG(1) << "Number of nodes in graph before RuntimeGraphOptimizer: "
630 << new_item->graph.node_size();
631 Status optimize_status =
632 RuntimeGraphOptimizer(new_item->graph, &new_item->graph, cfg);
633 if (!optimize_status.ok()) {
634 LOG(ERROR) << "Graph preprocessing failed: " << optimize_status;
635 return nullptr;
636 }
637 VLOG(1) << "Number of nodes in graph after RuntimeGraphOptimizer: "
638 << new_item->graph.node_size();
639
640 if (cfg.prune_graph) {
641 VLOG(1) << "Pruning graph...";
642 auto status = PruneGraph(new_item.get());
643 if (!status.ok()) {
644 LOG(ERROR) << "Pruning failed: " << status.error_message();
645 return nullptr;
646 }
647 VLOG(1) << "Number of nodes in graph after pruning: "
648 << new_item->graph.node_size();
649 }
650
651 // Validate feed, fetch and init nodes
652 std::unordered_set<string> nodes;
653 for (const auto& node : new_item->graph.node()) {
654 nodes.insert(node.name());
655 }
656 for (const auto& feed : new_item->feed) {
657 if (nodes.find(feed.first) == nodes.end()) {
658 LOG(ERROR) << "Feed node " << feed.first << " doesn't exist in graph";
659 return nullptr;
660 }
661 }
662 for (const auto& fetch : new_item->fetch) {
663 if (nodes.find(fetch) == nodes.end()) {
664 LOG(ERROR) << "Fetch node " << fetch << " doesn't exist in graph";
665 return nullptr;
666 }
667 }
668 for (const auto& init : new_item->init_ops) {
669 if (nodes.find(init) == nodes.end()) {
670 LOG(ERROR) << "Init node " << init << " doesn't exist in graph";
671 return nullptr;
672 }
673 }
674 return new_item;
675}
676
677std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDefFile(
678 const string& id, const string& meta_graph_file, const ItemConfig& cfg) {
679 MetaGraphDef meta_graph;
680 if (!ReadMetaGraphDefFromFile(meta_graph_file, &meta_graph).ok()) {
681 LOG(ERROR) << "Failed to read " << meta_graph_file;
682 return nullptr;
683 }
684 return GrapplerItemFromMetaGraphDef(id, meta_graph, cfg);
685}
686
687} // end namespace grappler
688} // end namespace tensorflow
689