1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/optimization_registry.h"
17#include "tensorflow/core/graph/algorithm.h"
18#include "tensorflow/core/graph/node_builder.h"
19#include "tensorflow/core/graph/optimizer_cse.h"
20
21namespace tensorflow {
22namespace {
23
24// Replaces occurrences of parallel_concat with the implementation based on
25// unsafe ops. Sets removed_any to true if any parallel_concats were removed;
26// leaves it untouched otherwise.
27class ParallelConcatRemovePass : public GraphOptimizationPass {
28 public:
29 Status Run(const GraphOptimizationPassOptions& options) override {
30 if (options.graph == nullptr) {
31 // TODO(apassos) returning OK feels weird here as we can't do anything
32 // without a graph, but some tests require this.
33 return OkStatus();
34 }
35 Graph* g = options.graph->get();
36 if (g == nullptr) {
37 return errors::Internal(
38 "Parallel concat removal should happen before partitioning and a "
39 "graph should be available.");
40 }
41 gtl::InlinedVector<Node*, 2> matches;
42 for (Node* n : g->op_nodes()) {
43 if (n->type_string() == "ParallelConcat") {
44 matches.push_back(n);
45 }
46 }
47 for (Node* n : matches) {
48 AttrSlice n_attrs = n->attrs();
49 auto base_make_node = [n, &n_attrs](const string& op,
50 const string& name) {
51 NodeDebugInfo debug_info(*n);
52 NodeBuilder node_builder(name, op, OpRegistry::Global(), &debug_info);
53 node_builder.Device(n->requested_device());
54 const string& colo = GetNodeAttrString(n_attrs, "_class");
55 if (!colo.empty()) {
56 node_builder.Attr("_class", colo);
57 }
58 return node_builder;
59 };
60 auto make_node = [n, g, &base_make_node](string op) {
61 return base_make_node(
62 op, g->NewName(strings::StrCat(n->name(), "/Internal")));
63 };
64 DataType dtype;
65 TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T", &dtype));
66 TensorShapeProto shape;
67 TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "shape", &shape));
68
69 // Add the start node
70 Node* start;
71 TF_RETURN_IF_ERROR(make_node("_ParallelConcatStart")
72 .Attr("shape", shape)
73 .Attr("dtype", dtype)
74 .Finalize(g, &start));
75
76 // Add all the inplace_updates.
77 std::vector<Node*> control_nodes;
78 for (const Edge* input_edge : n->in_edges()) {
79 if (input_edge->IsControlEdge()) {
80 g->AddControlEdge(input_edge->src(), start);
81 continue;
82 }
83
84 Node* update;
85 TF_RETURN_IF_ERROR(
86 make_node("_ParallelConcatUpdate")
87 .Attr("loc", input_edge->dst_input())
88 .Input(start)
89 .Input(input_edge->src(), input_edge->src_output())
90 .Finalize(g, &update));
91 control_nodes.push_back(update);
92 }
93
94 // Add the final identity.
95 NodeBuilder identity_def = base_make_node("Identity", n->name());
96 identity_def.Input(start, 0);
97 for (Node* s : control_nodes) {
98 identity_def.ControlInput(s);
99 }
100 Node* identity_node;
101 TF_RETURN_IF_ERROR(identity_def.Finalize(g, &identity_node));
102
103 // Remove the node and redirect edges.
104 for (auto* e : n->out_edges()) {
105 if (e->IsControlEdge()) {
106 g->AddControlEdge(identity_node, e->dst());
107 } else {
108 g->AddEdge(identity_node, 0, e->dst(), e->dst_input());
109 }
110 }
111 g->RemoveNode(n);
112 }
113 return OkStatus();
114 }
115};
116REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 10,
117 ParallelConcatRemovePass);
118
119} // namespace
120} // namespace tensorflow
121