1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#if defined(INTEL_MKL) && defined(ENABLE_MKL)
17
18#include "tensorflow/core/common_runtime/mkl_tfconversion_pass.h"
19
20#include <memory>
21#include <queue>
22#include <set>
23#include <utility>
24#include <vector>
25
26#include "tensorflow/core/common_runtime/function.h"
27#include "tensorflow/core/common_runtime/optimization_registry.h"
28#include "tensorflow/core/framework/node_def_util.h"
29#include "tensorflow/core/graph/algorithm.h"
30#include "tensorflow/core/graph/graph.h"
31#include "tensorflow/core/graph/mkl_graph_util.h"
32#include "tensorflow/core/graph/node_builder.h"
33#include "tensorflow/core/lib/core/status.h"
34#include "tensorflow/core/lib/gtl/map_util.h"
35#include "tensorflow/core/lib/hash/hash.h"
36#include "tensorflow/core/platform/logging.h"
37#include "tensorflow/core/util/util.h"
38
39namespace tensorflow {
40
41// This pass inserts Mkl to Tf tensor conversion nodes (represented by C)
42// in the graph in between A and B, where A and B match any one
43// of the following cases:
44//
45// 1) A = a node that generates output in the Mkl format and,
46// B = a node that does not accept input in the Mkl format and,
47// A -> B (there is a direct edge between A and B, then
48// We will insert C such that A->C->B.
49//
50// 2) A = a node that generates output in the Mkl format and,
51// B = NULL (in other words, A is the last node in the graph), then
52// We will insert C such that A->C->B. (C will be the last node.)
53//
54// Note that case 1 applies to all outputs of A that are input to B.
55// In other words, the conversions will be required for every output
56// of A that is input to B. For example, let us say the output of A
57// is A1, A2, A3, of which A1 and A2 are in Mkl format, but A3 is not
58// in Mkl format, and all of them are input to B. In such case, we will
59// do the conversion for A1 and A2 only. We do not need to do any conversion
60// for A3.
61//
62// This pass relies on ops registering themselves about their Mkl compliance.
63// An Mkl-compliant op can accept inputs in the Mkl format, and produce outputs
64// in the Mkl format. Non-compliant ops accept inputs and outputs in the
65// TensorFlow format.
66//
67// ADDENDUM: For element-wise ops, we may or may not need a conversion to
68// take place before we hit the op. For this, we add a new op before each
69// element-wise MKL op to deal with the inputs, called _MklInputConversion.
70// This pass has been enhanced to add this capability.
71//
72// The _MklInputConversion op will check the inputs to the elementwise op and
73// make sure that either both are in MKL format or both are in TF format,
74// depending on their initial state and whether broadcast is needed or not.
75
76class MklToTfConversionPass : public GraphOptimizationPass {
77 public:
78 MklToTfConversionPass() {}
79 Status Run(const GraphOptimizationPassOptions& options);
80
81 // Insert layout conversion node in the graph pointed by g.
82 // Function scans the graph for candidate edges where we
83 // need to insert conversion nodes.
84 //
85 // @return true even if single conversion node is inserted;
86 // false, otherwise.
87 bool RunPass(std::unique_ptr<Graph>* g);
88
89 private:
90 // Is the input Op supported by Mkl-specific layout?
91 //
92 // @input op_name string of the op
93 // @input T Datatype to use for checking input op
94 // @return true if op is Mkl supported; false, otherwise.
95 inline bool IsMklSupportedOp(const string& op_name, DataType T) const {
96 return mkl_op_registry::IsMklOp(op_name, T, false);
97 }
98
99 // Is the input Op supported by Mkl-specific layout AND
100 // is it element-wise?
101 //
102 // @input op_name string of the op
103 // @input T Datatype to use for checking input op
104 // @return true if op is Mkl supported; false, otherwise.
105 inline bool IsMklElementWiseOp(const string& op_name, DataType T) const {
106 return mkl_op_registry::IsMklElementWiseOp(op_name, T);
107 }
108
109 // Insert layout conversion node on the edge pointed by 'e' from graph 'g'.
110 //
111 // Edge will be deleted once a call to this function is successful.
112 // Any attempt to use the edge after this call
113 // will lead to undefined behaviors.
114 //
115 // @return Success:OK() if insertion is successful, otherwise returns
116 // appropriate error status code.
117 Status InsertConversionNodeOnEdge(std::unique_ptr<Graph>* g, Edge*);
118
119 // For element-wise ops, we need to sanitize the inputs. For this, we add a
120 // new node at the input of the replacement element-wise node that checks
121 // the inputs and converts one/both of them as required. See the op code
122 // comments for details.
123 //
124 // Insert input conversion node as parent of 'n' from graph 'g'.
125 //
126 // @return Success:OK() if insertion is successful, otherwise returns
127 // appropriate error status code.
128 Status InsertInputConversionNode(std::unique_ptr<Graph>* g, Node*);
129};
130
131// We register MklToTf insertion for phase 2 in post-partition grouping
132// because we register MklLayoutRewritePass for phase 1 in post-partition
133// grouping. We register this pass after partitioning so that we get a
134// complete picture of inputs and outputs of the nodes in the graphs.
135const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
136 OptimizationPassRegistry::POST_PARTITIONING;
137#ifdef ENABLE_MKL
138REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);
139#endif // ENABLE_MKL
140
141Status MklToTfConversionPass::InsertConversionNodeOnEdge(
142 std::unique_ptr<Graph>* g, Edge* e) {
143 CHECK_NOTNULL(e);
144
145 Node* src = e->src();
146 Node* dst = e->dst();
147
148 CHECK_NOTNULL(src);
149 CHECK_NOTNULL(dst);
150
151 Node* conversion_node = nullptr;
152 DataType src_datatype = src->output_type(e->src_output());
153 DataType dst_datatype = dst->input_type(e->dst_input());
154 string data_format;
155
156 // We compare source and destination datatypes only when both are found.
157 if (src_datatype != dst_datatype) {
158 string err_msg = "T attribute of " + src->name() + ":" +
159 std::to_string(e->src_output()) + " and " + dst->name() +
160 ":" + std::to_string(e->dst_input()) +
161 " do not"
162 " match. Will not insert MklToTf node in such case.";
163 return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
164 }
165
166 TF_CHECK_OK(
167 NodeBuilder((*g)->NewName("Mkl2Tf"), "_MklToTf")
168 .Input(src, e->src_output())
169 .Input(src, DataIndexToMetaDataIndex(
170 e->src_output(),
171 src->num_outputs())) // Get an Mkl tensor slot
172 // from the Tf tensor slot.
173 .Device(src->def().device()) // We want to get conversion node
174 // on same device as source node.
175 .Attr("T", src_datatype)
176 .Finalize(&**g, &conversion_node));
177
178 CHECK_NOTNULL(conversion_node);
179 // TODO(Intel-tf) MklToTf accepts only NHWC or NCHW, but doesn't seem to be
180 // using data_format. This code might be redundant.
181 if (GetNodeAttr(src->def(), "data_format", &data_format) == OkStatus() &&
182 (data_format == ToString(FORMAT_NHWC) ||
183 data_format == ToString(FORMAT_NCHW))) {
184 conversion_node->AddAttr("data_format", data_format);
185 }
186
187 // Get assigned device from source node and apply it to conversion node.
188 // We want conversion node to be on the same device as the source node.
189 conversion_node->set_assigned_device_name(src->assigned_device_name());
190
191 // Set the Mkl op label for this op.
192 conversion_node->AddAttr("_kernel",
193 mkl_op_registry::kMklLayoutDependentOpLabel);
194
195 // Now that we have added edge from src->conversion_node, let's add edge from
196 // output of conversion_node to the dest node. Since conversion_node
197 // has only 1 output, the src_output of conversion_node is 0.
198 CHECK_NOTNULL((*g)->AddEdge(conversion_node, 0, dst, e->dst_input()));
199
200 VLOG(1) << "MklToTfConversionPass: Inserting Conversion node on: "
201 << src->type_string() << " and " << dst->type_string()
202 << " successful.";
203
204 // Remove src->dst edge now.
205 (*g)->RemoveEdge(e);
206 return OkStatus();
207}
208
209Status MklToTfConversionPass::InsertInputConversionNode(
210 std::unique_ptr<Graph>* g, Node* n) {
211 CHECK_NOTNULL(n);
212
213 // Get the input nodes and edges
214 std::vector<const Edge*> edges;
215 TF_CHECK_OK(n->input_edges(&edges));
216 if (edges.size() != 4) {
217 return Status(error::Code::INVALID_ARGUMENT,
218 "MKL Binary Element-wise op should have exactly 2 data"
219 " inputs and 2 metadata inputs");
220 }
221
222 // Sanity check: ensure that both inputs are of the expected type, and the
223 // same type as input type
224 CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())),
225 BaseType(edges[1]->src()->output_type(edges[1]->src_output())));
226 CHECK_EQ(BaseType(edges[0]->src()->output_type(edges[0]->src_output())),
227 BaseType(n->input_type(0)));
228
229 // Check ordering of edges
230 for (uint32 i = 0; i < 4; i++) {
231 CHECK_EQ((edges[i]->dst_input() == i), true);
232 }
233
234 // Build the conversion node and specify src as input.
235 Node* conversion_node = nullptr;
236
237 TF_CHECK_OK(
238 NodeBuilder((*g)->NewName("MklInputConversion"), "_MklInputConversion")
239 .Input(edges[0]->src(), edges[0]->src_output())
240 .Input(edges[1]->src(), edges[1]->src_output())
241 .Input(edges[2]->src(), edges[2]->src_output())
242 .Input(edges[3]->src(), edges[3]->src_output())
243 .Device(n->def().device())
244 .Attr("T", n->input_type(0))
245 .Finalize(&**g, &conversion_node));
246
247 CHECK_NOTNULL(conversion_node);
248
249 // Change the destination of any control edges to the InputConversion node
250 if (edges.size() != n->in_edges().size()) {
251 std::vector<const Edge*> edges_to_remove;
252 for (const Edge* e : n->in_edges()) {
253 if (e->IsControlEdge()) {
254 CHECK_NOTNULL((*g)->AddControlEdge(e->src(), conversion_node));
255 edges_to_remove.push_back(e);
256 }
257 }
258 for (const Edge* e : edges_to_remove) {
259 (*g)->RemoveEdge(e);
260 }
261 }
262
263 // TODO(Intel-tf) MklInputConversion accepts only NHWC or NCHW, but doesn't
264 // seem to be using data_format. This code might be redundant.
265 string data_format;
266 if (GetNodeAttr(edges[0]->src()->def(), "data_format", &data_format) ==
267 OkStatus() &&
268 (data_format == ToString(FORMAT_NHWC) ||
269 data_format == ToString(FORMAT_NCHW))) {
270 conversion_node->AddAttr("data_format", data_format);
271 }
272
273 // Get assigned device from destination node and apply it to conversion node.
274 // We want conversion node to be on the same device as the destination node.
275 conversion_node->set_assigned_device_name(n->assigned_device_name());
276
277 // Set the Mkl op label for this op.
278 conversion_node->AddAttr("_kernel",
279 mkl_op_registry::kMklLayoutDependentOpLabel);
280
281 // Now that we have added edges from src->conversion_node, let's add edge from
282 // output of conversion_node to the element-wise node.
283 CHECK_NOTNULL((*g)->AddEdge(conversion_node, 0, n, edges[0]->dst_input()));
284 CHECK_NOTNULL((*g)->AddEdge(conversion_node, 1, n, edges[1]->dst_input()));
285 CHECK_NOTNULL((*g)->AddEdge(conversion_node, 2, n, edges[2]->dst_input()));
286 CHECK_NOTNULL((*g)->AddEdge(conversion_node, 3, n, edges[3]->dst_input()));
287
288 VLOG(1) << "MklToTfConversionPass - InputConversion: Inserting input "
289 << "conversion node on: " << n->type_string() << " successful.";
290
291 // Remove src->dst edge now.
292 (*g)->RemoveEdge(edges[0]);
293 (*g)->RemoveEdge(edges[1]);
294 (*g)->RemoveEdge(edges[2]);
295 (*g)->RemoveEdge(edges[3]);
296
297 return OkStatus();
298}
299
300bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
301 bool result = false;
302
303 CHECK_NOTNULL(g);
304
305 DumpGraph("Before MklToTfConversionPass", &**g);
306
307 // Since we are looking for an Mkl-supported op node immediately
308 // followed by a non-Mkl op node, we will just iterate over edge
309 // set of the graph.
310 // edge set whose source and destination are candidates for
311 // inserting conversion node
312 std::vector<Edge*> candidate_edges;
313
314 for (const Edge* e : (*g)->edges()) {
315 Node* src = e->src();
316 Node* dst = e->dst();
317
318 // We skip control edges.
319 if (e->IsControlEdge()) {
320 continue;
321 }
322
323 // We skip adding MklToTf on an edge between X->MklToTf or
324 // MklToTf->X, where X is any node.
325 if (src->type_string().compare("_MklToTf") == 0 ||
326 dst->type_string().compare("_MklToTf") == 0) {
327 continue;
328 }
329
330 VLOG(1) << "MklToTfConversionPass: InsertConversionNodes: "
331 << src->type_string() << " and " << dst->type_string();
332
333 // Let's get source and destination data type.
334 // We cannot check datatype on destination node because destination node
335 // may not be Mkl node.
336 DataType src_datatype;
337 DataType dst_datatype;
338 bool src_is_mkl_op =
339 (GetNodeAttr(src->def(), "T", &src_datatype) == OkStatus() &&
340 IsMklSupportedOp(src->type_string(), src_datatype));
341 bool dst_is_mkl_op =
342 (GetNodeAttr(dst->def(), "T", &dst_datatype) == OkStatus() &&
343 IsMklSupportedOp(dst->type_string(), dst_datatype));
344
345 // Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
346 if (src_is_mkl_op && !dst_is_mkl_op) {
347 VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name()
348 << " and " << dst->name() << " for inserting conversion nodes";
349 candidate_edges.push_back(const_cast<Edge*>(e));
350 }
351 }
352
353 // Process all candidate edges and insert conversion nodes on them.
354 for (Edge* e : candidate_edges) {
355 // Even if we insert conversion node on a single edge, we
356 // need to return true.
357 string src_name = e->src()->name();
358 string dst_name = e->dst()->name();
359 if (InsertConversionNodeOnEdge(g, e) == OkStatus()) {
360 VLOG(1) << "MklToTfConversionPass: Inserted conversion "
361 << "node on edge between " << src_name << " and " << dst_name;
362 result = true;
363 }
364 }
365
366 DumpGraph("After MklToTfConversionPass", &**g);
367
368 //---------------------------------------------------------------------------
369 // Check all nodes and add an input-conversion-node if the node is an mkl
370 // element-wise node.
371 VLOG(1) << "Before running MklToTfConversionPass - InputConversion";
372
373 std::vector<Node*> candidate_nodes;
374 std::vector<Node*> order;
375 GetReversePostOrder(**g, &order); // This will give us topological sort.
376
377 for (Node* n : order) {
378 // If node is not an op or it does not have a datatype, then skip.
379 DataType datatype;
380 if (!n->IsOp() || (GetNodeAttr(n->def(), "T", &datatype) != OkStatus())) {
381 continue;
382 }
383 if (IsMklElementWiseOp(n->type_string(), datatype)) {
384 // If the input node is an input-conversion op, skip
385 Node* input_node = nullptr;
386 TF_CHECK_OK(n->input_node(0, &input_node));
387 DataType input_datatype;
388 if ((GetNodeAttr(n->def(), "T", &input_datatype) == OkStatus()) &&
389 (input_node->type_string().compare("_MklInputConversion") == 0)) {
390 continue;
391 }
392
393 VLOG(1) << "MklToTfConversionPass: InputConversion: Scheduled node "
394 << n->name() << " for inserting input conversion node";
395 candidate_nodes.push_back(const_cast<Node*>(n));
396 }
397 }
398
399 // Process all candidate edges and insert conversion nodes on them.
400 for (Node* n : candidate_nodes) {
401 // Even if we insert conversion node on a single node, we
402 // need to return true.
403 if (InsertInputConversionNode(g, n) == OkStatus()) {
404 VLOG(1) << "MklToTfConversionPass: Inserted conversion "
405 << "on node " << n->name();
406 result = true;
407 }
408 }
409 DumpGraph("After MklToTfConversionPass - InputConversion", &**g);
410
411 // We need to return true even if we insert one conversion node
412 // anywhere in the graph.
413 return result;
414}
415
416//////////////////////////////////////////////////////////////////////////////
417// Run function for the pass
418//////////////////////////////////////////////////////////////////////////////
419
420bool InsertMklToTfConversionNodes(std::unique_ptr<Graph>* g) {
421 return MklToTfConversionPass().RunPass(g);
422}
423
424Status MklToTfConversionPass::Run(const GraphOptimizationPassOptions& options) {
425 if (options.graph == nullptr && options.partition_graphs == nullptr) {
426 return OkStatus();
427 }
428 if (!IsMKLEnabled()) {
429 VLOG(2) << "TF-MKL: MKL is not enabled";
430 return OkStatus();
431 }
432 if (NativeFormatEnabled()) {
433 VLOG(2)
434 << "Running in native format mode, MklToTfConversionPass won't run.";
435 return OkStatus();
436 }
437
438 auto process_graph = [&](std::unique_ptr<Graph>* g) {
439 // Get the ownership of graph
440 std::unique_ptr<Graph>* ng = std::move(g);
441 RunPass(ng);
442 // Return the ownership of graph back
443 g->reset(ng->release());
444 };
445
446 if (kMklTfConvPassGroup != OptimizationPassRegistry::POST_PARTITIONING) {
447 // For any pre-partitioning phase, graph is stored in options.graph.
448 process_graph(options.graph);
449 } else {
450 // For post partitioning phase, graphs are stored in
451 // options.partition_graphs.
452 for (auto& pg : *options.partition_graphs) {
453 process_graph(&pg.second);
454 }
455 }
456
457 return OkStatus();
458}
459
460} // namespace tensorflow
461
462#endif // defined(INTEL_MKL) && defined(ENABLE_MKL)
463