1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/optimization_registry.h" |
17 | #include "tensorflow/core/graph/algorithm.h" |
18 | #include "tensorflow/core/graph/node_builder.h" |
19 | #include "tensorflow/core/graph/optimizer_cse.h" |
20 | |
21 | namespace tensorflow { |
22 | namespace { |
23 | |
24 | // Replaces occurrences of parallel_concat with the implementation based on |
25 | // unsafe ops. Sets removed_any to true if any parallel_concats were removed; |
26 | // leaves it untouched otherwise. |
27 | class ParallelConcatRemovePass : public GraphOptimizationPass { |
28 | public: |
29 | Status Run(const GraphOptimizationPassOptions& options) override { |
30 | if (options.graph == nullptr) { |
31 | // TODO(apassos) returning OK feels weird here as we can't do anything |
32 | // without a graph, but some tests require this. |
33 | return OkStatus(); |
34 | } |
35 | Graph* g = options.graph->get(); |
36 | if (g == nullptr) { |
37 | return errors::Internal( |
38 | "Parallel concat removal should happen before partitioning and a " |
39 | "graph should be available." ); |
40 | } |
41 | gtl::InlinedVector<Node*, 2> matches; |
42 | for (Node* n : g->op_nodes()) { |
43 | if (n->type_string() == "ParallelConcat" ) { |
44 | matches.push_back(n); |
45 | } |
46 | } |
47 | for (Node* n : matches) { |
48 | AttrSlice n_attrs = n->attrs(); |
49 | auto base_make_node = [n, &n_attrs](const string& op, |
50 | const string& name) { |
51 | NodeDebugInfo debug_info(*n); |
52 | NodeBuilder node_builder(name, op, OpRegistry::Global(), &debug_info); |
53 | node_builder.Device(n->requested_device()); |
54 | const string& colo = GetNodeAttrString(n_attrs, "_class" ); |
55 | if (!colo.empty()) { |
56 | node_builder.Attr("_class" , colo); |
57 | } |
58 | return node_builder; |
59 | }; |
60 | auto make_node = [n, g, &base_make_node](string op) { |
61 | return base_make_node( |
62 | op, g->NewName(strings::StrCat(n->name(), "/Internal" ))); |
63 | }; |
64 | DataType dtype; |
65 | TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T" , &dtype)); |
66 | TensorShapeProto shape; |
67 | TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "shape" , &shape)); |
68 | |
69 | // Add the start node |
70 | Node* start; |
71 | TF_RETURN_IF_ERROR(make_node("_ParallelConcatStart" ) |
72 | .Attr("shape" , shape) |
73 | .Attr("dtype" , dtype) |
74 | .Finalize(g, &start)); |
75 | |
76 | // Add all the inplace_updates. |
77 | std::vector<Node*> control_nodes; |
78 | for (const Edge* input_edge : n->in_edges()) { |
79 | if (input_edge->IsControlEdge()) { |
80 | g->AddControlEdge(input_edge->src(), start); |
81 | continue; |
82 | } |
83 | |
84 | Node* update; |
85 | TF_RETURN_IF_ERROR( |
86 | make_node("_ParallelConcatUpdate" ) |
87 | .Attr("loc" , input_edge->dst_input()) |
88 | .Input(start) |
89 | .Input(input_edge->src(), input_edge->src_output()) |
90 | .Finalize(g, &update)); |
91 | control_nodes.push_back(update); |
92 | } |
93 | |
94 | // Add the final identity. |
95 | NodeBuilder identity_def = base_make_node("Identity" , n->name()); |
96 | identity_def.Input(start, 0); |
97 | for (Node* s : control_nodes) { |
98 | identity_def.ControlInput(s); |
99 | } |
100 | Node* identity_node; |
101 | TF_RETURN_IF_ERROR(identity_def.Finalize(g, &identity_node)); |
102 | |
103 | // Remove the node and redirect edges. |
104 | for (auto* e : n->out_edges()) { |
105 | if (e->IsControlEdge()) { |
106 | g->AddControlEdge(identity_node, e->dst()); |
107 | } else { |
108 | g->AddEdge(identity_node, 0, e->dst(), e->dst_input()); |
109 | } |
110 | } |
111 | g->RemoveNode(n); |
112 | } |
113 | return OkStatus(); |
114 | } |
115 | }; |
116 | REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 10, |
117 | ParallelConcatRemovePass); |
118 | |
119 | } // namespace |
120 | } // namespace tensorflow |
121 | |