1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #if defined(INTEL_MKL) && defined(ENABLE_MKL) |
17 | |
18 | #include "tensorflow/core/common_runtime/mkl_tfconversion_pass.h" |
19 | |
20 | #include <memory> |
21 | #include <queue> |
22 | #include <set> |
23 | #include <utility> |
24 | #include <vector> |
25 | |
26 | #include "tensorflow/core/common_runtime/function.h" |
27 | #include "tensorflow/core/common_runtime/optimization_registry.h" |
28 | #include "tensorflow/core/framework/node_def_util.h" |
29 | #include "tensorflow/core/graph/algorithm.h" |
30 | #include "tensorflow/core/graph/graph.h" |
31 | #include "tensorflow/core/graph/mkl_graph_util.h" |
32 | #include "tensorflow/core/graph/node_builder.h" |
33 | #include "tensorflow/core/lib/core/status.h" |
34 | #include "tensorflow/core/lib/gtl/map_util.h" |
35 | #include "tensorflow/core/lib/hash/hash.h" |
36 | #include "tensorflow/core/platform/logging.h" |
37 | #include "tensorflow/core/util/util.h" |
38 | |
39 | namespace tensorflow { |
40 | |
41 | // This pass inserts Mkl to Tf tensor conversion nodes (represented by C) |
42 | // in the graph in between A and B, where A and B match any one |
43 | // of the following cases: |
44 | // |
45 | // 1) A = a node that generates output in the Mkl format and, |
46 | // B = a node that does not accept input in the Mkl format and, |
47 | // A -> B (there is a direct edge between A and B, then |
48 | // We will insert C such that A->C->B. |
49 | // |
50 | // 2) A = a node that generates output in the Mkl format and, |
51 | // B = NULL (in other words, A is the last node in the graph), then |
52 | // We will insert C such that A->C->B. (C will be the last node.) |
53 | // |
54 | // Note that case 1 applies to all outputs of A that are input to B. |
55 | // In other words, the conversions will be required for every output |
56 | // of A that is input to B. For example, let us say the output of A |
57 | // is A1, A2, A3, of which A1 and A2 are in Mkl format, but A3 is not |
58 | // in Mkl format, and all of them are input to B. In such case, we will |
59 | // do the conversion for A1 and A2 only. We do not need to do any conversion |
60 | // for A3. |
61 | // |
62 | // This pass relies on ops registering themselves about their Mkl compliance. |
63 | // An Mkl-compliant op can accept inputs in the Mkl format, and produce outputs |
64 | // in the Mkl format. Non-compliant ops accept inputs and outputs in the |
65 | // TensorFlow format. |
66 | // |
67 | // ADDENDUM: For element-wise ops, we may or may not need a conversion to |
68 | // take place before we hit the op. For this, we add a new op before each |
69 | // element-wise MKL op to deal with the inputs, called _MklInputConversion. |
70 | // This pass has been enhanced to add this capability. |
71 | // |
72 | // The _MklInputConversion op will check the inputs to the elementwise op and |
73 | // make sure that either both are in MKL format or both are in TF format, |
74 | // depending on their initial state and whether broadcast is needed or not. |
75 | |
76 | class MklToTfConversionPass : public GraphOptimizationPass { |
77 | public: |
78 | MklToTfConversionPass() {} |
79 | Status Run(const GraphOptimizationPassOptions& options); |
80 | |
81 | // Insert layout conversion node in the graph pointed by g. |
82 | // Function scans the graph for candidate edges where we |
83 | // need to insert conversion nodes. |
84 | // |
85 | // @return true even if single conversion node is inserted; |
86 | // false, otherwise. |
87 | bool RunPass(std::unique_ptr<Graph>* g); |
88 | |
89 | private: |
90 | // Is the input Op supported by Mkl-specific layout? |
91 | // |
92 | // @input op_name string of the op |
93 | // @input T Datatype to use for checking input op |
94 | // @return true if op is Mkl supported; false, otherwise. |
95 | inline bool IsMklSupportedOp(const string& op_name, DataType T) const { |
96 | return mkl_op_registry::IsMklOp(op_name, T, false); |
97 | } |
98 | |
99 | // Is the input Op supported by Mkl-specific layout AND |
100 | // is it element-wise? |
101 | // |
102 | // @input op_name string of the op |
103 | // @input T Datatype to use for checking input op |
104 | // @return true if op is Mkl supported; false, otherwise. |
105 | inline bool IsMklElementWiseOp(const string& op_name, DataType T) const { |
106 | return mkl_op_registry::IsMklElementWiseOp(op_name, T); |
107 | } |
108 | |
109 | // Insert layout conversion node on the edge pointed by 'e' from graph 'g'. |
110 | // |
111 | // Edge will be deleted once a call to this function is successful. |
112 | // Any attempt to use the edge after this call |
113 | // will lead to undefined behaviors. |
114 | // |
115 | // @return Success:OK() if insertion is successful, otherwise returns |
116 | // appropriate error status code. |
117 | Status InsertConversionNodeOnEdge(std::unique_ptr<Graph>* g, Edge*); |
118 | |
119 | // For element-wise ops, we need to sanitize the inputs. For this, we add a |
120 | // new node at the input of the replacement element-wise node that checks |
121 | // the inputs and converts one/both of them as required. See the op code |
122 | // comments for details. |
123 | // |
124 | // Insert input conversion node as parent of 'n' from graph 'g'. |
125 | // |
126 | // @return Success:OK() if insertion is successful, otherwise returns |
127 | // appropriate error status code. |
128 | Status InsertInputConversionNode(std::unique_ptr<Graph>* g, Node*); |
129 | }; |
130 | |
131 | // We register MklToTf insertion for phase 2 in post-partition grouping |
132 | // because we register MklLayoutRewritePass for phase 1 in post-partition |
133 | // grouping. We register this pass after partitioning so that we get a |
134 | // complete picture of inputs and outputs of the nodes in the graphs. |
135 | const OptimizationPassRegistry::Grouping kMklTfConvPassGroup = |
136 | OptimizationPassRegistry::POST_PARTITIONING; |
137 | #ifdef ENABLE_MKL |
138 | REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass); |
139 | #endif // ENABLE_MKL |
140 | |
141 | Status MklToTfConversionPass::InsertConversionNodeOnEdge( |
142 | std::unique_ptr<Graph>* g, Edge* e) { |
143 | CHECK_NOTNULL(e); |
144 | |
145 | Node* src = e->src(); |
146 | Node* dst = e->dst(); |
147 | |
148 | CHECK_NOTNULL(src); |
149 | CHECK_NOTNULL(dst); |
150 | |
151 | Node* conversion_node = nullptr; |
152 | DataType src_datatype = src->output_type(e->src_output()); |
153 | DataType dst_datatype = dst->input_type(e->dst_input()); |
154 | string data_format; |
155 | |
156 | // We compare source and destination datatypes only when both are found. |
157 | if (src_datatype != dst_datatype) { |
158 | string err_msg = "T attribute of " + src->name() + ":" + |
159 | std::to_string(e->src_output()) + " and " + dst->name() + |
160 | ":" + std::to_string(e->dst_input()) + |
161 | " do not" |
162 | " match. Will not insert MklToTf node in such case." ; |
163 | return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str()); |
164 | } |
165 | |
166 | TF_CHECK_OK( |
167 | NodeBuilder((*g)->NewName("Mkl2Tf" ), "_MklToTf" ) |
168 | .Input(src, e->src_output()) |
169 | .Input(src, DataIndexToMetaDataIndex( |
170 | e->src_output(), |
171 | src->num_outputs())) // Get an Mkl tensor slot |
172 | // from the Tf tensor slot. |
173 | .Device(src->def().device()) // We want to get conversion node |
174 | // on same device as source node. |
175 | .Attr("T" , src_datatype) |
176 | .Finalize(&**g, &conversion_node)); |
177 | |
178 | CHECK_NOTNULL(conversion_node); |
179 | // TODO(Intel-tf) MklToTf accepts only NHWC or NCHW, but doesn't seem to be |
180 | // using data_format. This code might be redundant. |
181 | if (GetNodeAttr(src->def(), "data_format" , &data_format) == OkStatus() && |
182 | (data_format == ToString(FORMAT_NHWC) || |
183 | data_format == ToString(FORMAT_NCHW))) { |
184 | conversion_node->AddAttr("data_format" , data_format); |
185 | } |
186 | |
187 | // Get assigned device from source node and apply it to conversion node. |
188 | // We want conversion node to be on the same device as the source node. |
189 | conversion_node->set_assigned_device_name(src->assigned_device_name()); |
190 | |
191 | // Set the Mkl op label for this op. |
192 | conversion_node->AddAttr("_kernel" , |
193 | mkl_op_registry::kMklLayoutDependentOpLabel); |
194 | |
195 | // Now that we have added edge from src->conversion_node, let's add edge from |
196 | // output of conversion_node to the dest node. Since conversion_node |
197 | // has only 1 output, the src_output of conversion_node is 0. |
198 | CHECK_NOTNULL((*g)->AddEdge(conversion_node, 0, dst, e->dst_input())); |
199 | |
200 | VLOG(1) << "MklToTfConversionPass: Inserting Conversion node on: " |
201 | << src->type_string() << " and " << dst->type_string() |
202 | << " successful." ; |
203 | |
204 | // Remove src->dst edge now. |
205 | (*g)->RemoveEdge(e); |
206 | return OkStatus(); |
207 | } |
208 | |
209 | Status MklToTfConversionPass::InsertInputConversionNode( |
210 | std::unique_ptr<Graph>* g, Node* n) { |
211 | CHECK_NOTNULL(n); |
212 | |
213 | // Get the input nodes and edges |
214 | std::vector<const Edge*> edges; |
215 | TF_CHECK_OK(n->input_edges(&edges)); |
216 | if (edges.size() != 4) { |
217 | return Status(error::Code::INVALID_ARGUMENT, |
218 | "MKL Binary Element-wise op should have exactly 2 data" |
219 | " inputs and 2 metadata inputs" ); |
220 | } |
221 | |
222 | // Sanity check: ensure that both inputs are of the expected type, and the |
223 | // same type as input type |
224 | CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())), |
225 | BaseType(edges[1]->src()->output_type(edges[1]->src_output()))); |
226 | CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())), |
227 | BaseType(n->input_type(0))); |
228 | |
229 | // Check ordering of edges |
230 | for (uint32 i = 0; i < 4; i++) { |
231 | CHECK_EQ((edges[i]->dst_input() == i), true); |
232 | } |
233 | |
234 | // Build the conversion node and specify src as input. |
235 | Node* conversion_node = nullptr; |
236 | |
237 | TF_CHECK_OK( |
238 | NodeBuilder((*g)->NewName("MklInputConversion" ), "_MklInputConversion" ) |
239 | .Input(edges[0]->src(), edges[0]->src_output()) |
240 | .Input(edges[1]->src(), edges[1]->src_output()) |
241 | .Input(edges[2]->src(), edges[2]->src_output()) |
242 | .Input(edges[3]->src(), edges[3]->src_output()) |
243 | .Device(n->def().device()) |
244 | .Attr("T" , n->input_type(0)) |
245 | .Finalize(&**g, &conversion_node)); |
246 | |
247 | CHECK_NOTNULL(conversion_node); |
248 | |
249 | // Change the destination of any control edges to the InputConversion node |
250 | if (edges.size() != n->in_edges().size()) { |
251 | std::vector<const Edge*> edges_to_remove; |
252 | for (const Edge* e : n->in_edges()) { |
253 | if (e->IsControlEdge()) { |
254 | CHECK_NOTNULL((*g)->AddControlEdge(e->src(), conversion_node)); |
255 | edges_to_remove.push_back(e); |
256 | } |
257 | } |
258 | for (const Edge* e : edges_to_remove) { |
259 | (*g)->RemoveEdge(e); |
260 | } |
261 | } |
262 | |
263 | // TODO(Intel-tf) MklInputConversion accepts only NHWC or NCHW, but doesn't |
264 | // seem to be using data_format. This code might be redundant. |
265 | string data_format; |
266 | if (GetNodeAttr(edges[0]->src()->def(), "data_format" , &data_format) == |
267 | OkStatus() && |
268 | (data_format == ToString(FORMAT_NHWC) || |
269 | data_format == ToString(FORMAT_NCHW))) { |
270 | conversion_node->AddAttr("data_format" , data_format); |
271 | } |
272 | |
273 | // Get assigned device from destination node and apply it to conversion node. |
274 | // We want conversion node to be on the same device as the destination node. |
275 | conversion_node->set_assigned_device_name(n->assigned_device_name()); |
276 | |
277 | // Set the Mkl op label for this op. |
278 | conversion_node->AddAttr("_kernel" , |
279 | mkl_op_registry::kMklLayoutDependentOpLabel); |
280 | |
281 | // Now that we have added edges from src->conversion_node, let's add edge from |
282 | // output of conversion_node to the element-wise node. |
283 | CHECK_NOTNULL((*g)->AddEdge(conversion_node, 0, n, edges[0]->dst_input())); |
284 | CHECK_NOTNULL((*g)->AddEdge(conversion_node, 1, n, edges[1]->dst_input())); |
285 | CHECK_NOTNULL((*g)->AddEdge(conversion_node, 2, n, edges[2]->dst_input())); |
286 | CHECK_NOTNULL((*g)->AddEdge(conversion_node, 3, n, edges[3]->dst_input())); |
287 | |
288 | VLOG(1) << "MklToTfConversionPass - InputConversion: Inserting input " |
289 | << "conversion node on: " << n->type_string() << " successful." ; |
290 | |
291 | // Remove src->dst edge now. |
292 | (*g)->RemoveEdge(edges[0]); |
293 | (*g)->RemoveEdge(edges[1]); |
294 | (*g)->RemoveEdge(edges[2]); |
295 | (*g)->RemoveEdge(edges[3]); |
296 | |
297 | return OkStatus(); |
298 | } |
299 | |
300 | bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) { |
301 | bool result = false; |
302 | |
303 | CHECK_NOTNULL(g); |
304 | |
305 | DumpGraph("Before MklToTfConversionPass" , &**g); |
306 | |
307 | // Since we are looking for an Mkl-supported op node immediately |
308 | // followed by a non-Mkl op node, we will just iterate over edge |
309 | // set of the graph. |
310 | // edge set whose source and destination are candidates for |
311 | // inserting conversion node |
312 | std::vector<Edge*> candidate_edges; |
313 | |
314 | for (const Edge* e : (*g)->edges()) { |
315 | Node* src = e->src(); |
316 | Node* dst = e->dst(); |
317 | |
318 | // We skip control edges. |
319 | if (e->IsControlEdge()) { |
320 | continue; |
321 | } |
322 | |
323 | // We skip adding MklToTf on an edge between X->MklToTf or |
324 | // MklToTf->X, where X is any node. |
325 | if (src->type_string().compare("_MklToTf" ) == 0 || |
326 | dst->type_string().compare("_MklToTf" ) == 0) { |
327 | continue; |
328 | } |
329 | |
330 | VLOG(1) << "MklToTfConversionPass: InsertConversionNodes: " |
331 | << src->type_string() << " and " << dst->type_string(); |
332 | |
333 | // Let's get source and destination data type. |
334 | // We cannot check datatype on destination node because destination node |
335 | // may not be Mkl node. |
336 | DataType src_datatype; |
337 | DataType dst_datatype; |
338 | bool src_is_mkl_op = |
339 | (GetNodeAttr(src->def(), "T" , &src_datatype) == OkStatus() && |
340 | IsMklSupportedOp(src->type_string(), src_datatype)); |
341 | bool dst_is_mkl_op = |
342 | (GetNodeAttr(dst->def(), "T" , &dst_datatype) == OkStatus() && |
343 | IsMklSupportedOp(dst->type_string(), dst_datatype)); |
344 | |
345 | // Check if src with is Mkl-compliant, while dst is not Mkl-compliant. |
346 | if (src_is_mkl_op && !dst_is_mkl_op) { |
347 | VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name() |
348 | << " and " << dst->name() << " for inserting conversion nodes" ; |
349 | candidate_edges.push_back(const_cast<Edge*>(e)); |
350 | } |
351 | } |
352 | |
353 | // Process all candidate edges and insert conversion nodes on them. |
354 | for (Edge* e : candidate_edges) { |
355 | // Even if we insert conversion node on a single edge, we |
356 | // need to return true. |
357 | string src_name = e->src()->name(); |
358 | string dst_name = e->dst()->name(); |
359 | if (InsertConversionNodeOnEdge(g, e) == OkStatus()) { |
360 | VLOG(1) << "MklToTfConversionPass: Inserted conversion " |
361 | << "node on edge between " << src_name << " and " << dst_name; |
362 | result = true; |
363 | } |
364 | } |
365 | |
366 | DumpGraph("After MklToTfConversionPass" , &**g); |
367 | |
368 | //--------------------------------------------------------------------------- |
369 | // Check all nodes and add an input-conversion-node if the node is an mkl |
370 | // element-wise node. |
371 | VLOG(1) << "Before running MklToTfConversionPass - InputConversion" ; |
372 | |
373 | std::vector<Node*> candidate_nodes; |
374 | std::vector<Node*> order; |
375 | GetReversePostOrder(**g, &order); // This will give us topological sort. |
376 | |
377 | for (Node* n : order) { |
378 | // If node is not an op or it does not have a datatype, then skip. |
379 | DataType datatype; |
380 | if (!n->IsOp() || (GetNodeAttr(n->def(), "T" , &datatype) != OkStatus())) { |
381 | continue; |
382 | } |
383 | if (IsMklElementWiseOp(n->type_string(), datatype)) { |
384 | // If the input node is an input-conversion op, skip |
385 | Node* input_node = nullptr; |
386 | TF_CHECK_OK(n->input_node(0, &input_node)); |
387 | DataType input_datatype; |
388 | if ((GetNodeAttr(n->def(), "T" , &input_datatype) == OkStatus()) && |
389 | (input_node->type_string().compare("_MklInputConversion" ) == 0)) { |
390 | continue; |
391 | } |
392 | |
393 | VLOG(1) << "MklToTfConversionPass: InputConversion: Scheduled node " |
394 | << n->name() << " for inserting input conversion node" ; |
395 | candidate_nodes.push_back(const_cast<Node*>(n)); |
396 | } |
397 | } |
398 | |
399 | // Process all candidate edges and insert conversion nodes on them. |
400 | for (Node* n : candidate_nodes) { |
401 | // Even if we insert conversion node on a single node, we |
402 | // need to return true. |
403 | if (InsertInputConversionNode(g, n) == OkStatus()) { |
404 | VLOG(1) << "MklToTfConversionPass: Inserted conversion " |
405 | << "on node " << n->name(); |
406 | result = true; |
407 | } |
408 | } |
409 | DumpGraph("After MklToTfConversionPass - InputConversion" , &**g); |
410 | |
411 | // We need to return true even if we insert one conversion node |
412 | // anywhere in the graph. |
413 | return result; |
414 | } |
415 | |
416 | ////////////////////////////////////////////////////////////////////////////// |
417 | // Run function for the pass |
418 | ////////////////////////////////////////////////////////////////////////////// |
419 | |
420 | bool InsertMklToTfConversionNodes(std::unique_ptr<Graph>* g) { |
421 | return MklToTfConversionPass().RunPass(g); |
422 | } |
423 | |
424 | Status MklToTfConversionPass::Run(const GraphOptimizationPassOptions& options) { |
425 | if (options.graph == nullptr && options.partition_graphs == nullptr) { |
426 | return OkStatus(); |
427 | } |
428 | if (!IsMKLEnabled()) { |
429 | VLOG(2) << "TF-MKL: MKL is not enabled" ; |
430 | return OkStatus(); |
431 | } |
432 | if (NativeFormatEnabled()) { |
433 | VLOG(2) |
434 | << "Running in native format mode, MklToTfConversionPass won't run." ; |
435 | return OkStatus(); |
436 | } |
437 | |
438 | auto process_graph = [&](std::unique_ptr<Graph>* g) { |
439 | // Get the ownership of graph |
440 | std::unique_ptr<Graph>* ng = std::move(g); |
441 | RunPass(ng); |
442 | // Return the ownership of graph back |
443 | g->reset(ng->release()); |
444 | }; |
445 | |
446 | if (kMklTfConvPassGroup != OptimizationPassRegistry::POST_PARTITIONING) { |
447 | // For any pre-partitioning phase, graph is stored in options.graph. |
448 | process_graph(options.graph); |
449 | } else { |
450 | // For post partitioning phase, graphs are stored in |
451 | // options.partition_graphs. |
452 | for (auto& pg : *options.partition_graphs) { |
453 | process_graph(&pg.second); |
454 | } |
455 | } |
456 | |
457 | return OkStatus(); |
458 | } |
459 | |
460 | } // namespace tensorflow |
461 | |
462 | #endif // defined(INTEL_MKL) && defined(ENABLE_MKL) |
463 | |