1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/optimization_registry.h" |
17 | #include "tensorflow/core/graph/control_flow.h" |
18 | #include "tensorflow/core/graph/node_builder.h" |
19 | |
20 | namespace tensorflow { |
21 | namespace { |
22 | |
23 | static constexpr const char* const kParallelIterationsAttrName = |
24 | "parallel_iterations" ; |
25 | |
26 | Tensor make_zeros(const DataType& dtype, const TensorShapeProto& shape) { |
27 | Tensor tensor(dtype, TensorShape(shape)); |
28 | |
29 | // Conveniently, all numeric data types have 0x0 == zero. Otherwise we would |
30 | // need a giant switch statement here. |
31 | memset(const_cast<char*>(tensor.tensor_data().data()), 0, |
32 | tensor.tensor_data().size()); |
33 | |
34 | return tensor; |
35 | } |
36 | |
37 | // Replaces occurrences of the "AccumulateNV2" stub operator with a graph of |
38 | // lower-level ops. The graph is equivalent (modulo certain corner cases) |
39 | // to the semantics of the original accumulate_n() Python op in math_ops.py. |
40 | // Implementing the op with a rewrite allows this new variant of accumulate_n |
41 | // to be differentiable. |
42 | // |
43 | // The binary code that generates AccumulateNV2 stub ops is located in a |
44 | // dynamic library built out of tensorflow/contrib/framework. Ideally, this |
45 | // class would also be in contrib, but calls to REGISTER_OPTIMIZATION() from |
46 | // third-party libraries aren't currently supported. |
47 | class AccumulateNV2RemovePass : public GraphOptimizationPass { |
48 | public: |
49 | Status Run(const GraphOptimizationPassOptions& options) override { |
50 | // TODO([email protected]): Substantial shared code with |
51 | // ParallelConcatRemovePass::Run(). Consider refactoring if someone makes |
52 | // a third similar rewrite. |
53 | if (options.graph == nullptr) { |
54 | // TODO(apassos) returning OK feels weird here as we can't do anything |
55 | // without a graph, but some tests require this. |
56 | return OkStatus(); |
57 | } |
58 | |
59 | Graph* g = options.graph->get(); |
60 | if (g == nullptr) { |
61 | return errors::Internal( |
62 | "AccumulateNV2 removal should happen before partitioning and a " |
63 | "graph should be available." ); |
64 | } |
65 | |
66 | // Build up a todo list of ops to replace, *then* modify the graph |
67 | gtl::InlinedVector<Node*, 2> matches; |
68 | for (Node* n : g->op_nodes()) { |
69 | if (n->type_string() == "AccumulateNV2" ) { |
70 | matches.push_back(n); |
71 | } |
72 | } |
73 | if (matches.empty()) return OkStatus(); |
74 | |
75 | std::vector<ControlFlowInfo> control_flow_info; |
76 | TF_RETURN_IF_ERROR(BuildControlFlowInfo(g, &control_flow_info)); |
77 | |
78 | for (Node* n : matches) { |
79 | // Temporary variables do not work inside while loops with parallel |
80 | // iterations. If the `AccumulateNV2` node is executed inside a loop, we |
81 | // rewrite it into 'AddN' node. |
82 | const Node* frame = control_flow_info[n->id()].frame; |
83 | bool is_in_while_loop = frame->id() != Graph::kSourceId; |
84 | |
85 | // With `parallel_iterations == 1` it's safe to use TemporaryVariable. |
86 | if (is_in_while_loop) { |
87 | int parallel_iterations; |
88 | bool found = TryGetNodeAttr(frame->attrs(), kParallelIterationsAttrName, |
89 | ¶llel_iterations); |
90 | if (found && parallel_iterations == 1) { |
91 | is_in_while_loop = false; |
92 | } |
93 | } |
94 | |
95 | if (is_in_while_loop) { |
96 | TF_RETURN_IF_ERROR(RewriteIntoAddN(n, g)); |
97 | } else { |
98 | TF_RETURN_IF_ERROR(RewriteIntoTempVariable(n, g)); |
99 | } |
100 | } |
101 | return OkStatus(); |
102 | } |
103 | |
104 | Status RewriteIntoTempVariable(Node* n, Graph* g) { |
105 | VLOG(3) << "Rewrite AccumulateNV2 into TemporaryVariable and Assign: " |
106 | << SummarizeNode(*n); |
107 | |
108 | AttrSlice n_attrs = n->attrs(); |
109 | auto base_make_node = [n, &n_attrs](const string& op, const string& name) { |
110 | NodeDebugInfo debug_info(*n); |
111 | NodeBuilder node_builder(name, op, OpRegistry::Global(), &debug_info); |
112 | |
113 | // The pieces of AccumulateNV2 should all be on the same node. |
114 | node_builder.Device(n->requested_device()); |
115 | const string& colo = GetNodeAttrString(n_attrs, kColocationAttrName); |
116 | if (!colo.empty()) { |
117 | node_builder.Attr(kColocationAttrName, colo); |
118 | } |
119 | return node_builder; |
120 | }; |
121 | auto make_node = [n, g, &base_make_node](string op) { |
122 | return base_make_node( |
123 | op, g->NewName(strings::StrCat(n->name(), "/Internal" ))); |
124 | }; |
125 | |
126 | DataType dtype; |
127 | TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T" , &dtype)); |
128 | TensorShapeProto shape; |
129 | TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "shape" , &shape)); |
130 | |
131 | std::vector<const Edge*> data_edges, control_edges; |
132 | for (const Edge* input_edge : n->in_edges()) { |
133 | if (input_edge->IsControlEdge()) { |
134 | control_edges.push_back(input_edge); |
135 | } else { |
136 | data_edges.push_back(input_edge); |
137 | } |
138 | } |
139 | |
140 | // Create the following ops to replace the AccumulateNV2 placeholder: |
141 | Node* create_accumulator = nullptr; // TemporaryVariable op |
142 | Node* initial_val = nullptr; // Const op |
143 | Node* initialize_accumulator = nullptr; // Assign op |
144 | std::vector<Node*> add_values_to_accumulator; // AssignAdd ops |
145 | Node* clean_up_accumulator = nullptr; // DestroyTemporaryVariable |
146 | |
147 | const string accumulator_name = |
148 | strings::StrCat(n->name(), "/Internal/Accumulator" ); |
149 | TensorShapeProto variable_shape; |
150 | variable_shape.add_dim()->set_size(0); |
151 | TF_RETURN_IF_ERROR(make_node("TemporaryVariable" ) |
152 | .Attr("shape" , variable_shape) |
153 | .Attr("dtype" , dtype) |
154 | .Attr("var_name" , accumulator_name) |
155 | .Finalize(g, &create_accumulator)); |
156 | PartialTensorShape partial_shape(shape); |
157 | // Make a Fill operation to make a zero tensor with the shape of the first |
158 | // input. |
159 | Node* shape_node; |
160 | TF_RETURN_IF_ERROR( |
161 | make_node("Shape" ) |
162 | .Input(data_edges[0]->src(), data_edges[0]->src_output()) |
163 | .Finalize(g, &shape_node)); |
164 | Node* zero; |
165 | TF_RETURN_IF_ERROR(make_node("Const" ) |
166 | .Attr("value" , make_zeros(dtype, TensorShapeProto())) |
167 | .Attr("dtype" , dtype) |
168 | .Finalize(g, &zero)); |
169 | TF_RETURN_IF_ERROR(make_node("Fill" ) |
170 | .Input(shape_node) |
171 | .Input(zero) |
172 | .Finalize(g, &initial_val)); |
173 | TF_RETURN_IF_ERROR(make_node("Assign" ) |
174 | .Attr("T" , dtype) |
175 | .Input(create_accumulator) // ref: Ref(T) |
176 | .Input(initial_val) // value: T |
177 | .Attr("validate_shape" , false) |
178 | .Finalize(g, &initialize_accumulator)); |
179 | for (int i = 0; i < data_edges.size(); ++i) { |
180 | Node* assignAdd; |
181 | TF_RETURN_IF_ERROR(make_node("AssignAdd" ) |
182 | .Attr("T" , dtype) |
183 | .Attr("use_locking" , true) |
184 | .Input(initialize_accumulator) // ref: Ref(T) |
185 | .Input(data_edges[i]->src(), |
186 | data_edges[i]->src_output()) // value: T |
187 | .Finalize(g, &assignAdd)); |
188 | |
189 | add_values_to_accumulator.push_back(assignAdd); |
190 | } |
191 | |
192 | // Note that we use the original placeholder op's name here |
193 | TF_RETURN_IF_ERROR(base_make_node("DestroyTemporaryVariable" , n->name()) |
194 | .Attr("T" , dtype) |
195 | .Attr("var_name" , accumulator_name) |
196 | .Input(initialize_accumulator) |
197 | .Finalize(g, &clean_up_accumulator)); |
198 | |
199 | // Add edges to the graph to ensure that operations occur in the right |
200 | // order: |
201 | // 1. Do anything that had a control edge to the AccumulateNV2 placeholder |
202 | // 2. Initialize accumulator |
203 | // 3. Add input values to accumulator (already handled by data edges |
204 | // added above) |
205 | // 4. Reclaim the buffer that held the accumulator |
206 | // 5. Do anything that depended on the AccumulateNV2 placeholder |
207 | for (const Edge* control_edge : control_edges) { |
208 | g->AddControlEdge(control_edge->src(), initialize_accumulator); |
209 | } |
210 | |
211 | for (Node* assign_add : add_values_to_accumulator) { |
212 | g->AddControlEdge(assign_add, clean_up_accumulator); |
213 | } |
214 | |
215 | for (const Edge* out_edge : n->out_edges()) { |
216 | if (out_edge->IsControlEdge()) { |
217 | g->AddControlEdge(clean_up_accumulator, out_edge->dst()); |
218 | } else { |
219 | g->AddEdge(clean_up_accumulator, 0, out_edge->dst(), |
220 | out_edge->dst_input()); |
221 | } |
222 | } |
223 | |
224 | // Remove the original AccumulateNV2 placeholder op. |
225 | // This removal modifies the op and must happen after we have finished |
226 | // using its incoming/outgoing edge sets. |
227 | g->RemoveNode(n); |
228 | |
229 | return OkStatus(); |
230 | } |
231 | |
232 | Status RewriteIntoAddN(Node* n, Graph* g) { |
233 | VLOG(3) << "Rewrite AccumulateNV2 into AddN: " << SummarizeNode(*n); |
234 | |
235 | AttrSlice n_attrs = n->attrs(); |
236 | DataType dtype; |
237 | TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T" , &dtype)); |
238 | int num_inputs; |
239 | TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "N" , &num_inputs)); |
240 | |
241 | Node* add_n_node = nullptr; |
242 | |
243 | std::vector<NodeBuilder::NodeOut> data_inputs; |
244 | std::vector<Node*> control_inputs; |
245 | data_inputs.reserve(n->num_inputs()); |
246 | control_inputs.reserve(n->in_edges().size() - n->num_inputs()); |
247 | for (const Edge* in_edge : n->in_edges()) { |
248 | if (in_edge->IsControlEdge()) { |
249 | control_inputs.push_back(in_edge->src()); |
250 | } else { |
251 | data_inputs.emplace_back(in_edge->src(), in_edge->src_output()); |
252 | } |
253 | } |
254 | |
255 | // Rewrite `AccumulateNV2` node into `AddN` node. |
256 | NodeDebugInfo debug_info(*n); |
257 | NodeBuilder builder = |
258 | NodeBuilder(n->name(), "AddN" , OpRegistry::Global(), &debug_info) |
259 | .Device(n->requested_device()) |
260 | .Attr("N" , num_inputs) |
261 | .Attr("T" , dtype) |
262 | .Input(data_inputs) |
263 | .ControlInputs(control_inputs); |
264 | const string& colo = GetNodeAttrString(n_attrs, kColocationAttrName); |
265 | if (!colo.empty()) { |
266 | builder.Attr(kColocationAttrName, colo); |
267 | } |
268 | TF_RETURN_IF_ERROR(builder.Finalize(g, &add_n_node)); |
269 | |
270 | // Forward all consumers to the new node. |
271 | for (const Edge* out_edge : n->out_edges()) { |
272 | if (out_edge->IsControlEdge()) { |
273 | g->AddControlEdge(add_n_node, out_edge->dst()); |
274 | } else { |
275 | g->AddEdge(add_n_node, 0, out_edge->dst(), out_edge->dst_input()); |
276 | } |
277 | } |
278 | |
279 | // Remove the original AccumulateNV2 placeholder op. |
280 | // This removal modifies the op and must happen after we have finished |
281 | // using its incoming/outgoing edge sets. |
282 | g->RemoveNode(n); |
283 | |
284 | return OkStatus(); |
285 | } |
286 | }; |
287 | REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 10, |
288 | AccumulateNV2RemovePass); |
289 | |
290 | } // namespace |
291 | } // namespace tensorflow |
292 | |