1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/optimization_registry.h"
17#include "tensorflow/core/graph/control_flow.h"
18#include "tensorflow/core/graph/node_builder.h"
19
20namespace tensorflow {
21namespace {
22
23static constexpr const char* const kParallelIterationsAttrName =
24 "parallel_iterations";
25
26Tensor make_zeros(const DataType& dtype, const TensorShapeProto& shape) {
27 Tensor tensor(dtype, TensorShape(shape));
28
29 // Conveniently, all numeric data types have 0x0 == zero. Otherwise we would
30 // need a giant switch statement here.
31 memset(const_cast<char*>(tensor.tensor_data().data()), 0,
32 tensor.tensor_data().size());
33
34 return tensor;
35}
36
37// Replaces occurrences of the "AccumulateNV2" stub operator with a graph of
38// lower-level ops. The graph is equivalent (modulo certain corner cases)
39// to the semantics of the original accumulate_n() Python op in math_ops.py.
40// Implementing the op with a rewrite allows this new variant of accumulate_n
41// to be differentiable.
42//
43// The binary code that generates AccumulateNV2 stub ops is located in a
44// dynamic library built out of tensorflow/contrib/framework. Ideally, this
45// class would also be in contrib, but calls to REGISTER_OPTIMIZATION() from
46// third-party libraries aren't currently supported.
47class AccumulateNV2RemovePass : public GraphOptimizationPass {
48 public:
49 Status Run(const GraphOptimizationPassOptions& options) override {
50 // TODO([email protected]): Substantial shared code with
51 // ParallelConcatRemovePass::Run(). Consider refactoring if someone makes
52 // a third similar rewrite.
53 if (options.graph == nullptr) {
54 // TODO(apassos) returning OK feels weird here as we can't do anything
55 // without a graph, but some tests require this.
56 return OkStatus();
57 }
58
59 Graph* g = options.graph->get();
60 if (g == nullptr) {
61 return errors::Internal(
62 "AccumulateNV2 removal should happen before partitioning and a "
63 "graph should be available.");
64 }
65
66 // Build up a todo list of ops to replace, *then* modify the graph
67 gtl::InlinedVector<Node*, 2> matches;
68 for (Node* n : g->op_nodes()) {
69 if (n->type_string() == "AccumulateNV2") {
70 matches.push_back(n);
71 }
72 }
73 if (matches.empty()) return OkStatus();
74
75 std::vector<ControlFlowInfo> control_flow_info;
76 TF_RETURN_IF_ERROR(BuildControlFlowInfo(g, &control_flow_info));
77
78 for (Node* n : matches) {
79 // Temporary variables do not work inside while loops with parallel
80 // iterations. If the `AccumulateNV2` node is executed inside a loop, we
81 // rewrite it into 'AddN' node.
82 const Node* frame = control_flow_info[n->id()].frame;
83 bool is_in_while_loop = frame->id() != Graph::kSourceId;
84
85 // With `parallel_iterations == 1` it's safe to use TemporaryVariable.
86 if (is_in_while_loop) {
87 int parallel_iterations;
88 bool found = TryGetNodeAttr(frame->attrs(), kParallelIterationsAttrName,
89 &parallel_iterations);
90 if (found && parallel_iterations == 1) {
91 is_in_while_loop = false;
92 }
93 }
94
95 if (is_in_while_loop) {
96 TF_RETURN_IF_ERROR(RewriteIntoAddN(n, g));
97 } else {
98 TF_RETURN_IF_ERROR(RewriteIntoTempVariable(n, g));
99 }
100 }
101 return OkStatus();
102 }
103
104 Status RewriteIntoTempVariable(Node* n, Graph* g) {
105 VLOG(3) << "Rewrite AccumulateNV2 into TemporaryVariable and Assign: "
106 << SummarizeNode(*n);
107
108 AttrSlice n_attrs = n->attrs();
109 auto base_make_node = [n, &n_attrs](const string& op, const string& name) {
110 NodeDebugInfo debug_info(*n);
111 NodeBuilder node_builder(name, op, OpRegistry::Global(), &debug_info);
112
113 // The pieces of AccumulateNV2 should all be on the same node.
114 node_builder.Device(n->requested_device());
115 const string& colo = GetNodeAttrString(n_attrs, kColocationAttrName);
116 if (!colo.empty()) {
117 node_builder.Attr(kColocationAttrName, colo);
118 }
119 return node_builder;
120 };
121 auto make_node = [n, g, &base_make_node](string op) {
122 return base_make_node(
123 op, g->NewName(strings::StrCat(n->name(), "/Internal")));
124 };
125
126 DataType dtype;
127 TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T", &dtype));
128 TensorShapeProto shape;
129 TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "shape", &shape));
130
131 std::vector<const Edge*> data_edges, control_edges;
132 for (const Edge* input_edge : n->in_edges()) {
133 if (input_edge->IsControlEdge()) {
134 control_edges.push_back(input_edge);
135 } else {
136 data_edges.push_back(input_edge);
137 }
138 }
139
140 // Create the following ops to replace the AccumulateNV2 placeholder:
141 Node* create_accumulator = nullptr; // TemporaryVariable op
142 Node* initial_val = nullptr; // Const op
143 Node* initialize_accumulator = nullptr; // Assign op
144 std::vector<Node*> add_values_to_accumulator; // AssignAdd ops
145 Node* clean_up_accumulator = nullptr; // DestroyTemporaryVariable
146
147 const string accumulator_name =
148 strings::StrCat(n->name(), "/Internal/Accumulator");
149 TensorShapeProto variable_shape;
150 variable_shape.add_dim()->set_size(0);
151 TF_RETURN_IF_ERROR(make_node("TemporaryVariable")
152 .Attr("shape", variable_shape)
153 .Attr("dtype", dtype)
154 .Attr("var_name", accumulator_name)
155 .Finalize(g, &create_accumulator));
156 PartialTensorShape partial_shape(shape);
157 // Make a Fill operation to make a zero tensor with the shape of the first
158 // input.
159 Node* shape_node;
160 TF_RETURN_IF_ERROR(
161 make_node("Shape")
162 .Input(data_edges[0]->src(), data_edges[0]->src_output())
163 .Finalize(g, &shape_node));
164 Node* zero;
165 TF_RETURN_IF_ERROR(make_node("Const")
166 .Attr("value", make_zeros(dtype, TensorShapeProto()))
167 .Attr("dtype", dtype)
168 .Finalize(g, &zero));
169 TF_RETURN_IF_ERROR(make_node("Fill")
170 .Input(shape_node)
171 .Input(zero)
172 .Finalize(g, &initial_val));
173 TF_RETURN_IF_ERROR(make_node("Assign")
174 .Attr("T", dtype)
175 .Input(create_accumulator) // ref: Ref(T)
176 .Input(initial_val) // value: T
177 .Attr("validate_shape", false)
178 .Finalize(g, &initialize_accumulator));
179 for (int i = 0; i < data_edges.size(); ++i) {
180 Node* assignAdd;
181 TF_RETURN_IF_ERROR(make_node("AssignAdd")
182 .Attr("T", dtype)
183 .Attr("use_locking", true)
184 .Input(initialize_accumulator) // ref: Ref(T)
185 .Input(data_edges[i]->src(),
186 data_edges[i]->src_output()) // value: T
187 .Finalize(g, &assignAdd));
188
189 add_values_to_accumulator.push_back(assignAdd);
190 }
191
192 // Note that we use the original placeholder op's name here
193 TF_RETURN_IF_ERROR(base_make_node("DestroyTemporaryVariable", n->name())
194 .Attr("T", dtype)
195 .Attr("var_name", accumulator_name)
196 .Input(initialize_accumulator)
197 .Finalize(g, &clean_up_accumulator));
198
199 // Add edges to the graph to ensure that operations occur in the right
200 // order:
201 // 1. Do anything that had a control edge to the AccumulateNV2 placeholder
202 // 2. Initialize accumulator
203 // 3. Add input values to accumulator (already handled by data edges
204 // added above)
205 // 4. Reclaim the buffer that held the accumulator
206 // 5. Do anything that depended on the AccumulateNV2 placeholder
207 for (const Edge* control_edge : control_edges) {
208 g->AddControlEdge(control_edge->src(), initialize_accumulator);
209 }
210
211 for (Node* assign_add : add_values_to_accumulator) {
212 g->AddControlEdge(assign_add, clean_up_accumulator);
213 }
214
215 for (const Edge* out_edge : n->out_edges()) {
216 if (out_edge->IsControlEdge()) {
217 g->AddControlEdge(clean_up_accumulator, out_edge->dst());
218 } else {
219 g->AddEdge(clean_up_accumulator, 0, out_edge->dst(),
220 out_edge->dst_input());
221 }
222 }
223
224 // Remove the original AccumulateNV2 placeholder op.
225 // This removal modifies the op and must happen after we have finished
226 // using its incoming/outgoing edge sets.
227 g->RemoveNode(n);
228
229 return OkStatus();
230 }
231
232 Status RewriteIntoAddN(Node* n, Graph* g) {
233 VLOG(3) << "Rewrite AccumulateNV2 into AddN: " << SummarizeNode(*n);
234
235 AttrSlice n_attrs = n->attrs();
236 DataType dtype;
237 TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T", &dtype));
238 int num_inputs;
239 TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "N", &num_inputs));
240
241 Node* add_n_node = nullptr;
242
243 std::vector<NodeBuilder::NodeOut> data_inputs;
244 std::vector<Node*> control_inputs;
245 data_inputs.reserve(n->num_inputs());
246 control_inputs.reserve(n->in_edges().size() - n->num_inputs());
247 for (const Edge* in_edge : n->in_edges()) {
248 if (in_edge->IsControlEdge()) {
249 control_inputs.push_back(in_edge->src());
250 } else {
251 data_inputs.emplace_back(in_edge->src(), in_edge->src_output());
252 }
253 }
254
255 // Rewrite `AccumulateNV2` node into `AddN` node.
256 NodeDebugInfo debug_info(*n);
257 NodeBuilder builder =
258 NodeBuilder(n->name(), "AddN", OpRegistry::Global(), &debug_info)
259 .Device(n->requested_device())
260 .Attr("N", num_inputs)
261 .Attr("T", dtype)
262 .Input(data_inputs)
263 .ControlInputs(control_inputs);
264 const string& colo = GetNodeAttrString(n_attrs, kColocationAttrName);
265 if (!colo.empty()) {
266 builder.Attr(kColocationAttrName, colo);
267 }
268 TF_RETURN_IF_ERROR(builder.Finalize(g, &add_n_node));
269
270 // Forward all consumers to the new node.
271 for (const Edge* out_edge : n->out_edges()) {
272 if (out_edge->IsControlEdge()) {
273 g->AddControlEdge(add_n_node, out_edge->dst());
274 } else {
275 g->AddEdge(add_n_node, 0, out_edge->dst(), out_edge->dst_input());
276 }
277 }
278
279 // Remove the original AccumulateNV2 placeholder op.
280 // This removal modifies the op and must happen after we have finished
281 // using its incoming/outgoing edge sets.
282 g->RemoveNode(n);
283
284 return OkStatus();
285 }
286};
287REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 10,
288 AccumulateNV2RemovePass);
289
290} // namespace
291} // namespace tensorflow
292