1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// TODO(intel): Improve error handling in this file; instead of CHECK failing
17// all over the place, we should log an error and execute the original graph.
18#ifdef INTEL_MKL
19
20#include "tensorflow/core/common_runtime/mkl_layout_pass.h"
21
22#include <algorithm>
23#include <functional>
24#include <memory>
25#include <queue>
26#include <set>
27#include <stack>
28#include <tuple>
29#include <unordered_set>
30#include <utility>
31#include <vector>
32
33#include "absl/base/call_once.h"
34#include "tensorflow/core/common_runtime/function.h"
35#include "tensorflow/core/common_runtime/optimization_registry.h"
36#include "tensorflow/core/framework/node_def_util.h"
37#include "tensorflow/core/framework/tensor.pb.h"
38#include "tensorflow/core/graph/algorithm.h"
39#include "tensorflow/core/graph/graph.h"
40#include "tensorflow/core/graph/mkl_graph_util.h"
41#include "tensorflow/core/graph/node_builder.h"
42#include "tensorflow/core/lib/core/status.h"
43#include "tensorflow/core/lib/gtl/array_slice.h"
44#include "tensorflow/core/lib/gtl/map_util.h"
45#include "tensorflow/core/lib/hash/hash.h"
46#include "tensorflow/core/platform/logging.h"
47#include "tensorflow/core/util/tensor_format.h"
48#include "tensorflow/core/util/util.h"
49
50namespace tensorflow {
51
52// This pass implements rewriting of graph to support following scenarios:
53// (A) Merging nodes in the graph
54// (B) Rewriting a node in the graph to a new node
55// Rewrite happens under following scenario:
56// - Propagating Mkl layout as an additional output tensor
57// (we will loosely call a tensor that carries Mkl layout as Mkl tensor
58// henceforth.) from every Mkl supported NN layer.
59//
60// Example of A : Merging nodes in the graph
61// -----------------------------------------
62// Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as:
63//
64// O = Conv2D(A, B)
65// P = BiasAdd(O, C)
66//
67// We merge them into Conv2DWithBias as:
68// P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
69//
70// The meaning of A_m, B_m and C_m is explained in B.1.
71//
72// Merge rules:
73// - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_
74// goes to BiasAdd.
75// - Also, the intersection of attributes of both the nodes must have same
76// values.
77// - Both the nodes must have been assigned to same device (if any).
78//
79// Example of B.1 : Rewriting nodes to Mkl nodes
80// ---------------------------------------------
81// Consider a Relu node. Current definition of Relu node looks like:
82//
83// O = Relu(A)
84//
85// Relu has 1 input (A), and 1 output (O).
86//
87// This rewrite pass will generate a new graph node for Relu (new node is
88// called MklRelu) as:
89//
90// O, O_m = MklRelu(A, A_m)
91//
92// MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is
93// same as input A of Relu; output O is same as output O of Relu. O_m is the
94// additional output tensor that will be set by MklRelu, and it represents
95// Mkl tensor corresponding to O -- in other words, O_m is some kind of
96// metadata for O. A_m is additional input of Relu, and it represents metadata
97// for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives
98// this metadata from previous node in the graph.
99//
100// When a previous node in the graph is an Mkl node, A_m will represent a valid
101// Mkl tensor. But when a previous node is not an Mkl node, A_m will represent
102// a dummy Mkl tensor.
103//
104// Rewriting rules:
105// - Selection of a node for rewriting happens by registering the op type of
106// the node with the rewriting pass. If the op type is not registered, then
107// all nodes of this op type will not be rewritten.
108// - Number of inputs after rewriting:
109// Since for every input Tensorflow tensor, the rewritten node gets Mkl
110// tensor(s), rewritten node gets 2*N inputs, where N is the number of
111// inputs for the original node.
112// - Number of outputs after rewriting:
113// Since for every output Tensorflow tensor, the rewritten node generates
114// Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the
115// number of outputs of the original node.
116// - Ordering of Tensorflow tensors and Mkl tensors:
117// Since every rewritten node generates twice the number of inputs and
118// outputs, one could imagine various orderings among Tensorflow tensors
119// and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as
120// inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m
121// in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m
122// order. Among N inputs one can get N! permutations.
123//
124// So the question is: which order do we follow? We support 2 types of
125// orderings: (1) interleaved, and (2) contiguous. Interleaved ordering
126// follows an intuitive order where an Mkl tensor follows the
127// corresponding Tensorflow tensor immediately. In the context of the
128// above example, it will be: A, A_m, B, B_m. Note that the ordering rule
129// applies to both the inputs and outputs. Contiguous ordering means
130// all the Tensorflow tensors are contiguous followed by all the Mkl
131// tensors. We use contiguous ordering as default.
132//
133// Graph rewrite algorithm:
134// Algorithm: Graph Rewrite
135// Input: Graph G, Names of the nodes to rewrite and their new names
136// Output: Modified Graph G' if the nodes are modified, G otherwise.
137// Start:
138// N = Topological_Sort(G) // N is a set of nodes in toposort order.
139// foreach node n in N
140// do
141// if (Is_MKL_Op(n)) // Can this node accept an Mkl layout as input.
142// then
143// E = set of <incoming edge and its src_output slot> of n
144// E' = {} // a new set of edges for rewritten node
145// foreach <e,s> in E
146// do
147// E' U {<e,s>} // First copy edge which generates Tensorflow
148// // tensor as it is
149// m = Source node of edge e
150// if Is_Rewritten(m) // Did we rewrite this node in this pass?
151// then
152// E' U {<m,s+1>} // If yes, then m will generate an Mkl
153// // tensor as an additional output.
154// else
155// d = Generate_Dummy_Mkl_Tensor() // If not, generate a dummy
156// // Mkl tensor.
157// E' U {<d,0>} // The dummy Mkl tensor has only 1 output slot.
158// fi
159// done
160// n' = Build_New_Node(G,new_name,E')
161// Mark_Rewritten(n') // Mark the new node as being rewritten.
162// fi
163// done
164//
165// Explanation:
166// For graph rewrite, we visit nodes of the input graph in the
167// topological sort order. With this ordering, we visit nodes in the
168// top-to-bottom fashion. We need this order because while visiting a
169// node we want that all of its input nodes are visited and rewritten if
170// applicable. This is because if we need to rewrite a given node
171// then all of its input nodes need to be fixed (in other words they
172// cannot be deleted later.)
173//
174// While visiting a node, we first check if the op type of the node is
175// an Mkl op. If it is, then we rewrite that node after constructing
176// new inputs to the node. If the op type of the node is not Mkl op,
177// then we do not rewrite that node.
178//
179// Handling workspace propagation for certain ops:
180//
181// Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require
182// passing of a workspace from their respective forward ops. Workspace
183// tensors provide memory for storing results of intermediate operations
184// which are helpful in backward propagation. TensorFlow does not have
185// a notion of a workspace and as a result does not allow producing
186// additional outputs from these forward ops. For these ops, we need
187// to add 2 extra edges between forward ops and their corresponding
188// backward ops - the first extra edge carries a workspace tensor and
189// the second one carries an Mkl tensor for the workspace tensor.
190//
191// Example:
192//
193// Typical graph for MaxPool and its gradient looks like:
194//
195// A = MaxPool(T)
196// B = MaxPoolGrad(X, A, Y)
197//
198// We will transform this graph to propagate the workspace as:
199// (with the contiguous ordering)
200//
201// A, W, A_m, W_m = MklMaxPool(T, T_m)
202// B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m)
203//
204// Here W is the workspace tensor. Transformed tensor names with the
205// suffix _m are Mkl tensors, and this transformation has been done
206// using the algorithm discussed earlier. The transformation for
207// workspace propagation only adds extra outputs (W, W_m) for a forward
208// op and connects them to the corresponding backward ops.
209//
210// Terms:
211//
212// Forward op name = name of the op in the forward pass
213// where a workspace tensor originates (MaxPool in this example)
214// Backward op name = name of the op in the backward pass that receives
215// a workspace tensor from the forward op (MaxPoolGrad in the example)
216// Slot = Position of the output or input slot that will be
217// used by the workspace tensor (1 for MklMaxPool as W is the 2nd
218// output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad)
219//
220// Question:
221//
222// How do we associate a backward op to a forward op? There can be more
223// than one op with the exact same name.
224//
225// In this example, we associate MaxPoolGrad with MaxPool. But there
226// could be more than one MaxPool ops. To solve this problem, we look
227// for _direct_ edge between a forward op and a backward op (tensor A is
228// flowing along this edge in the example).
229//
230// How do we transform forward and backward ops when there is no direct
231// edge between them? In such a case, we generate dummy tensors for
232// workspace tensors. For the example, transformation of MaxPool will
233// be exactly same as it would be when there is a direct edge between
234// the forward and the backward op --- it is just that MaxPool won't
235// generate any workspace tensor. For MaxPoolGrad, the transformation
236// will also be same, but instead of connecting W and W_m with the
237// outputs of MaxPool, we will produce dummy tensors for them, and we
238// will set workspace_enabled attribute to false.
239//
240class MklLayoutRewritePass : public GraphOptimizationPass {
241 public:
242 MklLayoutRewritePass() {
243 // NOTE: names are alphabetically sorted.
244 csinfo_.addn = "AddN";
245 csinfo_.avg_pool = "AvgPool";
246 csinfo_.avg_pool_grad = "AvgPoolGrad";
247 csinfo_.avg_pool3d = "AvgPool3D";
248 csinfo_.avg_pool3d_grad = "AvgPool3DGrad";
249 csinfo_.batch_matmul = "BatchMatMul";
250 csinfo_.batch_matmul_v2 = "BatchMatMulV2";
251 csinfo_.bias_add = "BiasAdd";
252 csinfo_.bias_add_grad = "BiasAddGrad";
253 csinfo_.concat = "Concat";
254 csinfo_.concatv2 = "ConcatV2";
255 csinfo_.conjugate_transpose = "ConjugateTranspose";
256 csinfo_.conv2d = "Conv2D";
257 csinfo_.conv2d_with_bias = "__MklDummyConv2DWithBias";
258 csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
259 csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
260 csinfo_.conv2d_grad_filter_with_bias =
261 "__MklDummyConv2DBackpropFilterWithBias";
262 csinfo_.conv3d = "Conv3D";
263 csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2";
264 csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2";
265 csinfo_.depthwise_conv2d = "DepthwiseConv2dNative";
266 csinfo_.depthwise_conv2d_grad_input = "DepthwiseConv2dNativeBackpropInput";
267 csinfo_.depthwise_conv2d_grad_filter =
268 "DepthwiseConv2dNativeBackpropFilter";
269 csinfo_.dequantize = "Dequantize";
270 csinfo_.einsum = "Einsum";
271 csinfo_.fused_batch_norm = "FusedBatchNorm";
272 csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
273 csinfo_.fused_batch_norm_ex = "_FusedBatchNormEx";
274 csinfo_.fused_batch_norm_v2 = "FusedBatchNormV2";
275 csinfo_.fused_batch_norm_grad_v2 = "FusedBatchNormGradV2";
276 csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3";
277 csinfo_.fused_batch_norm_grad_v3 = "FusedBatchNormGradV3";
278 csinfo_.fused_conv2d = "_FusedConv2D";
279 csinfo_.fused_conv3d = "_FusedConv3D";
280 csinfo_.fused_depthwise_conv2d = "_FusedDepthwiseConv2dNative";
281 csinfo_.fused_matmul = "_FusedMatMul";
282 csinfo_.identity = "Identity";
283 csinfo_.leakyrelu = "LeakyRelu";
284 csinfo_.leakyrelu_grad = "LeakyReluGrad";
285 csinfo_.lrn = "LRN";
286 csinfo_.lrn_grad = "LRNGrad";
287 csinfo_.matmul = "MatMul";
288 csinfo_.max_pool = "MaxPool";
289 csinfo_.max_pool_grad = "MaxPoolGrad";
290 csinfo_.max_pool3d = "MaxPool3D";
291 csinfo_.max_pool3d_grad = "MaxPool3DGrad";
292 csinfo_.mkl_conv2d = "_MklConv2D";
293 csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
294 csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
295 csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
296 csinfo_.mkl_conv2d_grad_filter_with_bias =
297 "_MklConv2DBackpropFilterWithBias";
298 csinfo_.mkl_depthwise_conv2d_grad_input =
299 "_MklDepthwiseConv2dNativeBackpropInput";
300 csinfo_.mkl_depthwise_conv2d_grad_filter =
301 "_MklDepthwiseConv2dNativeBackpropFilter";
302 csinfo_.mkl_fused_batch_norm_ex = "_MklFusedBatchNormEx";
303 csinfo_.mkl_fused_conv2d = "_MklFusedConv2D";
304 csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative";
305 csinfo_.mkl_fused_matmul = "_MklFusedMatMul";
306 csinfo_.mkl_native_conv2d_with_bias = "_MklNativeConv2DWithBias";
307 csinfo_.mkl_native_conv2d_grad_filter_with_bias =
308 "_MklNativeConv2DBackpropFilterWithBias";
309 csinfo_.mkl_native_fused_batch_norm_ex = "_MklNativeFusedBatchNormEx";
310 csinfo_.mkl_native_fused_conv2d = "_MklNativeFusedConv2D";
311 csinfo_.mkl_native_fused_conv3d = "_MklNativeFusedConv3D";
312 csinfo_.mkl_native_fused_depthwise_conv2d =
313 "_MklNativeFusedDepthwiseConv2dNative";
314 csinfo_.mkl_native_fused_matmul = "_MklNativeFusedMatMul";
315 csinfo_.mkl_native_pad_with_conv2d = "_MklNativePadWithConv2D";
316 csinfo_.mkl_native_pad_with_fused_conv2d = "_MklNativePadWithFusedConv2D";
317 csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D";
318 csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D";
319 csinfo_.pad = "Pad";
320 csinfo_.pad_with_conv2d = "__MklDummyPadWithConv2D";
321 csinfo_.pad_with_fused_conv2d = "__MklDummyPadWithFusedConv2D";
322 csinfo_.quantized_avg_pool = "QuantizedAvgPool";
323 csinfo_.quantized_concatv2 = "QuantizedConcatV2";
324 csinfo_.quantized_conv2d = "QuantizedConv2D";
325 csinfo_.quantized_conv2d_per_channel = "QuantizedConv2DPerChannel";
326 csinfo_.quantized_conv2d_with_requantize = "QuantizedConv2DAndRequantize";
327 csinfo_.quantized_conv2d_with_bias = "QuantizedConv2DWithBias";
328 csinfo_.quantized_conv2d_with_bias_and_requantize =
329 "QuantizedConv2DWithBiasAndRequantize";
330 csinfo_.quantized_conv2d_and_relu = "QuantizedConv2DAndRelu";
331 csinfo_.quantized_conv2d_and_relu_and_requantize =
332 "QuantizedConv2DAndReluAndRequantize";
333 csinfo_.quantized_conv2d_with_bias_and_relu =
334 "QuantizedConv2DWithBiasAndRelu";
335 csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize =
336 "QuantizedConv2DWithBiasAndReluAndRequantize";
337 csinfo_.quantized_max_pool = "QuantizedMaxPool";
338 csinfo_.quantized_conv2d_with_bias_sum_and_relu =
339 "QuantizedConv2DWithBiasSumAndRelu";
340 csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize =
341 "QuantizedConv2DWithBiasSumAndReluAndRequantize";
342 csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize =
343 "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize";
344 csinfo_.quantized_matmul_with_bias = "QuantizedMatMulWithBias";
345 csinfo_.quantized_matmul_with_bias_and_relu =
346 "QuantizedMatMulWithBiasAndRelu";
347 csinfo_.quantized_matmul_with_bias_and_relu_and_requantize =
348 "QuantizedMatMulWithBiasAndReluAndRequantize";
349 csinfo_.quantized_matmul_with_bias_and_dequantize =
350 "QuantizedMatMulWithBiasAndDequantize";
351 csinfo_.quantized_matmul_with_bias_and_requantize =
352 "QuantizedMatMulWithBiasAndRequantize";
353 csinfo_.quantized_depthwise_conv2d = "QuantizedDepthwiseConv2D";
354 csinfo_.quantized_depthwise_conv2d_with_bias =
355 "QuantizedDepthwiseConv2DWithBias";
356 csinfo_.quantized_depthwise_conv2d_with_bias_and_relu =
357 "QuantizedDepthwiseConv2DWithBiasAndRelu";
358 csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize =
359 "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize";
360 csinfo_.quantize_v2 = "QuantizeV2";
361 csinfo_.relu = "Relu";
362 csinfo_.relu_grad = "ReluGrad";
363 csinfo_.relu6 = "Relu6";
364 csinfo_.relu6_grad = "Relu6Grad";
365 csinfo_.requantize = "Requantize";
366 csinfo_.tanh = "Tanh";
367 csinfo_.tanh_grad = "TanhGrad";
368 csinfo_.reshape = "Reshape";
369 csinfo_.slice = "Slice";
370 csinfo_.softmax = "Softmax";
371 csinfo_.split = "Split";
372 csinfo_.transpose = "Transpose";
373 // Element-wise ops. Ensure you also add any new ops to IsOpElementWise
374 // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
375 // MklInputConversion op is added before it.
376 csinfo_.add = "Add";
377 csinfo_.add_v2 = "AddV2";
378 csinfo_.maximum = "Maximum";
379 csinfo_.mul = "Mul";
380 csinfo_.squared_difference = "SquaredDifference";
381 csinfo_.sub = "Sub";
382 // End - element-wise ops. See note above.
383
384 const bool native_fmt = NativeFormatEnabled();
385 // NOTE: names are alphabetically sorted.
386 rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
387 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
388 rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
389 CopyAttrsAll, RewriteIfAtleastOneMklInput,
390 GetRewriteCause()});
391 rinfo_.push_back(
392 {csinfo_.add_v2, mkl_op_registry::GetMklOpName(csinfo_.add_v2),
393 CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
394 rinfo_.push_back({csinfo_.avg_pool,
395 mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
396 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
397 rinfo_.push_back({csinfo_.avg_pool_grad,
398 mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
399 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
400 rinfo_.push_back({csinfo_.avg_pool3d,
401 mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
402 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
403 rinfo_.push_back({csinfo_.avg_pool3d_grad,
404 mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
405 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
406 rinfo_.push_back({csinfo_.batch_matmul,
407 mkl_op_registry::GetMklOpName(csinfo_.batch_matmul),
408 CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
409 rinfo_.push_back({csinfo_.einsum,
410 mkl_op_registry::GetMklOpName(csinfo_.einsum),
411 CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
412 rinfo_.push_back({csinfo_.batch_matmul_v2,
413 mkl_op_registry::GetMklOpName(csinfo_.batch_matmul_v2),
414 CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
415 rinfo_.push_back({csinfo_.concat,
416 mkl_op_registry::GetMklOpName(csinfo_.concat),
417 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
418 rinfo_.push_back({csinfo_.concatv2,
419 mkl_op_registry::GetMklOpName(csinfo_.concatv2),
420 CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()});
421 rinfo_.push_back(
422 {csinfo_.conjugate_transpose,
423 mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose),
424 CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
425 rinfo_.push_back(
426 {csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d),
427 CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
428 rinfo_.push_back({csinfo_.conv2d_with_bias,
429 native_fmt ? csinfo_.mkl_native_conv2d_with_bias
430 : csinfo_.mkl_conv2d_with_bias,
431 CopyAttrsConvCheckConstFilter, AlwaysRewrite,
432 GetRewriteCause()});
433 rinfo_.push_back({csinfo_.conv2d_grad_filter,
434 mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
435 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
436 rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
437 native_fmt
438 ? csinfo_.mkl_native_conv2d_grad_filter_with_bias
439 : csinfo_.mkl_conv2d_grad_filter_with_bias,
440 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
441 rinfo_.push_back({csinfo_.conv2d_grad_input,
442 mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
443 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
444 rinfo_.push_back(
445 {csinfo_.conv3d, mkl_op_registry::GetMklOpName(csinfo_.conv3d),
446 CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()});
447 rinfo_.push_back({csinfo_.conv3d_grad_filter,
448 mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter),
449 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
450 rinfo_.push_back({csinfo_.conv3d_grad_input,
451 mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input),
452 CopyAttrsConv, AlwaysRewrite, GetRewriteCause()});
453 rinfo_.push_back({csinfo_.depthwise_conv2d,
454 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d),
455 CopyAttrsConvCheckConstFilter, AlwaysRewrite,
456 GetRewriteCause()});
457 rinfo_.push_back(
458 {csinfo_.depthwise_conv2d_grad_input,
459 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_input),
460 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
461 rinfo_.push_back(
462 {csinfo_.depthwise_conv2d_grad_filter,
463 mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_filter),
464 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
465 rinfo_.push_back(
466 {csinfo_.dequantize, mkl_op_registry::GetMklOpName(csinfo_.dequantize),
467 CopyAttrsAll, DequantizeRewrite, kRewriteForOpNameChange});
468 rinfo_.push_back({csinfo_.fused_batch_norm,
469 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
470 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
471 rinfo_.push_back(
472 {csinfo_.fused_batch_norm_grad,
473 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad),
474 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
475 rinfo_.push_back(
476 {csinfo_.fused_batch_norm_v2,
477 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v2),
478 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
479 rinfo_.push_back(
480 {csinfo_.fused_batch_norm_grad_v2,
481 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v2),
482 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
483
484 // Using CopyAttrsAll for V3 on CPU, as there are no additional
485 // attributes.
486 rinfo_.push_back(
487 {csinfo_.fused_batch_norm_v3,
488 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v3),
489 CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
490 rinfo_.push_back(
491 {csinfo_.fused_batch_norm_grad_v3,
492 mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3),
493 CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()});
494 rinfo_.push_back({csinfo_.fused_batch_norm_ex,
495 native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex
496 : csinfo_.mkl_fused_batch_norm_ex,
497 CopyAttrsAll, FusedBatchNormExRewrite,
498 GetRewriteCause()});
499 rinfo_.push_back({csinfo_.fused_conv2d,
500 native_fmt ? csinfo_.mkl_native_fused_conv2d
501 : csinfo_.mkl_fused_conv2d,
502 CopyAttrsFusedConv2DCheckConstFilter, FusedConv2DRewrite,
503 GetRewriteCause()});
504 rinfo_.push_back({csinfo_.fused_conv3d, csinfo_.mkl_native_fused_conv3d,
505 CopyAttrsAllCheckConstFilter, AlwaysRewrite,
506 kRewriteForOpNameChange});
507 rinfo_.push_back({csinfo_.fused_depthwise_conv2d,
508 native_fmt ? csinfo_.mkl_native_fused_depthwise_conv2d
509 : csinfo_.mkl_fused_depthwise_conv2d,
510 CopyAttrsAllCheckConstFilter, FusedDepthwiseConv2DRewrite,
511 GetRewriteCause()});
512 rinfo_.push_back({csinfo_.fused_matmul,
513 native_fmt ? csinfo_.mkl_native_fused_matmul
514 : csinfo_.mkl_fused_matmul,
515 CopyAttrsAllCheckConstFilter, FusedMatMulRewrite,
516 GetRewriteCause()});
517 rinfo_.push_back(
518 {csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity),
519 CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
520 rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn),
521 CopyAttrsAll, LrnRewrite, GetRewriteCause()});
522 rinfo_.push_back({csinfo_.lrn_grad,
523 mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
524 CopyAttrsAll, LrnGradRewrite, GetRewriteCause()});
525 rinfo_.push_back({csinfo_.matmul,
526 mkl_op_registry::GetMklOpName(csinfo_.matmul),
527 CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange});
528 rinfo_.push_back({csinfo_.leakyrelu,
529 mkl_op_registry::GetMklOpName(csinfo_.leakyrelu),
530 CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
531 rinfo_.push_back({csinfo_.leakyrelu_grad,
532 mkl_op_registry::GetMklOpName(csinfo_.leakyrelu_grad),
533 CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()});
534 rinfo_.push_back(
535 {csinfo_.max_pool, mkl_op_registry::GetMklOpName(csinfo_.max_pool),
536 CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
537 rinfo_.push_back({csinfo_.max_pool_grad,
538 mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
539 CopyAttrsAll, MaxpoolGradRewrite, GetRewriteCause()});
540 rinfo_.push_back(
541 {csinfo_.max_pool3d, mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
542 CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()});
543 rinfo_.push_back({csinfo_.max_pool3d_grad,
544 mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
545 CopyAttrsAll, Maxpool3DGradRewrite, GetRewriteCause()});
546 rinfo_.push_back(
547 {csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum),
548 CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
549 rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
550 CopyAttrsAll, RewriteIfAtleastOneMklInput,
551 GetRewriteCause()});
552 rinfo_.push_back({csinfo_.pad_with_conv2d,
553 native_fmt ? csinfo_.mkl_native_pad_with_conv2d
554 : csinfo_.mkl_pad_with_conv2d,
555 CopyAttrsAllCheckConstFilter, AlwaysRewrite,
556 GetRewriteCause()});
557 rinfo_.push_back({csinfo_.pad_with_fused_conv2d,
558 native_fmt ? csinfo_.mkl_native_pad_with_fused_conv2d
559 : csinfo_.mkl_pad_with_fused_conv2d,
560 CopyAttrsAllCheckConstFilter, AlwaysRewrite,
561 GetRewriteCause()});
562 rinfo_.push_back({csinfo_.quantized_avg_pool,
563 mkl_op_registry::GetMklOpName(csinfo_.quantized_avg_pool),
564 CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
565 rinfo_.push_back({csinfo_.quantized_concatv2,
566 mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2),
567 CopyAttrsAll, ConcatV2Rewrite, kRewriteForOpNameChange});
568 rinfo_.push_back({csinfo_.quantized_conv2d,
569 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d),
570 CopyAttrsQuantizedConv2D, AlwaysRewrite,
571 kRewriteForOpNameChange});
572 rinfo_.push_back(
573 {csinfo_.quantized_conv2d_per_channel,
574 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_per_channel),
575 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
576 rinfo_.push_back({csinfo_.quantized_conv2d_with_requantize,
577 mkl_op_registry::GetMklOpName(
578 csinfo_.quantized_conv2d_with_requantize),
579 CopyAttrsQuantizedConv2D, AlwaysRewrite,
580 kRewriteForOpNameChange});
581 rinfo_.push_back(
582 {csinfo_.quantized_conv2d_with_bias,
583 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_with_bias),
584 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
585 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_requantize,
586 mkl_op_registry::GetMklOpName(
587 csinfo_.quantized_conv2d_with_bias_and_requantize),
588 CopyAttrsQuantizedConv2D, AlwaysRewrite,
589 kRewriteForOpNameChange});
590 rinfo_.push_back(
591 {csinfo_.quantized_conv2d_and_relu,
592 mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_and_relu),
593 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
594 rinfo_.push_back({csinfo_.quantized_conv2d_and_relu_and_requantize,
595 mkl_op_registry::GetMklOpName(
596 csinfo_.quantized_conv2d_and_relu_and_requantize),
597 CopyAttrsQuantizedConv2D, AlwaysRewrite,
598 kRewriteForOpNameChange});
599 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_relu,
600 mkl_op_registry::GetMklOpName(
601 csinfo_.quantized_conv2d_with_bias_and_relu),
602 CopyAttrsQuantizedConv2D, AlwaysRewrite,
603 kRewriteForOpNameChange});
604 rinfo_.push_back(
605 {csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize,
606 mkl_op_registry::GetMklOpName(
607 csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize),
608 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
609 rinfo_.push_back({csinfo_.quantized_max_pool,
610 mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool),
611 CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
612 rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu,
613 mkl_op_registry::GetMklOpName(
614 csinfo_.quantized_conv2d_with_bias_sum_and_relu),
615 CopyAttrsQuantizedConv2D, AlwaysRewrite,
616 kRewriteForOpNameChange});
617 rinfo_.push_back(
618 {csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize,
619 mkl_op_registry::GetMklOpName(
620 csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize),
621 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
622 rinfo_.push_back(
623 {csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize,
624 mkl_op_registry::GetMklOpName(
625 csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize),
626 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
627 rinfo_.push_back(
628 {csinfo_.quantized_matmul_with_bias,
629 mkl_op_registry::GetMklOpName(csinfo_.quantized_matmul_with_bias),
630 CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
631 kRewriteForOpNameChange});
632 rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_relu,
633 mkl_op_registry::GetMklOpName(
634 csinfo_.quantized_matmul_with_bias_and_relu),
635 CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
636 kRewriteForOpNameChange});
637 rinfo_.push_back(
638 {csinfo_.quantized_matmul_with_bias_and_relu_and_requantize,
639 mkl_op_registry::GetMklOpName(
640 csinfo_.quantized_matmul_with_bias_and_relu_and_requantize),
641 CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
642 kRewriteForOpNameChange});
643 rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_requantize,
644 mkl_op_registry::GetMklOpName(
645 csinfo_.quantized_matmul_with_bias_and_requantize),
646 CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite,
647 kRewriteForOpNameChange});
648 rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_dequantize,
649 mkl_op_registry::GetMklOpName(
650 csinfo_.quantized_matmul_with_bias_and_dequantize),
651 CopyAttrsQuantizedMatMulWithBiasAndDequantize,
652 AlwaysRewrite, kRewriteForOpNameChange});
653 rinfo_.push_back(
654 {csinfo_.quantized_depthwise_conv2d,
655 mkl_op_registry::GetMklOpName(csinfo_.quantized_depthwise_conv2d),
656 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
657 rinfo_.push_back({csinfo_.quantized_depthwise_conv2d_with_bias,
658 mkl_op_registry::GetMklOpName(
659 csinfo_.quantized_depthwise_conv2d_with_bias),
660 CopyAttrsQuantizedConv2D, AlwaysRewrite,
661 kRewriteForOpNameChange});
662 rinfo_.push_back(
663 {csinfo_.quantized_depthwise_conv2d_with_bias_and_relu,
664 mkl_op_registry::GetMklOpName(
665 csinfo_.quantized_depthwise_conv2d_with_bias_and_relu),
666 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
667 rinfo_.push_back(
668 {csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize,
669 mkl_op_registry::GetMklOpName(
670 csinfo_
671 .quantized_depthwise_conv2d_with_bias_and_relu_and_requantize),
672 CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange});
673 rinfo_.push_back({csinfo_.quantize_v2,
674 mkl_op_registry::GetMklOpName(csinfo_.quantize_v2),
675 CopyAttrsAll, QuantizeOpRewrite,
676 kRewriteForOpNameChange});
677 rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
678 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
679 rinfo_.push_back({csinfo_.relu_grad,
680 mkl_op_registry::GetMklOpName(csinfo_.relu_grad),
681 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
682 rinfo_.push_back({csinfo_.relu6,
683 mkl_op_registry::GetMklOpName(csinfo_.relu6),
684 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
685 rinfo_.push_back({csinfo_.relu6_grad,
686 mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
687 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
688 rinfo_.push_back({csinfo_.requantize,
689 mkl_op_registry::GetMklOpName(csinfo_.requantize),
690 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
691 // Optimized TanhGrad support exists only in DNNL 1.x.
692 rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh),
693 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
694 rinfo_.push_back({csinfo_.tanh_grad,
695 mkl_op_registry::GetMklOpName(csinfo_.tanh_grad),
696 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
697 rinfo_.push_back({csinfo_.reshape,
698 mkl_op_registry::GetMklOpName(csinfo_.reshape),
699 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
700 rinfo_.push_back(
701 {csinfo_.slice, mkl_op_registry::GetMklOpName(csinfo_.slice),
702 CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()});
703 rinfo_.push_back({csinfo_.softmax,
704 mkl_op_registry::GetMklOpName(csinfo_.softmax),
705 CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
706
707 rinfo_.push_back({csinfo_.squared_difference,
708 mkl_op_registry::GetMklOpName(csinfo_.squared_difference),
709 CopyAttrsAll, RewriteIfAtleastOneMklInput,
710 GetRewriteCause()});
711 rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
712 CopyAttrsAll, RewriteIfAtleastOneMklInput,
713 GetRewriteCause()});
714 rinfo_.push_back({csinfo_.transpose,
715 mkl_op_registry::GetMklOpName(csinfo_.transpose),
716 CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
717
718 // Add info about which ops to add workspace edge to and the slots.
719 wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
720 wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
721 wsinfo_.push_back(
722 {csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3});
723
724 // Add a rule for merging nodes
725 minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
726 csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd});
727
728 // Merge Pad and Conv2d, only if the pad op is "Pad"
729 // Doesn't merge if pad op is "PadV2" or "MirrorPad"
730 minfo_.push_back(
731 {csinfo_.pad, csinfo_.conv2d, csinfo_.pad_with_conv2d, GetPadOrConv2D});
732
733 minfo_.push_back({csinfo_.pad, csinfo_.fused_conv2d,
734 csinfo_.pad_with_fused_conv2d, GetPadOrFusedConv2D});
735
736 minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad,
737 csinfo_.conv2d_grad_filter_with_bias,
738 GetConv2DBackpropFilterOrBiasAddGrad});
739
740 // The fusion patterns in "finfo_" that show up first will get applied
741 // first, for example, graph "A->B->C-D" and finfo_ is {A->B->C to ABC,
742 // A->B->C->D to ABCD}, since the first gets applied first, the final
743 // graph will be ABC->D.
744 }
745
746 // Standard interface to run pass
747 Status Run(const GraphOptimizationPassOptions& options);
748
749 // Helper function which does most of heavy lifting for rewriting
750 // Mkl nodes to propagate Mkl tensor as additional output
751 //
752 // Extracts common functionality between Run public interface and
753 // test interface.
754 //
755 // @return true, if and only if graph is mutated; false otherwise.
756 bool RunPass(std::unique_ptr<Graph>* g);
757
758 /// Cause for rewrite
759 /// Currently, we only support 2 causes - either for Mkl layout propagation
760 /// which is the most common case, or for just a name change (used in case
761 /// of ops like MatMul, Transpose, which do not support Mkl layout)
762 enum RewriteCause { kRewriteForLayoutPropagation, kRewriteForOpNameChange };
763
764 // Get the op rewrite cause depending on whether native format mode
765 // is enabled or not.
766 RewriteCause GetRewriteCause() {
767 if (NativeFormatEnabled()) {
768 return kRewriteForOpNameChange;
769 } else {
770 return kRewriteForLayoutPropagation;
771 }
772 }
773
774 /// Structure to specify the name of an original node, its new name after
775 /// rewrite, the number of inputs to the original node, the function to
776 /// be used to copy attributes for the op, and the rule (if any) which
777 /// must hold for rewriting the node
778 typedef struct {
779 string name; // Original name of op of the node in the graph
780 string new_name; // New name of the op of the node in the graph
781 // A function handler to copy attributes from an old node to a new node.
782 std::function<void(const Node*, NodeBuilder*, bool)> copy_attrs;
783 // A rule under which to rewrite this node
784 std::function<bool(const Node*)> rewrite_rule;
785 // Why are we rewriting?
786 RewriteCause rewrite_cause;
787 } RewriteInfo;
788
789 /// Structure to specify a forward op, a backward op, and the slot numbers
790 /// in the forward and backward ops where we will add a workspace edge.
791 typedef struct {
792 string fwd_op; // Name of a forward op in the graph
793 string bwd_op; // Name of a backward op in the graph
794 int fwd_slot; // Output slot in the forward op node where actual
795 // output tensor resides
796 int bwd_slot; // Input slot in the backward op node where actual
797 // input tensor resides
798 int ws_fwd_slot; // Output slot in the forward op node where workspace
799 // edge is added
800 int ws_bwd_slot; // Input slot in the backward op node where workspace
801 // edge is added
802 } WorkSpaceInfo;
803
804 /// Structure to specify information used in node merge of 2 operators
805 typedef struct {
806 string op1; // Node string for one operator.
807 string op2; // Node string for second operator.
808 string new_node; // Name of the node after merge
809 // Function that enables user of the node merger to specify how to find
810 // second operator given the first operator.
811 std::function<Node*(const Node*)> get_node_to_be_merged;
812 } MergeInfo;
813
814 // Structure to specify information used in node fusion of 3+ operators
815 typedef struct {
816 std::string pattern_name; // Name to describe this pattern, such as
817 // "Transpose_Mklop_Transpose".
818 std::vector<std::function<bool(const Node*)> >
819 node_checkers; // Extra restriction checker for these ops
820 std::function<Status(
821 std::unique_ptr<Graph>*, std::vector<Node*>&,
822 std::function<void(const Node*, NodeBuilder* nb, bool)>)>
823 fuse_func;
824 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs;
825 } FusionInfo;
826
827 //
828 // Dimension indices for 2D tensor.
829 //
830 struct NCHW {
831 enum dim { N = 0, C = 1, H = 2, W = 3 };
832 };
833
834 struct NHWC {
835 enum dim { N = 0, H = 1, W = 2, C = 3 };
836 };
837
838 //
839 // dimension indices for 3D tensor.
840 //
841 struct NCDHW {
842 enum dim { N = 0, C = 1, D = 2, H = 3, W = 4 };
843 };
844
845 struct NDHWC {
846 enum dim { N = 0, D = 1, H = 2, W = 3, C = 4 };
847 };
848
849 /// Structure to store all constant strings
850 /// NOTE: names are alphabetically sorted.
851 typedef struct {
852 string addn;
853 string add;
854 string add_v2;
855 string avg_pool;
856 string avg_pool_grad;
857 string avg_pool3d;
858 string avg_pool3d_grad;
859 string batch_matmul;
860 string batch_matmul_v2;
861 string bias_add;
862 string bias_add_grad;
863 string concat;
864 string concatv2;
865 string conjugate_transpose;
866 string conv2d;
867 string conv2d_with_bias;
868 string conv2d_grad_input;
869 string conv2d_grad_filter;
870 string conv2d_grad_filter_with_bias;
871 string conv3d;
872 string conv3d_grad_input;
873 string conv3d_grad_filter;
874 string depthwise_conv2d;
875 string depthwise_conv2d_grad_input;
876 string depthwise_conv2d_grad_filter;
877 string dequantize;
878 string einsum;
879 string fused_batch_norm;
880 string fused_batch_norm_grad;
881 string fused_batch_norm_ex;
882 string fused_batch_norm_v2;
883 string fused_batch_norm_grad_v2;
884 string fused_batch_norm_v3;
885 string fused_batch_norm_grad_v3;
886 string fused_conv2d;
887 string fused_conv3d;
888 string fused_depthwise_conv2d;
889 string fused_matmul;
890 string identity;
891 string leakyrelu;
892 string leakyrelu_grad;
893 string lrn;
894 string lrn_grad;
895 string matmul;
896 string max_pool;
897 string max_pool_grad;
898 string max_pool3d;
899 string max_pool3d_grad;
900 string maximum;
901 string mkl_conv2d;
902 string mkl_conv2d_grad_input;
903 string mkl_conv2d_grad_filter;
904 string mkl_conv2d_grad_filter_with_bias;
905 string mkl_conv2d_with_bias;
906 string mkl_depthwise_conv2d_grad_input;
907 string mkl_depthwise_conv2d_grad_filter;
908 string mkl_fused_batch_norm_ex;
909 string mkl_fused_conv2d;
910 string mkl_fused_depthwise_conv2d;
911 string mkl_fused_matmul;
912 string mkl_native_conv2d_with_bias;
913 string mkl_native_conv2d_grad_filter_with_bias;
914 string mkl_native_fused_batch_norm_ex;
915 string mkl_native_fused_conv2d;
916 string mkl_native_fused_conv3d;
917 string mkl_native_fused_depthwise_conv2d;
918 string mkl_native_fused_matmul;
919 string mkl_native_pad_with_conv2d;
920 string mkl_native_pad_with_fused_conv2d;
921 string mkl_pad_with_conv2d;
922 string mkl_pad_with_fused_conv2d;
923 string mul;
924 string pad;
925 string pad_with_conv2d;
926 string pad_with_fused_conv2d;
927 string quantized_avg_pool;
928 string quantized_conv2d;
929 string quantized_conv2d_per_channel;
930 string quantized_conv2d_with_requantize;
931 string quantized_conv2d_with_bias;
932 string quantized_conv2d_with_bias_and_requantize;
933 string quantized_conv2d_and_relu;
934 string quantized_conv2d_and_relu_and_requantize;
935 string quantized_conv2d_with_bias_and_relu;
936 string quantized_conv2d_with_bias_and_relu_and_requantize;
937 string quantized_concatv2;
938 string quantized_max_pool;
939 string quantized_conv2d_with_bias_sum_and_relu;
940 string quantized_conv2d_with_bias_sum_and_relu_and_requantize;
941 string quant_conv2d_with_bias_signed_sum_and_relu_and_requantize;
942 string quantized_matmul_with_bias;
943 string quantized_matmul_with_bias_and_relu;
944 string quantized_matmul_with_bias_and_relu_and_requantize;
945 string quantized_matmul_with_bias_and_requantize;
946 string quantized_matmul_with_bias_and_dequantize;
947 string quantized_depthwise_conv2d;
948 string quantized_depthwise_conv2d_with_bias;
949 string quantized_depthwise_conv2d_with_bias_and_relu;
950 string quantized_depthwise_conv2d_with_bias_and_relu_and_requantize;
951 string quantize_v2;
952 string relu;
953 string relu_grad;
954 string relu6;
955 string relu6_grad;
956 string requantize;
957 string tanh;
958 string tanh_grad;
959 string transpose;
960 string reshape;
961 string slice;
962 string softmax;
963 string split;
964 string squared_difference;
965 string sub;
966 } ConstStringsInfo;
967
968 private:
969 /// Maintain info about nodes to rewrite
970 std::vector<RewriteInfo> rinfo_;
971
972 /// Maintain info about nodes to add workspace edge
973 std::vector<WorkSpaceInfo> wsinfo_;
974
975 /// Maintain info about nodes to be merged
976 std::vector<MergeInfo> minfo_;
977
978 /// Maintain info about nodes to be fused
979 std::vector<FusionInfo> finfo_;
980
981 /// Maintain structure of constant strings
982 static ConstStringsInfo csinfo_;
983
984 private:
985 // Is OpDef::ArgDef a list type? It could be N * T or list(type).
986 // Refer to opdef.proto for details of list type.
987 inline bool ArgIsList(const OpDef::ArgDef& arg) const {
988 return !arg.type_list_attr().empty() || !arg.number_attr().empty();
989 }
990
991 // Get length of a list in 'n' if 'arg' is of list type. Refer to
992 // description of ArgIsList for definition of list type.
993 inline int GetTensorListLength(const OpDef::ArgDef& arg, const Node* n) {
994 CHECK_EQ(ArgIsList(arg), true);
995 int N = 0;
996 const string attr_name = !arg.type_list_attr().empty()
997 ? arg.type_list_attr()
998 : arg.number_attr();
999 if (!arg.type_list_attr().empty()) {
1000 std::vector<DataType> value;
1001 TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value));
1002 N = value.size();
1003 } else {
1004 TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N));
1005 }
1006 return N;
1007 }
1008
1009 // Can op represented by node 'n' run on DEVICE_CPU?
1010 // Op can run on CPU with MKL if the runtime assigned device or the
1011 // user requested device contains device CPU, or both are empty.
1012 bool CanOpRunOnCPUDevice(const Node* n) {
1013 bool result = true;
1014 string reason;
1015
1016 // Substring that should be checked for in device name for CPU device.
1017 const char* const kCPUDeviceSubStr = "CPU";
1018 const char* const kXLACPUDeviceSubStr = "XLA_CPU";
1019
1020 // If Op has been specifically assigned to a non-CPU or XLA_CPU device, then
1021 // No.
1022 if (!n->assigned_device_name().empty() &&
1023 (!absl::StrContains(n->assigned_device_name(), kCPUDeviceSubStr) ||
1024 absl::StrContains(n->assigned_device_name(), kXLACPUDeviceSubStr))) {
1025 result = false;
1026 reason = "Op has been assigned a runtime device that is not CPU.";
1027 }
1028
1029 // If user has specifically assigned this op to a non-CPU or XLA_CPU device,
1030 // then No.
1031 if (!n->def().device().empty() &&
1032 (!absl::StrContains(n->def().device(), kCPUDeviceSubStr) ||
1033 absl::StrContains(n->def().device(), kXLACPUDeviceSubStr))) {
1034 result = false;
1035 reason = "User has assigned a device that is not CPU.";
1036 }
1037
1038 if (result == false) {
1039 VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node "
1040 << n->type_string() << ", reason: " << reason;
1041 }
1042
1043 // Otherwise Yes.
1044 return result;
1045 }
1046
1047 // Return a node that can be merged with input node 'n'
1048 //
1049 // @return pointer to the node if we can find such a
1050 // node. Otherwise, it returns nullptr.
1051 Node* CheckForNodeMerge(const Node* n) const;
1052
1053 // Merge node 'm' with node 'n'.
1054 // Currently, we merge (1) Conv2D with BiasAdd, and (2) BiasAddGrad with
1055 // Conv2DBackpropFilter.
1056 //
1057 // Input nodes m and n may be deleted if the call to
1058 // this function is successful. Attempt to use the pointers
1059 // after the call to function may result in undefined behaviors.
1060 //
1061 // @input g - input graph, m - graph node, n - graph node to be merged with m
1062 // @return OkStatus(), if merging is successful and supported.
1063 // Returns appropriate Status error code otherwise.
1064 // Graph is updated in case nodes are merged. Otherwise, it is
1065 // not updated.
1066 Status MergeNode(std::unique_ptr<Graph>* g, Node* m, Node* n);
1067
1068 // Helper function to merge different nodes
1069 Status MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, Node* m, Node* n);
1070 Status MergePadWithConv2D(std::unique_ptr<Graph>* g, Node* m, Node* n);
1071 Status MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph>* g,
1072 Node* m, Node* n);
1073
1074 // Find BiasAdd or Conv2D node that can be merged with input node 'm'.
1075 // If input 'm' is BiasAdd, then check if there exists Conv2D node that can be
1076 // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd
1077 // node that can be merged with 'm'.
1078 static Node* GetConv2DOrBiasAdd(const Node* m) {
1079 DCHECK(m);
1080 Node* n = nullptr;
1081
1082 DataType T_m;
1083 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1084
1085 // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1086 if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1087
1088 if (m->type_string() == csinfo_.bias_add) {
1089 // If a is BiasAdd, then Conv2D is 0th input of BiasAdd.
1090 TF_CHECK_OK(m->input_node(0, &n));
1091 } else {
1092 CHECK_EQ(m->type_string(), csinfo_.conv2d);
1093 // Go over all output edges and search for BiasAdd Node.
1094 // 0th input of BiasAdd is Conv2D.
1095 for (const Edge* e : m->out_edges()) {
1096 if (!e->IsControlEdge() &&
1097 e->dst()->type_string() == csinfo_.bias_add &&
1098 e->dst_input() == 0) {
1099 n = e->dst();
1100 break;
1101 }
1102 }
1103 }
1104
1105 if (n == nullptr) {
1106 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1107 << "Conv2D and BiasAdd node for merging. Input node: "
1108 << m->DebugString();
1109 }
1110
1111 return n;
1112 }
1113
1114 // Find Pad or Conv2D node that can be merged with input node 'm'.
1115 // If input 'm' is Pad, then check if there exists Conv2D node that can be
1116 // merged with 'm'. If input 'm' is Conv2D, then check if there exists Pad
1117 // node that can be merged with 'm'.
1118 static Node* GetPadOrConv2D(const Node* m) {
1119 DCHECK(m);
1120 Node* n = nullptr;
1121
1122 DataType T_m;
1123 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1124
1125 // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1126 if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1127
1128 const Node* conv_node;
1129 if (m->type_string() == csinfo_.pad) {
1130 // If m is Pad, then Conv2D is the output of Pad.
1131 for (const Edge* e : m->out_edges()) {
1132 if (!e->IsControlEdge() && e->dst()->type_string() == csinfo_.conv2d) {
1133 n = e->dst();
1134 conv_node = n;
1135 break;
1136 }
1137 }
1138 } else {
1139 DCHECK_EQ(m->type_string(), csinfo_.conv2d);
1140 // If m is conv2D, Go over all input edges
1141 // and search for Pad Node.
1142 for (const Edge* e : m->in_edges()) {
1143 if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
1144 n = e->src();
1145 conv_node = m;
1146 break;
1147 }
1148 }
1149 }
1150 // Check if only VALID type of padding is used
1151 // or not.
1152 if (n != nullptr) {
1153 string padding;
1154 TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
1155 if (padding != "VALID")
1156 // Then do not merge.
1157 // Only VALID type of padding in conv op can be
1158 // merged with Pad op.
1159 n = nullptr;
1160 } else {
1161 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1162 << "Pad and Conv2D node for merging. Input node: "
1163 << m->DebugString();
1164 }
1165
1166 return n;
1167 }
1168
1169 // Find Pad or _FusedConv2D node that can be merged with input node 'm'.
1170 // If input 'm' is Pad, then check if there exists _FusedConv2D node that can
1171 // be merged with 'm'. If input 'm' is _FusedConv2D, then check if there
1172 // exists Pad node that can be merged with 'm'.
1173 static Node* GetPadOrFusedConv2D(const Node* m) {
1174 DCHECK(m);
1175 Node* n = nullptr;
1176
1177 const Node* conv_node;
1178 if (m->type_string() == csinfo_.pad) {
1179 // If m is Pad, then _FusedConv2D is the output of Pad.
1180 for (const Edge* e : m->out_edges()) {
1181 if (!e->IsControlEdge() &&
1182 e->dst()->type_string() == csinfo_.fused_conv2d) {
1183 n = e->dst();
1184 conv_node = n;
1185 break;
1186 }
1187 }
1188 } else {
1189 DCHECK_EQ(m->type_string(), csinfo_.fused_conv2d);
1190 // If m is _FusedConv2D, Go over all input edges
1191 // and search for Pad node.
1192 for (const Edge* e : m->in_edges()) {
1193 if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) {
1194 n = e->src();
1195 conv_node = m;
1196 break;
1197 }
1198 }
1199 }
1200 // Check if only VALID type of padding is used or not.
1201 if (n != nullptr) {
1202 string padding;
1203 string data_format;
1204 string filter_format;
1205 int num_host_args = 0;
1206 TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding", &padding));
1207 TF_CHECK_OK(GetNodeAttr(conv_node->def(), "data_format", &data_format));
1208 TF_CHECK_OK(
1209 GetNodeAttr(conv_node->def(), "filter_format", &filter_format));
1210 TF_CHECK_OK(
1211 GetNodeAttr(conv_node->def(), "num_host_args", &num_host_args));
1212
1213 if ((data_format != "NCHW" && data_format != "NHWC") ||
1214 (filter_format != "HWIO" && filter_format != "OIHW")) {
1215 n = nullptr;
1216 VLOG(1) << "MklLayoutRewritePass: Could match Pad and _FusedConv2D "
1217 << "nodes but cannot merge them. Only conv ops with NCHW or "
1218 << "NHWC data format and HWIO or OIHW filter format can be "
1219 << "merged with Pad op Input node: " << m->DebugString();
1220 } else if (padding != "VALID") {
1221 // Then do not merge.
1222 n = nullptr;
1223 VLOG(1) << "MklLayoutRewritePass: Could match Pad and _FusedConv2D "
1224 << "nodes but cannot merge them. Only conv ops with padding "
1225 << "type VALID can be merged with Pad op Input node: "
1226 << m->DebugString();
1227 } else if (num_host_args != 0) {
1228 n = nullptr;
1229 VLOG(1) << "MklLayoutRewritePass: Could match Pad and _FusedConv2D "
1230 << "nodes but cannot merge them. Only conv ops without host "
1231 << "args can be merged with Pad op Input node: "
1232 << m->DebugString();
1233 }
1234 } else {
1235 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1236 << "Pad and _FusedConv2D node for merging. Input node: "
1237 << m->DebugString();
1238 }
1239
1240 return n;
1241 }
1242
1243 // Find Conv2DBackpropFilter or BiasAddGrad node that can be merged with input
1244 // node 'm'. If input 'm' is Conv2DBackpropFilter, then check if there exists
1245 // BiasAddGrad node that can be merged with 'm'. If input 'm' is BiasAddGrad,
1246 // then check if there exists Conv2DBackpropFilter node that can be merged
1247 // with 'm'.
1248 //
1249 // Graph that will allow us to connect Conv2DBackpropFilter with BiasAddGrad
1250 // would look like:
1251 //
1252 // _ = Conv2DBackpropFilter(F, _, G)
1253 // _ = BiasAddGrad(G)
1254 //
1255 // So 1st input of BiasAddGrad connects with 3rd input of
1256 // Conv2DBackpropFilter and vice versa.
1257 static Node* GetConv2DBackpropFilterOrBiasAddGrad(const Node* m) {
1258 DCHECK(m);
1259 Node* n = nullptr;
1260 const Node* conv2d_backprop_filter = nullptr;
1261
1262 DataType T_m;
1263 TF_CHECK_OK(GetNodeAttr(m->def(), "T", &T_m));
1264
1265 // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16
1266 if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n;
1267
1268 if (m->type_string() == csinfo_.bias_add_grad) {
1269 // Get 1st input 'g' of BiasAddGrad.
1270 Node* g = nullptr;
1271 TF_CHECK_OK(m->input_node(0, &g));
1272 // Now traverse all outgoing edges from g that have destination node as
1273 // Conv2DBackpropFilter.
1274 for (const Edge* e : g->out_edges()) {
1275 if (!e->IsControlEdge() &&
1276 e->dst()->type_string() == csinfo_.conv2d_grad_filter &&
1277 e->dst_input() == 2 /* 3rd input of BackpropFilter */) {
1278 n = e->dst();
1279 conv2d_backprop_filter = n;
1280 break;
1281 }
1282 }
1283 } else {
1284 conv2d_backprop_filter = m;
1285 CHECK_EQ(m->type_string(), csinfo_.conv2d_grad_filter);
1286 // Get 3rd input 'g' of Conv2DBackpropFilter.
1287 Node* g = nullptr;
1288 TF_CHECK_OK(m->input_node(2, &g));
1289 // Now traverse all outgoing edges from g that have destination node as
1290 // BiasAddGrad.
1291 for (const Edge* e : g->out_edges()) {
1292 if (!e->IsControlEdge() &&
1293 e->dst()->type_string() == csinfo_.bias_add_grad &&
1294 e->dst_input() == 0 /* 1st input of BiasAddGrad */) {
1295 n = e->dst();
1296 break;
1297 }
1298 }
1299 }
1300
1301 // Do not merge if padding type is EXPLICIT.
1302 // TODO(intel): Support `EXPLICIT` padding for MklConv2DBackpropFilter.
1303 if (conv2d_backprop_filter != nullptr) {
1304 string padding;
1305 TF_CHECK_OK(
1306 GetNodeAttr(conv2d_backprop_filter->def(), "padding", &padding));
1307 if (padding == "EXPLICIT") {
1308 // Then do not merge.
1309 VLOG(1) << "MklLayoutRewritePass: Could match Conv2DBackpropFilter "
1310 << "and BiasAddGrad nodes but cannot merge them. "
1311 << "EXPLICIT padding is not supported now. "
1312 << conv2d_backprop_filter->DebugString();
1313 return nullptr;
1314 }
1315 }
1316
1317 if (n == nullptr) {
1318 VLOG(1) << "MklLayoutRewritePass: Could not find matching "
1319 << "Conv2DBackpropFilter and BiasAddGrad node for merging. "
1320 << "Input node: " << m->DebugString();
1321 }
1322 return n;
1323 }
1324
1325 // Return a node that can be fused with input node 'n'
1326 //
1327 // @return tuple. If we can find such nodes, the first
1328 // element of the tuple is a true. Otherwise, it's false.
1329 std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
1330 CheckForNodeFusion(Node* n) const;
1331
1332 // Fuse nodes in the vector "nodes"
1333 Status FuseNode(std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1334 const MklLayoutRewritePass::FusionInfo fi);
1335
1336 // Fuse transpose(to "NHWC") + mklop("NHWC") + transpose(to "NCHW") into
1337 // mklop("NCHW").
1338 // Here "mklop" can be any MKL-DNN supported op, such as Conv2D.
1339 static Status FuseTransposeMklOpTranspose(
1340 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
1341 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
1342 string data_format);
1343
1344 static bool CheckForTranspose(const Node* node, std::vector<int> perm) {
1345 // Check if node's type is "Transpose"
1346 if (node->type_string() != "Transpose") return false;
1347
1348 // If "Transpose" has multiple output data edges, also don't fuse it.
1349 if (node->num_outputs() > 1 || node->out_edges().size() > 1) return false;
1350
1351 // Check if has out control edge. If true, this is a training graph.
1352 // Currently we focus on inference and do no fusion in training.
1353 // Note: this constraint will eventually be removed, if we enabled this
1354 // fusion for training
1355 // in the future.
1356 for (const Edge* e : node->out_edges()) {
1357 if (e->IsControlEdge()) {
1358 return false;
1359 }
1360 }
1361
1362 // If "Transpose" has input control edges, don't fuse on it.
1363 for (const Edge* e : node->in_edges()) {
1364 if (e->IsControlEdge()) {
1365 return false;
1366 }
1367 }
1368
1369 // We compared the tensor containing the permutation order ("perm_node")
1370 // with our desired order ("perm"). If they're exactly match, this check
1371 // succeed and returns true.
1372 for (const Edge* e : node->in_edges()) {
1373 if (!e->IsControlEdge()) {
1374 const Node* perm_node = e->src();
1375
1376 const int kPermTensorIndex = 1;
1377 if (perm_node->type_string() == "Const" &&
1378 e->dst_input() == kPermTensorIndex) {
1379 // we find the "perm" node, now try to retrieve its value.
1380 const TensorProto* proto = nullptr;
1381 TF_CHECK_OK(GetNodeAttr(perm_node->def(), "value", &proto));
1382
1383 DataType type;
1384 TF_CHECK_OK(GetNodeAttr(perm_node->def(), "dtype", &type));
1385
1386 Tensor tensor;
1387 if (!tensor.FromProto(*proto)) {
1388 TF_CHECK_OK(errors::InvalidArgument(
1389 "Could not construct Tensor from TensorProto in node: ",
1390 node->name()));
1391 return false;
1392 }
1393 // Current fusion only supports 4D or 5D tensors according to `perm`
1394 // vector, return false otherwise.
1395 if (tensor.dim_size(0) != perm.size()) return false;
1396 DCHECK_EQ(tensor.dims(), 1);
1397 if (type == DT_INT32) {
1398 const auto tensor_content = tensor.flat<int>().data();
1399 for (int i = 0; i < perm.size(); ++i)
1400 if (tensor_content[i] != perm[i]) return false;
1401 return true;
1402 } else if (type == DT_INT64) {
1403 const auto tensor_content = tensor.flat<int64_t>().data();
1404 for (int i = 0; i < perm.size(); ++i)
1405 if (tensor_content[i] != perm[i]) return false;
1406 return true;
1407 }
1408 return false;
1409 }
1410 }
1411 }
1412 return false;
1413 }
1414
1415 static bool CheckForMklOp(const Node* node, string name = "") {
1416 if (node == nullptr) return false;
1417
1418 if (!name.empty() && node->type_string() != name) {
1419 return false;
1420 }
1421
1422 // if mklop has multiple outputs, don't fuse it.
1423 if (node->num_outputs() > 1) return false;
1424
1425 if (node->out_edges().size() > 1) return false;
1426
1427 DataType T;
1428 TF_CHECK_OK(GetNodeAttr(node->def(), "T", &T));
1429 return mkl_op_registry::IsMklOp(
1430 mkl_op_registry::GetMklOpName(node->type_string()), T);
1431 }
1432
1433 // Check if the node 'n' has any applicable rewrite rule
1434 // We check for 2 scenarios for rewrite.
1435 //
1436 // @return RewriteInfo* for the applicable rewrite rule
1437 const RewriteInfo* CheckForNodeRewrite(const Node* n) const;
1438 const RewriteInfo* CheckForQuantizedNodeRewrite(const Node* n) const;
1439
1440 // Default rewrite rule to be used in scenario 1 for rewrite.
1441 // @return - true (since we want to always rewrite)
1442 static bool AlwaysRewrite(const Node* n) { return true; }
1443
1444 // Rewrite rule which considers "context" of the current node to decide if we
1445 // should rewrite. By "context" we currently mean all the inputs of current
1446 // node. The idea is if none of the inputs of current node are not MKL nodes,
1447 // then rewriting current node to MKL node _may not_ offer any performance
1448 // improvement.
1449 //
1450 // One such case is element-wise ops. For such ops, we reuse the Eigen
1451 // implementation and pass the MKL metadata tensor through so we can avoid
1452 // conversions. However, if all incoming edges are in TF format, we don't
1453 // need all this overhead, so replace the elementwise node only if at least
1454 // one of its parents is a MKL node.
1455 //
1456 // More generally, all memory- or IO-bound ops (such as Identity) may fall
1457 // under this category.
1458 //
1459 // @input - Input graph node to be rewritten
1460 // @return - true if node is to be rewritten as MKL node; false otherwise.
1461 static bool RewriteIfAtleastOneMklInput(const Node* n) {
1462 DataType T;
1463 if (GetNodeAttr(n->def(), "T", &T).ok() &&
1464 mkl_op_registry::IsMklOp(
1465 mkl_op_registry::GetMklOpName(n->type_string()), T)) {
1466 for (auto e : n->in_edges()) {
1467 if (e->IsControlEdge()) continue;
1468 if (mkl_op_registry::IsMklOp(e->src())) {
1469 return true;
1470 }
1471 }
1472 }
1473 return false;
1474 }
1475
1476 static bool MatMulRewrite(const Node* n) {
1477 DataType T;
1478 TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T));
1479 if ((T == DT_FLOAT) || (T == DT_BFLOAT16)) {
1480 VLOG(2) << "Rewriting MatMul to _MklMatMul";
1481 return true;
1482 }
1483 return false;
1484 }
1485 // For oneDNN, only int32 is supported for axis data type
1486 static bool ConcatV2Rewrite(const Node* n) {
1487 DataType T;
1488 TF_CHECK_OK(GetNodeAttr(n->def(), "Tidx", &T));
1489 return (T == DT_INT32);
1490 }
1491
1492 static bool DequantizeRewrite(const Node* n) {
1493 DCHECK(n);
1494 Node* input = nullptr;
1495 TF_CHECK_OK(n->input_node(0, &input));
1496 string mode_string;
1497 int axis = -1;
1498 TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
1499 TF_CHECK_OK(GetNodeAttr(n->def(), "axis", &axis));
1500 if (mode_string != "SCALED") {
1501 VLOG(1) << "DequantizeRewrite: Mode is not SCALED. "
1502 << "This case is not optimized by Intel MKL kernel, thus using "
1503 "Eigen op for Dequantize op.";
1504 return false;
1505 }
1506 if (input->IsConstant()) {
1507 VLOG(1) << "DequantizeRewrite: Trying to dequantize a Const node which "
1508 << "could possibly be a filter. "
1509 << "This case is not supported by Intel MKL kernel, thus using "
1510 "Eigen op for Dequantize op.";
1511 return false;
1512 }
1513
1514 if (axis != -1) {
1515 VLOG(1) << "DequantizeRewrite: Using Eigen op for Dequantize op because "
1516 << "dimension is specified for per slice dequantization. ";
1517 return false;
1518 }
1519 return true;
1520 }
1521
1522 // Rewrite rule for _FusedMatMul.
1523 // @return - true (no transpose attribute for input 1);
1524 // false otherwise.
1525 static bool FusedMatMulRewrite(const Node* n) {
1526 bool trans_a;
1527
1528 // Do not rewrite with transpose attribute because reorder has performance
1529 // impact.
1530 TF_CHECK_OK(GetNodeAttr(n->def(), "transpose_a", &trans_a));
1531
1532 return !trans_a;
1533 }
1534
1535 // Check if we are performing pooling on depth or batch. If it is, then we
1536 // do not rewrite MaxPool node to Mkl version.
1537 // @return - true (if it is not a depth/batch wise pooling case);
1538 // false otherwise.
1539 static bool NonDepthBatchWisePoolRewrite(const Node* n) {
1540 DCHECK(n);
1541
1542 string data_format_str;
1543 TensorFormat data_format;
1544 std::vector<int32> ksize, strides;
1545 TF_CHECK_OK(GetNodeAttr(n->def(), "ksize", &ksize));
1546 TF_CHECK_OK(GetNodeAttr(n->def(), "strides", &strides));
1547 TF_CHECK_OK(GetNodeAttr(n->def(), "data_format", &data_format_str));
1548 bool result = FormatFromString(data_format_str, &data_format);
1549 DCHECK(result);
1550
1551 // Condition that specifies non-batch-wise and non-depth-wise pooling.
1552 if (GetTensorDim(ksize, data_format, 'N') == 1 &&
1553 GetTensorDim(strides, data_format, 'N') == 1 &&
1554 GetTensorDim(ksize, data_format, 'C') == 1 &&
1555 GetTensorDim(strides, data_format, 'C') == 1) {
1556 return true;
1557 }
1558
1559 return false;
1560 }
1561
1562 // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized
1563 // path. The unoptimized path is slow. Thus we don't rewrite the node
1564 // and use default Eigen. But for depth_radius=2, MKL DNN optimized
1565 // path is taken, i.e., eigen node is rewritten by MKl DNN node.
1566 static bool LrnRewrite(const Node* n) {
1567 DCHECK(n);
1568
1569 int depth_radius;
1570 TF_CHECK_OK(GetNodeAttr(n->def(), "depth_radius", &depth_radius));
1571
1572 // if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN
1573 // and use eigen node instead
1574 if (depth_radius == 2) {
1575 return true;
1576 }
1577 VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which"
1578 << "case is not optimized by Intel MKL, thus using Eigen op"
1579 << "for LRN ";
1580
1581 return false;
1582 }
1583
1584 static bool LrnGradRewrite(const Node* n) {
1585 DCHECK(n);
1586 bool do_rewrite = false;
1587
1588 for (const Edge* e : n->in_edges()) {
1589 // Rewrite only if there is corresponding LRN, i.e workspace is available
1590 if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 &&
1591 e->src()->type_string() ==
1592 mkl_op_registry::GetMklOpName(csinfo_.lrn) &&
1593 e->src_output() == 0) {
1594 do_rewrite = true;
1595 break;
1596 }
1597 }
1598 return do_rewrite;
1599 }
1600
1601 // MKL-DNN's LeakyRelu(feature) = feature (if feature > 0), or
1602 // feature * alpha (otherwise),
1603 // while TensorFlow's LeakyRelu(feature) = max(feature, feature * alpha).
1604 // These two algorithms are not consistent when alpha > 1,
1605 // so we only rewrite LeakyRelu to MKL OP when alpha <= 1.
1606 static bool LeakyReluRewrite(const Node* n) {
1607 DCHECK(n);
1608
1609 float alpha;
1610 bool has_attr = TryGetNodeAttr(n->def(), "alpha", &alpha);
1611 DCHECK(has_attr);
1612
1613 // If the alpha of LeakyRelu is less than 1, rewrite the node.
1614 // Otherwise eigen node is used instead.
1615 if (alpha <= 1) {
1616 return true;
1617 }
1618 VLOG(1) << "LeakyReluRewrite: The model sets alpha is greater than 1 "
1619 << "which case is not optimized by Intel MKL, thus using Eigen op"
1620 << "for LeakyRelu ";
1621
1622 return false;
1623 }
1624
1625 static bool QuantizeOpRewrite(const Node* n) {
1626 DCHECK(n);
1627 Node* filter_node = nullptr;
1628 TF_CHECK_OK(n->input_node(0, &filter_node));
1629 bool narrow_range = false;
1630 int axis = -1;
1631 string mode_string;
1632 string round_mode_string;
1633 DataType type;
1634 TryGetNodeAttr(n->def(), "narrow_range", &narrow_range);
1635 TryGetNodeAttr(n->def(), "axis", &axis);
1636 TF_CHECK_OK(GetNodeAttr(n->def(), "mode", &mode_string));
1637 TF_CHECK_OK(GetNodeAttr(n->def(), "round_mode", &round_mode_string));
1638 TF_CHECK_OK(GetNodeAttr(n->def(), "T", &type));
1639
1640 if (narrow_range) {
1641 VLOG(1) << "QuantizeOpRewrite: narrow range is enabled for quantization."
1642 << "This case is not optimized by Intel MKL, "
1643 << "thus using Eigen op for Quantize op ";
1644 return false;
1645 }
1646 if (axis != -1) {
1647 VLOG(1) << "QuantizeOpRewrite: dimension is specified for "
1648 << "per slice quantization."
1649 << "This case is not optimized by Intel MKL, "
1650 << "thus using Eigen op for Quantize op ";
1651 return false;
1652 }
1653 if (!((mode_string == "SCALED" && round_mode_string == "HALF_TO_EVEN") ||
1654 (mode_string == "MIN_FIRST"))) {
1655 VLOG(1) << "QuantizeOpRewrite: Mode is not SCALED or MIN_FIRST and/or"
1656 << "rounding mode is not HALF_TO_EVEN. "
1657 << "This case is not optimized by Intel MKL, thus using Eigen op"
1658 << "for Quantize op ";
1659 return false;
1660 }
1661 if (filter_node->IsConstant()) {
1662 VLOG(1) << "QuantizeOpRewrite: Trying to quantize a node which "
1663 << "is a constant. "
1664 << "This case is not supported by the kernel, thus using Eigen op"
1665 << "for Quantize op ";
1666
1667 return false;
1668 }
1669 if (mode_string == "MIN_FIRST") {
1670 if (type != DT_QUINT8) {
1671 VLOG(1) << "QuantizeOpRewrite: For MIN_FIRST mode the data type is "
1672 << "not DT_UINT8. This case is not optimized by Intel MKL, "
1673 << "thus using Eigen op for Quantize op ";
1674 return false;
1675 }
1676 }
1677 return true;
1678 }
1679
1680 static bool MaxpoolGradRewrite(const Node* n) {
1681 DCHECK(n);
1682 bool do_rewrite = false;
1683 for (const Edge* e : n->in_edges()) {
1684 // Rewrite only if there is corresponding Maxpool, i.e workspace is
1685 // available
1686 if (e->dst()->type_string() == csinfo_.max_pool_grad &&
1687 e->dst_input() == 1 &&
1688 e->src()->type_string() ==
1689 mkl_op_registry::GetMklOpName(csinfo_.max_pool) &&
1690 e->src_output() == 0) {
1691 do_rewrite = true;
1692 break;
1693 }
1694 }
1695 return do_rewrite;
1696 }
1697
1698 static bool Maxpool3DGradRewrite(const Node* n) {
1699 DCHECK(n);
1700 for (const Edge* e : n->in_edges()) {
1701 // Rewrite only if there is corresponding Maxpool3D, i.e., workspace is
1702 // available
1703 if (e->dst()->type_string() == csinfo_.max_pool3d_grad &&
1704 e->dst_input() == 1 &&
1705 e->src()->type_string() ==
1706 mkl_op_registry::GetMklOpName(csinfo_.max_pool3d) &&
1707 e->src_output() == 0) {
1708 return true;
1709 }
1710 }
1711 return false;
1712 }
1713
1714 static bool FusedBatchNormV3Rewrite(const Node* n) {
1715 DCHECK(n);
1716 if (Check5DFormat(n->def())) {
1717 VLOG(1) << "Graph Rewrite: FusedBatchNorm(Grad)V3 op currently does not "
1718 << "support 5D tensors.";
1719 return false;
1720 }
1721 return true;
1722 }
1723
1724 static bool FusedBatchNormExRewrite(const Node* n) {
1725 DCHECK(n);
1726
1727 int num_side_inputs;
1728 TF_CHECK_OK(GetNodeAttr(n->def(), "num_side_inputs", &num_side_inputs));
1729 string activation_mode;
1730 TF_CHECK_OK(GetNodeAttr(n->def(), "activation_mode", &activation_mode));
1731
1732 // if the num_side_inputs is not 0, don't rewrite the node.
1733 if (num_side_inputs != 0) {
1734 VLOG(1) << "FusedBatchNormExRewrite: The model sets num_side_inputs"
1735 << "larger than 0 is not optimized by Intel MKL.";
1736 return false;
1737 }
1738
1739 // if the activation_mode is not 'Relu', don't rewrite the node.
1740 if (activation_mode != "Relu") {
1741 VLOG(1) << "FusedBatchNormExRewrite: Only Relu activation mode is"
1742 << "supported by Intel MKL.";
1743 return false;
1744 }
1745
1746 return true;
1747 }
1748
1749 static bool FusedConv2DRewrite(const Node* n) {
1750 // MKL DNN currently doesn't support all fusions that grappler fuses
1751 // together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if
1752 // it includes those we support.
1753 DataType T;
1754 if (!TryGetNodeAttr(n->def(), "T", &T) ||
1755 !mkl_op_registry::IsMklOp(NativeFormatEnabled()
1756 ? csinfo_.mkl_native_fused_conv2d
1757 : csinfo_.mkl_fused_conv2d,
1758 T)) {
1759 return false;
1760 }
1761
1762 string data_format;
1763 string filter_format;
1764 int num_host_args = 0;
1765 TF_CHECK_OK(GetNodeAttr(n->def(), "data_format", &data_format));
1766 TF_CHECK_OK(GetNodeAttr(n->def(), "filter_format", &filter_format));
1767 TF_CHECK_OK(GetNodeAttr(n->def(), "num_host_args", &num_host_args));
1768 if ((data_format != "NCHW" && data_format != "NHWC") ||
1769 (filter_format != "HWIO" && filter_format != "OIHW") ||
1770 (num_host_args != 0)) {
1771 return false;
1772 }
1773
1774 std::vector<string> fused_ops;
1775 TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
1776 return (fused_ops == std::vector<string>{"BiasAdd"} ||
1777 fused_ops == std::vector<string>{"Relu"} ||
1778 fused_ops == std::vector<string>{"Relu6"} ||
1779 fused_ops == std::vector<string>{"Elu"} ||
1780 fused_ops == std::vector<string>{"BiasAdd", "Relu"} ||
1781 fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
1782 fused_ops == std::vector<string>{"BiasAdd", "Elu"} ||
1783 fused_ops == std::vector<string>{"BiasAdd", "Add"} ||
1784 fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"} ||
1785 fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"} ||
1786 fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"} ||
1787 fused_ops == std::vector<string>{"LeakyRelu"} ||
1788 fused_ops == std::vector<string>{"BiasAdd", "LeakyRelu"} ||
1789 fused_ops == std::vector<string>{"BiasAdd", "Add", "LeakyRelu"} ||
1790 fused_ops == std::vector<string>{"FusedBatchNorm"} ||
1791 fused_ops == std::vector<string>{"FusedBatchNorm", "Relu"} ||
1792 fused_ops == std::vector<string>{"FusedBatchNorm", "Relu6"} ||
1793 fused_ops == std::vector<string>{"FusedBatchNorm", "Elu"} ||
1794 fused_ops == std::vector<string>{"FusedBatchNorm", "LeakyRelu"});
1795 }
1796
1797 static bool FusedDepthwiseConv2DRewrite(const Node* n) {
1798 // MKL DNN currently doesn't support all fusions that grappler fuses
1799 // together with DepthwiseConv2D (ex. batchnorm). We rewrite
1800 // _FusedDepthwiseConv2DNative only if it includes those we support.
1801 DataType T;
1802 if (!TryGetNodeAttr(n->def(), "T", &T) ||
1803 !mkl_op_registry::IsMklOp(
1804 NativeFormatEnabled() ? csinfo_.mkl_native_fused_depthwise_conv2d
1805 : csinfo_.mkl_fused_depthwise_conv2d,
1806 T)) {
1807 return false;
1808 }
1809
1810 std::vector<string> fused_ops;
1811 TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops));
1812 return (fused_ops == std::vector<string>{"BiasAdd"} ||
1813 fused_ops == std::vector<string>{"BiasAdd", "Relu"} ||
1814 fused_ops == std::vector<string>{"BiasAdd", "Relu6"} ||
1815 fused_ops == std::vector<string>{"BiasAdd", "Elu"});
1816 }
1817
1818 // Rewrites input node to a new node specified by its matching rewrite info.
1819 //
1820 // Method first searches matching rewrite info for input node and then
1821 // uses that info to rewrite.
1822 //
1823 // Input node may be deleted in case of rewrite. Attempt to use the node
1824 // after the call can result in undefined behaviors.
1825 //
1826 // @input g - input graph, n - Node to be rewritten,
1827 // ri - matching rewriteinfo
1828 // @return OkStatus(), if the input node is rewritten;
1829 // Returns appropriate Status error code otherwise.
1830 // Graph is updated in case the input node is rewritten.
1831 // Otherwise, it is not updated.
1832 Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri);
1833
1834 // Rewrites input node to just change its operator name. The number of
1835 // inputs to the node and the number of outputs remain the same. Attributes
1836 // of the new node could be copied from attributes of the old node or
1837 // modified. copy_attrs field of RewriteInfo controls this.
1838 //
1839 // Conceptually, it allows us to rewrite:
1840 //
1841 // f[a=v1,b=v2](x,y) -> g[a'=v3,b'=v4](x,y)
1842 //
1843 // Attributes can be altered without any restrictions --- they could be
1844 // copied, modified, or deleted completely.
1845 //
1846 // @input g - input graph, orig_node - Node to be rewritten,
1847 // ri - matching rewriteinfo
1848 // @output new_node - points to newly created node
1849 // @return OkStatus(), if the input node is rewritten;
1850 // Returns appropriate Status error code otherwise.
1851 // Graph is only updated when the input node is rewritten.
1852 Status RewriteNodeForJustOpNameChange(std::unique_ptr<Graph>* g,
1853 const Node* orig_node, Node** new_node,
1854 const RewriteInfo* ri);
1855
1856 // Rewrites input node to enable MKL layout propagation. Please also refer to
1857 // documentation for the function RewriteNodeForJustOpNameChange() to
1858 // understand what it means.
1859 //
1860 // @input g - input graph, orig_node - Node to be rewritten,
1861 // ri - matching rewriteinfo
1862 // @output new_node - points to newly created node
1863 // @return OkStatus(), if the input node is rewritten;
1864 // Returns appropriate Status error code otherwise.
1865 // Graph is updated in case the input node is rewritten.
1866 // Otherwise, it is not updated.
1867 Status RewriteNodeForLayoutPropagation(std::unique_ptr<Graph>* g,
1868 const Node* orig_node, Node** new_node,
1869 const RewriteInfo* ri);
1870
1871 // Get nodes that will feed a list of TF tensors to the new
1872 // node that we are constructing.
1873 //
1874 // @input g - input graph,
1875 // @input inputs - inputs to old node that we are using for constructing
1876 // new inputs,
1877 // @input input_idx - the index in the 'inputs' vector pointing to the
1878 // current input that we have processed so far
1879 // @output input_idx - index will be incremented by the number of nodes
1880 // from 'inputs' that are processed
1881 // @input list_length - The expected length of list of TF tensors
1882 // @output output_nodes - the list of new nodes creating TF tensors
1883 //
1884 // @return None
1885 void GetNodesProducingTFTensorList(
1886 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1887 int* input_idx, int list_length,
1888 std::vector<NodeBuilder::NodeOut>* output_nodes);
1889
1890 // Get nodes that will feed a list of Mkl tensors to the new
1891 // node that we are constructing.
1892 //
1893 // @input g - input graph,
1894 // @input orig_node - Original node that we are rewriting
1895 // @input inputs - inputs to old node that we are using for constructing
1896 // new inputs,
1897 // @input input_idx - the index in the 'inputs' vector pointing to the
1898 // current input that we have processed so far
1899 // @output input_idx - index will be incremented by the number of nodes
1900 // from 'inputs' that are processed
1901 // @input list_length - The expected length of list of Mkl tensors
1902 // @output output_nodes - the list of new nodes creating Mkl tensors
1903 //
1904 // @return None
1905 void GetNodesProducingMklTensorList(
1906 std::unique_ptr<Graph>* g, const Node* orig_node,
1907 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1908 int* input_idx, int list_length,
1909 std::vector<NodeBuilder::NodeOut>* output_nodes);
1910
1911 // Get a node that will feed an Mkl tensor to the new
1912 // node that we are constructing. The output node could be (1) 'n'
1913 // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
1914 // if 'n' is not an Mkl layer.
1915 //
1916 // @input g - input graph,
1917 // @input orig_node - Original node that we are rewriting,
1918 // @input n - Node based on which we are creating Mkl node,
1919 // @input n_output_slot - the output slot of node 'n'
1920 // which is feeding to the node that we are constructing
1921 // @output mkl_node - the new node that will feed Mkl tensor
1922 // @output mkl_node_output_slot - the slot number of mkl_node that
1923 // will feed the tensor
1924 // @return None
1925 void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
1926 const Node* orig_node, Node* n,
1927 int n_output_slot, Node** mkl_node,
1928 int* mkl_node_output_slot);
1929
1930 // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1931 // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
1932 // set up in contiguous fashion. 'workspace_tensors' carry graph nodes
1933 // producing workspace edges if 'are_workspace_tensors_available' is true.
1934 // Otherwise, 'workspace_tensors' is empty vector.
1935 //
1936 // For details, refer to 'Ordering of inputs after rewriting' section in the
1937 // documentation above.
1938 //
1939 // Returns OkStatus() if setting up inputs is successful, otherwise
1940 // returns appropriate status code.
1941 int SetUpContiguousInputs(
1942 std::unique_ptr<Graph>* g,
1943 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
1944 NodeBuilder* nb, const Node* old_node,
1945 std::vector<NodeBuilder::NodeOut>* workspace_tensors,
1946 bool are_workspace_tensors_available);
1947
1948 // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
1949 // in graph 'g'. Original node is input in 'orig_node'.
1950 //
1951 // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors'
1952 // section in the documentation above.
1953 //
1954 // Returns OkStatus() if setting up inputs is successful, otherwise
1955 // returns appropriate status code.
1956 Status SetUpInputs(std::unique_ptr<Graph>* g,
1957 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1958 NodeBuilder* nb, const Node* orig_node);
1959
1960 // Create new inputs by copying old inputs 'inputs' for the rewritten node
1961 // in 'nb' in graph 'g'. Original node is input in 'orig_node'. This is mostly
1962 // used in the context of rewrite for just operator name change in which
1963 // inputs of old operator and new operator are same.
1964 //
1965 // Returns OkStatus() if setting up inputs is successful, otherwise
1966 // returns appropriate status code.
1967 Status CopyInputs(const Node* orig_node,
1968 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
1969 NodeBuilder* nb);
1970
1971 // Add workspace edge on the input or output side of Node 'orig_node' by using
1972 // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate
1973 // adding workspace edge then do not add it. Workspace Tensorflow and Mkl
1974 // tensors, if they need to be added, will be set into these tensors.
1975 // If we set workspace tensors, then are_ws_tensors_added should be true.
1976 void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
1977 const Node* orig_node, NodeBuilder* nb,
1978 std::vector<NodeBuilder::NodeOut>* ws_tensors,
1979 bool* are_ws_tensors_added);
1980
1981 // Helper function used by FixMklMetaDataEdges. Fixes the metadata edge
1982 // pointed by 'e_metadata' corresponding to the data edge 'e_data' in graph
1983 // 'g'. Returns true if fixup was done; otherwise, it returns false.
1984 bool FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, const Edge* e_data,
1985 const Edge* e_metadata);
1986
1987 // Are the input Mkl metadata edges for node 'n' in graph 'g' correctly
1988 // connected? If not, then fix them. This is needed because a graph may have
1989 // some input Mkl metadata edges incorrectly setup after node merge and
1990 // rewrite passes. This could happen because GetReversePostOrder function may
1991 // not provide topologically sorted order if a graph contains cycles. The
1992 // function returns true if at least one Mkl metadata edge for node 'n' was
1993 // fixed. Otherwise, it returns false.
1994 //
1995 // Example:
1996 //
1997 // X = MklConv2D(_, _, _)
1998 // Y = MklConv2DWithBias(_, _, _, _, _, _)
1999 // Z = MklAdd(X, Y, DummyMklTensor, Y:1)
2000 //
2001 // For a graph such as shown above, note that 3rd argument of MklAdd contains
2002 // DummyMklTensor. Actually, it should be getting the Mkl metadata from
2003 // MklConv2D op (specifically, X:2). This incorrect plumbing could be possible
2004 // (although rare) if the Mkl NodeMerge + NodeRewrite passes visit Z before X
2005 // (possible if X, Y, Z are part of a loop.) This function fixes the Mkl
2006 // metadata edges only - it does not rewrite nodes nor does it modify the Mkl
2007 // data edges (1st and 2nd arguments of MklAdd).
2008 bool FixMklMetaDataEdges(std::unique_ptr<Graph>* g, Node* n);
2009
2010 // Functions specific to operators to copy attributes
2011 // We need operator-specific function to copy attributes because the framework
2012 // does not provide any generic function for it.
2013 // NOTE: names are alphabetically sorted.
2014 static void CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
2015 bool change_format = false);
2016 static void CopyAttrsAllCheckConstFilter(const Node* orig_node,
2017 NodeBuilder* nb,
2018 bool change_format = false);
2019
2020 static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
2021 bool change_format = false);
2022 static void CopyAttrsConvCheckConstFilter(const Node* orig_node,
2023 NodeBuilder* nb,
2024 bool change_format = false);
2025 static void CopyAttrsFusedConv2DCheckConstFilter(const Node* orig_node,
2026 NodeBuilder* nb,
2027 bool change_format = false);
2028 static void CopyAttrsFromPadAndConv2D(const Node* orig_node1,
2029 const Node* orig_node2, NodeBuilder* nb,
2030 bool change_format = false);
2031 static void CopyAttrsFromPadAndFusedConv2D(const Node* orig_node1,
2032 const Node* orig_node2,
2033 NodeBuilder* nb,
2034 bool change_format = false);
2035 static void CopyAttrsQuantizedConv2D(const Node* orig_node, NodeBuilder* nb,
2036 bool change_format = false);
2037 static void CopyFormatAttrsConv(const Node* orig_node, NodeBuilder* nb,
2038 const std::vector<int32>& strides,
2039 const std::vector<int32>& dilations,
2040 bool change_format = false);
2041
2042 static void CopyAttrsQuantizedMatMulWithBias(const Node* orig_node,
2043 NodeBuilder* nb,
2044 bool change_format = false);
2045 static void CopyAttrsQuantizedMatMulWithBiasAndDequantize(
2046 const Node* orig_node, NodeBuilder* nb, bool change_format = false);
2047 static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb,
2048 bool change_format = false);
2049
2050 // Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
2051 // using node for original node 'orig_node' and return it in '*out'.
2052 // TODO(nhasabni) We should move this to mkl_util.h
2053 void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out,
2054 const Node* orig_node);
2055 void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out,
2056 const Node* orig_node);
2057};
2058
2059MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
2060
2061// We register Mkl rewrite pass for phase 1 in post partitioning group.
2062// We register it here so that we get a complete picture of all users of Mkl
2063// nodes. Do not change the ordering of the Mkl passes.
2064const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
2065 OptimizationPassRegistry::POST_PARTITIONING;
2066REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
2067
2068//////////////////////////////////////////////////////////////////////////
2069// Helper functions for creating new node
2070//////////////////////////////////////////////////////////////////////////
2071
2072static void FillInputs(const Node* n,
2073 gtl::InlinedVector<Node*, 4>* control_edges,
2074 gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
2075 control_edges->clear();
2076 for (const Edge* e : n->in_edges()) {
2077 if (e->IsControlEdge()) {
2078 control_edges->push_back(e->src());
2079 } else {
2080 (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
2081 }
2082 }
2083 std::sort(control_edges->begin(), control_edges->end());
2084}
2085
2086void MklLayoutRewritePass::GetNodesProducingTFTensorList(
2087 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
2088 int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
2089 CHECK_LT(*input_idx, inputs.size());
2090 CHECK_GT(list_length, 0);
2091 DCHECK(output_nodes);
2092 output_nodes->reserve(list_length);
2093
2094 while (list_length != 0) {
2095 CHECK_GT(list_length, 0);
2096 CHECK_LT(*input_idx, inputs.size());
2097 Node* n = inputs[*input_idx].first;
2098 int slot = inputs[*input_idx].second;
2099 // If input node 'n' is just producing a single tensor at
2100 // output slot 'slot' then we just add that single node.
2101 output_nodes->push_back(NodeBuilder::NodeOut(n, slot));
2102 (*input_idx)++;
2103 list_length--;
2104 }
2105}
2106
2107// TODO(nhasabni) We should move this to mkl_util.h.
2108void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
2109 Node** out,
2110 const Node* orig_node) {
2111 // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent
2112 // dummy Mkl tensor. 8 = 2*size_t.
2113 const DataType dt = DataTypeToEnum<uint8>::v();
2114 TensorProto proto;
2115 proto.set_dtype(dt);
2116 uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
2117 proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 8));
2118 TensorShape dummy_shape({8});
2119 dummy_shape.AsProto(proto.mutable_tensor_shape());
2120 TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
2121 .Attr("value", proto)
2122 .Attr("dtype", dt)
2123 .Device(orig_node->def().device()) // We place this node on
2124 // the same device as the
2125 // device of the original
2126 // node.
2127 .Finalize(&**g, out));
2128 DCHECK(*out); // Make sure we got a valid object before using it
2129
2130 // If number of inputs to the original node is > 0, then we add
2131 // control dependency between 1st input (index 0) of the original node and
2132 // the dummy Mkl node. This is needed because control-flow ops such as Enter,
2133 // Merge, etc, require frame_name of the dummy Mkl node to be same as the
2134 // rewritten node. Adding control edge between 1st input of the original node
2135 // and the dummy Mkl node ensures that the dummy node is in the same frame
2136 // as the original node. Choosing 1st input is not necessary - any input of
2137 // the original node is fine because all the inputs of a node are always in
2138 // the same frame.
2139 if (orig_node->num_inputs() > 0) {
2140 Node* orig_input0 = nullptr;
2141 TF_CHECK_OK(
2142 orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
2143 auto edge = (*g)->AddControlEdge(orig_input0, *out, false);
2144 DCHECK(edge != nullptr || DoesControlEdgeExist(orig_input0, *out));
2145 }
2146
2147 (*out)->set_assigned_device_name(orig_node->assigned_device_name());
2148}
2149
2150void MklLayoutRewritePass::GetNodesProducingMklTensorList(
2151 std::unique_ptr<Graph>* g, const Node* orig_node,
2152 const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
2153 int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
2154 CHECK_LT(*input_idx, inputs.size());
2155 CHECK_GT(list_length, 0);
2156 DCHECK(output_nodes);
2157 output_nodes->reserve(list_length);
2158
2159 while (list_length != 0) {
2160 CHECK_GT(list_length, 0);
2161 CHECK_LT(*input_idx, inputs.size());
2162 Node* n = inputs[*input_idx].first;
2163 int slot = inputs[*input_idx].second;
2164 // If 'n' is producing a single tensor, then create a single Mkl tensor
2165 // node.
2166 Node* mkl_node = nullptr;
2167 int mkl_node_output_slot = 0;
2168 GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node,
2169 &mkl_node_output_slot);
2170 output_nodes->push_back(
2171 NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
2172 (*input_idx)++;
2173 list_length--;
2174 }
2175}
2176
2177// Get an input node that will feed Mkl tensor to the new
2178// node that we are constructing. An input node could be (1) 'n'
2179// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
2180// if 'n' is not an Mkl layer.
2181void MklLayoutRewritePass::GetNodeProducingMklTensor(
2182 std::unique_ptr<Graph>* g, const Node* orig_node, Node* n,
2183 int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) {
2184 DCHECK(n);
2185 DCHECK(mkl_node);
2186 DCHECK(mkl_node_output_slot);
2187
2188 // If this is an MKL op, then it will create extra output for MKL layout.
2189 DataType T;
2190 if (TryGetNodeAttr(n->def(), "T", &T) &&
2191 mkl_op_registry::IsMklOp(n->type_string(), T, false)) {
2192 // If this is an MKL op, then it will generate an edge that will receive
2193 // Mkl tensor from a node.
2194 // output slot number for Mkl tensor would be N+slot number of TensorFlow
2195 // tensor, where N is total number of TensorFlow tensors.
2196 *mkl_node = n;
2197 *mkl_node_output_slot =
2198 GetTensorMetaDataIndex(n_output_slot, n->num_outputs());
2199 } else {
2200 // If we have not visited the node and rewritten it, then we need
2201 // to create a dummy node that will feed a dummy Mkl tensor to this node.
2202 // DummyMklTensor node has no input and generates only 1 output
2203 // (dummy Mkl tensor) as output slot number 0.
2204 GetDummyMklTensorNode(g, mkl_node, orig_node);
2205 DCHECK(*mkl_node);
2206 *mkl_node_output_slot = 0;
2207 }
2208}
2209
2210int MklLayoutRewritePass::SetUpContiguousInputs(
2211 std::unique_ptr<Graph>* g,
2212 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2213 NodeBuilder* nb, const Node* old_node,
2214 std::vector<NodeBuilder::NodeOut>* workspace_tensors,
2215 bool are_workspace_tensors_available) {
2216 DCHECK(workspace_tensors);
2217 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
2218
2219 // TODO(nhasabni): Temporary solution to connect filter input of
2220 // BackpropInput with the converted filter from Conv2D.
2221 bool do_connect_conv2d_backprop_input_filter = false;
2222 Node* conv2d_node = nullptr;
2223 // Filter node is 2nd input (slot index 1) of Conv2D.
2224 int kConv2DFilterInputSlotIdx = 1;
2225 int kConv2DBackpropInputFilterInputSlotIdx = 1;
2226 int kConv2DFilterOutputSlotIdx = 1;
2227 if (old_node->type_string() == csinfo_.conv2d_grad_input) {
2228 // We need to find Conv2D node from Conv2DBackpropInput.
2229 // For that let's first find filter node that is 2nd input (slot 1)
2230 // of BackpropInput.
2231 Node* filter_node = nullptr;
2232 TF_CHECK_OK(old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx,
2233 &filter_node));
2234 DCHECK(filter_node);
2235
2236 // Now check which nodes receive from filter_node. Filter feeds as
2237 // 2nd input (slot 1) of _MklConv2D, _MklConv2DWithBias, and
2238 // _MklFusedConv2D.
2239 for (const Edge* e : filter_node->out_edges()) {
2240 if ((e->dst()->type_string() == csinfo_.mkl_conv2d ||
2241 e->dst()->type_string() == csinfo_.mkl_pad_with_conv2d ||
2242 e->dst()->type_string() == csinfo_.mkl_pad_with_fused_conv2d ||
2243 e->dst()->type_string() == csinfo_.mkl_conv2d_with_bias ||
2244 e->dst()->type_string() == csinfo_.mkl_fused_conv2d) &&
2245 e->dst_input() == kConv2DFilterInputSlotIdx
2246 /* filter is 2nd input of Conv2D and _MklConv2D. */) {
2247 if (conv2d_node != nullptr) {
2248 VLOG(1) << "MklLayoutRewritePass: unusual case of same filter"
2249 << " feeding multiple Conv2D nodes: "
2250 << filter_node->DebugString();
2251 // We will not connect filter input of Conv2DBackpropInput
2252 // to be safe here.
2253 do_connect_conv2d_backprop_input_filter = false;
2254 break;
2255 } else {
2256 conv2d_node = e->dst();
2257 do_connect_conv2d_backprop_input_filter = true;
2258 }
2259 }
2260 }
2261 }
2262
2263 // Number of input slots to original op
2264 // Input slots are represented by .Input() calls in REGISTER_OP.
2265 int old_node_input_slots = old_node->op_def().input_arg_size();
2266 int nn_slot_idx = 0; // slot index for inputs of new node
2267
2268 // Let's copy all inputs (TF tensors) of original node to new node.
2269 int iidx = 0;
2270 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2271 // An input slot could be a single tensor or a list. We need
2272 // to handle this case accordingly.
2273 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2274 if (ArgIsList(arg)) {
2275 std::vector<NodeBuilder::NodeOut> new_node_inputs;
2276 int tensor_list_length = GetTensorListLength(arg, old_node);
2277 if (tensor_list_length != 0) {
2278 GetNodesProducingTFTensorList(old_node_inputs, &iidx,
2279 tensor_list_length, &new_node_inputs);
2280 }
2281 nb->Input(new_node_inputs);
2282 nn_slot_idx++;
2283 } else {
2284 // Special case for connecting filter input of Conv2DBackpropInput
2285 if (do_connect_conv2d_backprop_input_filter &&
2286 iidx == kConv2DBackpropInputFilterInputSlotIdx) {
2287 nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx);
2288 } else {
2289 nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
2290 }
2291 iidx++;
2292 nn_slot_idx++;
2293 }
2294 }
2295
2296 // If workspace tensors are available for this op and we are using
2297 // contiguous ordering then we need to add Tensorflow tensor for
2298 // workspace here because Tensorflow tensor for workspace is the
2299 // last tensor in the list of Tensorflow tensors.
2300 if (are_workspace_tensors_available) {
2301 CHECK_EQ(workspace_tensors->size(), 2);
2302 // Tensorflow tensor
2303 nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index);
2304 nn_slot_idx++;
2305 }
2306
2307 // Let's now setup all Mkl inputs to a new node.
2308 // Number of Mkl inputs must be same as number of TF inputs.
2309 iidx = 0;
2310 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2311 // An input slot could be a single tensor or a list. We need
2312 // to handle this case accordingly.
2313 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2314 if (ArgIsList(arg)) {
2315 std::vector<NodeBuilder::NodeOut> new_node_inputs;
2316 int tensor_list_length = GetTensorListLength(arg, old_node);
2317 if (tensor_list_length != 0) {
2318 GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx,
2319 tensor_list_length, &new_node_inputs);
2320 }
2321 nb->Input(new_node_inputs);
2322 nn_slot_idx++;
2323 } else {
2324 Node* mkl_node = nullptr;
2325 int mkl_node_output_slot = 0;
2326 // Special case for connecting filter input of Conv2DBackpropInput
2327 if (do_connect_conv2d_backprop_input_filter &&
2328 iidx == kConv2DBackpropInputFilterInputSlotIdx) {
2329 GetNodeProducingMklTensor(g, old_node, conv2d_node,
2330 kConv2DFilterOutputSlotIdx, &mkl_node,
2331 &mkl_node_output_slot);
2332 } else {
2333 GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first,
2334 old_node_inputs[iidx].second, &mkl_node,
2335 &mkl_node_output_slot);
2336 }
2337 nb->Input(mkl_node, mkl_node_output_slot);
2338 iidx++;
2339 nn_slot_idx++;
2340 }
2341 }
2342
2343 // If workspace tensors are available for this op and we are using
2344 // contiguous ordering then we need to add Mkl tensor for
2345 // workspace here because Mkl tensor for workspace is the
2346 // last tensor in the list of Mkl tensors.
2347 if (are_workspace_tensors_available) {
2348 CHECK_EQ(workspace_tensors->size(), 2);
2349 // Mkl tensor
2350 nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index);
2351 nn_slot_idx++;
2352 }
2353
2354 return nn_slot_idx;
2355}
2356
2357// This method finds out if checking workspace is needed or not. Workspace is
2358// not used in quantized ops, so checking that would fail as quantized ops
2359// don't have attribute: "T".
2360bool IsWorkspaceCheckNeeded(const Node* node) {
2361 std::vector<string> quant_ops{
2362 "Dequantize",
2363 "QuantizeV2",
2364 "QuantizedConv2D",
2365 "QuantizedConv2DWithBias",
2366 "QuantizedConv2DAndRelu",
2367 "QuantizedConv2DWithBiasAndRelu",
2368 "QuantizedConv2DWithBiasSumAndRelu",
2369 "QuantizedConv2DPerChannel",
2370 "QuantizedConv2DAndRequantize",
2371 "QuantizedConv2DWithBiasAndRequantize",
2372 "QuantizedConv2DAndReluAndRequantize",
2373 "QuantizedConv2DWithBiasAndReluAndRequantize",
2374 "QuantizedConv2DWithBiasSumAndReluAndRequantize",
2375 "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize",
2376 "QuantizedMatMulWithBias",
2377 "QuantizedMatMulWithBiasAndRequantize",
2378 "QuantizedMatMulWithBiasAndDequantize",
2379 "QuantizedMatMulWithBiasAndRelu",
2380 "QuantizedMatMulWithBiasAndReluAndRequantize",
2381 "QuantizedDepthwiseConv2D",
2382 "QuantizedDepthwiseConv2DWithBias",
2383 "QuantizedDepthwiseConv2DWithBiasAndRelu",
2384 "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize"};
2385 return std::find(std::begin(quant_ops), std::end(quant_ops),
2386 node->type_string()) == std::end(quant_ops);
2387}
2388
2389Status MklLayoutRewritePass::SetUpInputs(
2390 std::unique_ptr<Graph>* g,
2391 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2392 NodeBuilder* nb, const Node* old_node) {
2393 // Let's check if we need to add workspace tensors for this node.
2394 // We add workspace edge only for MaxPool, LRN and BatchNorm.
2395 std::vector<NodeBuilder::NodeOut> workspace_tensors;
2396 bool are_workspace_tensors_available = false;
2397
2398 if (IsWorkspaceCheckNeeded(old_node)) {
2399 AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors,
2400 &are_workspace_tensors_available);
2401 }
2402
2403 int new_node_input_slots = 0;
2404 if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
2405 // TODO(nhasabni): implement this function just for same of completion.
2406 // We do not use interleaved ordering right now.
2407 return Status(
2408 error::Code::UNIMPLEMENTED,
2409 "Interleaved ordering of tensors is currently not supported.");
2410 } else {
2411 CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
2412 new_node_input_slots = SetUpContiguousInputs(
2413 g, old_node_inputs, nb, old_node, &workspace_tensors,
2414 are_workspace_tensors_available);
2415 }
2416
2417 // Sanity check
2418 int old_node_input_slots = old_node->op_def().input_arg_size();
2419 if (!are_workspace_tensors_available) {
2420 // If we are not adding workspace tensors for this op, then the total
2421 // number of input slots to the new node _must_ be 2 times the number
2422 // of input slots to the original node: N original Tensorflow tensors and
2423 // N for Mkl tensors corresponding to each Tensorflow tensors.
2424 CHECK_EQ(new_node_input_slots, old_node_input_slots * 2);
2425 } else {
2426 // If we are adding workspace tensors for this op, then the total
2427 // The total number of input slots to new node _must_ be 2 times the number
2428 // of input slots to the original node: N original Tensorflow tensors and
2429 // N for Mkl tensors corresponding to each Tensorflow tensors plus 2
2430 // (for workspace Tensorflow tensor and workspace Mkl tensor).
2431 CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2);
2432 }
2433
2434 return OkStatus();
2435}
2436
2437Status MklLayoutRewritePass::CopyInputs(
2438 const Node* old_node,
2439 const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
2440 NodeBuilder* nb) {
2441 // Number of input slots to old node
2442 // Input slots are represented by .Input() calls in REGISTER_OP.
2443 int old_node_input_slots = old_node->op_def().input_arg_size();
2444 // Actual number of inputs can be greater than or equal to number
2445 // of Input slots because inputs of type list could be unfolded.
2446 auto old_node_input_size = old_node_inputs.size();
2447
2448 if (old_node->type_string() == "_FusedConv2D") {
2449 // [TODO(intel-tf)]
2450 // commit 5be9a5 updates _FusedConv2D with additional host_args,
2451 // but mkl version currently doesn't have this input arg, needs to
2452 // remove this extra input when replace node with mkl node.
2453 old_node_input_slots--;
2454 }
2455
2456 DCHECK_GE(old_node_input_size, old_node_input_slots);
2457
2458 // Let's copy all inputs of old node to new node.
2459 int iidx = 0;
2460 for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) {
2461 // An input slot could be a single tensor or a list. We need
2462 // to handle this case accordingly.
2463 DCHECK_LT(iidx, old_node_input_size);
2464 const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx);
2465 if (ArgIsList(arg)) {
2466 std::vector<NodeBuilder::NodeOut> new_node_inputs;
2467 int N = GetTensorListLength(arg, old_node);
2468 if (N != 0) {
2469 GetNodesProducingTFTensorList(old_node_inputs, &iidx, N,
2470 &new_node_inputs);
2471 }
2472 nb->Input(new_node_inputs);
2473 } else {
2474 nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second);
2475 iidx++;
2476 }
2477 }
2478 return OkStatus();
2479}
2480
2481//////////////////////////////////////////////////////////////////////////
2482// Helper functions related to workspace pass
2483//////////////////////////////////////////////////////////////////////////
2484
2485// TODO(nhasabni) We should move this to mkl_util.h.
2486void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
2487 std::unique_ptr<Graph>* g, Node** out, const Node* orig_node) {
2488 // We use uint8 tensor of shape 8 with content {0,0,0,0,0,0,0,0} to represent
2489 // workspace tensor.
2490 GetDummyMklTensorNode(g, out, orig_node);
2491}
2492
2493void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
2494 std::unique_ptr<Graph>* g, const Node* orig_node, NodeBuilder* nb,
2495 std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) {
2496 bool workspace_edge_added = false; // Default initializer
2497 DCHECK(are_ws_tensors_added);
2498 *are_ws_tensors_added = false; // Default initializer
2499
2500 DataType T;
2501 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2502 for (auto ws : wsinfo_) {
2503 if (orig_node->type_string() == ws.fwd_op &&
2504 mkl_op_registry::IsMklOp(
2505 mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
2506 // If this op is a fwd op, then we need to check if there is an
2507 // edge from this node's fwd_slot to bwdop's bwd_slot. If there is
2508 // an edge, then we just add an attribute on this node for setting
2509 // workspace_passed to true. We don't add actual workspace edge
2510 // in this node. Actual workspace edge gets added in the backward
2511 // op for this node.
2512 for (const Edge* e : orig_node->out_edges()) {
2513 if (e->src_output() == ws.fwd_slot &&
2514 e->dst()->type_string() == ws.bwd_op &&
2515 e->dst_input() == ws.bwd_slot) {
2516 nb->Attr("workspace_enabled", true);
2517 VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
2518 << orig_node->type_string();
2519 workspace_edge_added = true;
2520 // We found the edge that we were looking for, so break.
2521 break;
2522 }
2523 }
2524
2525 if (!workspace_edge_added) {
2526 // If we are here, then we did not find backward operator for this
2527 // node.
2528 nb->Attr("workspace_enabled", false);
2529 }
2530 } else if (orig_node->type_string() == ws.bwd_op &&
2531 mkl_op_registry::IsMklOp(
2532 mkl_op_registry::GetMklOpName(orig_node->type_string()),
2533 T)) {
2534 // If this op is a bwd op, then we need to add workspace edge and
2535 // it's Mkl tensor edge between its corresponding fwd op and this
2536 // op. Corresponding fwd op is specified in 'fwd_op' field of
2537 // workspace info. fwd_slot and bwd_slot in workspace info specify
2538 // an edge between which slots connect forward and backward op.
2539 // Once all these criteria match, we add a workspace edge between
2540 // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is
2541 // determined by interleaved/contiguous ordering. Function
2542 // DataIndexToMetaDataIndex tells us the location of Mkl tensor
2543 // from the location of the Tensorflow tensor.
2544 for (const Edge* e : orig_node->in_edges()) {
2545 if (e->src_output() == ws.fwd_slot &&
2546 // We would have rewritten the forward op, so we need to use
2547 // GetMklOpName call to get its Mkl name.
2548 e->src()->type_string() ==
2549 mkl_op_registry::GetMklOpName(ws.fwd_op) &&
2550 e->dst_input() == ws.bwd_slot) {
2551 nb->Attr("workspace_enabled", true);
2552 DCHECK(ws_tensors);
2553 // Add workspace edge between fwd op and bwd op.
2554 ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
2555 // Check if we are running in native format mode. If so,
2556 // we don't need to have an Mkl metadata tensor for the workspace.
2557 if (!NativeFormatEnabled()) {
2558 // Add Mkl tensor edge for workspace edge between fwd op and bwd op.
2559 ws_tensors->push_back(NodeBuilder::NodeOut(
2560 e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
2561 e->src()->num_outputs())));
2562 }
2563 *are_ws_tensors_added = true;
2564 // In terms of input ordering, we add these calls to add Input
2565 // here because workspace edge (and its Mkl tensor) is the last
2566 // edge in the fwdop and bwdop. So all inputs before workspace
2567 // tensor have been added by SetUpInputs function.
2568 VLOG(1) << "MklLayoutRewritePass: workspace_enabled for "
2569 << orig_node->type_string();
2570 workspace_edge_added = true;
2571 // We found the edge that we were looking for, so break.
2572 break;
2573 }
2574 }
2575
2576 // If we are here means we did not find fwd op that feeds to this
2577 // bwd op. So in this case, we need to generate dummy tensors for
2578 // workspace input and Mkl tensor for workspace, and set
2579 // workspace_enabled to false.
2580 if (!workspace_edge_added) {
2581 nb->Attr("workspace_enabled", false);
2582 Node* dmt_ws = nullptr; // Dummy tensor for workspace
2583 Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace
2584 GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node);
2585 GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node);
2586 DCHECK(dmt_ws);
2587 DCHECK(dmt_mkl_ws);
2588 DCHECK(ws_tensors);
2589 // We add dummy tensor as workspace tensor.
2590 ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0));
2591 // We add dummy tensor as Mkl tensor for workspace tensor.
2592 ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0));
2593 *are_ws_tensors_added = true;
2594 VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for "
2595 << orig_node->type_string();
2596 }
2597 } else {
2598 // If this node does not match any workspace info, then we do not
2599 // do anything special for workspace propagation for it.
2600 }
2601 }
2602}
2603
2604//////////////////////////////////////////////////////////////////////////
2605// Op-specific functions to copy attributes from old node to new node
2606//////////////////////////////////////////////////////////////////////////
2607
2608// Generic function to copy all attributes from original node to target.
2609void MklLayoutRewritePass::CopyAttrsAll(const Node* orig_node, NodeBuilder* nb,
2610 bool change_format) {
2611 string name;
2612 AttrSlice attr_list(orig_node->def());
2613
2614 auto iter = attr_list.begin();
2615 while (iter != attr_list.end()) {
2616 name = iter->first;
2617 auto attr = iter->second;
2618 nb->Attr(name, attr);
2619 ++iter;
2620 }
2621}
2622
2623// Generic function to copy all attributes and check if filter is const.
2624void MklLayoutRewritePass::CopyAttrsAllCheckConstFilter(const Node* orig_node,
2625 NodeBuilder* nb,
2626 bool change_format) {
2627 CopyAttrsAll(orig_node, nb, change_format);
2628
2629 // Check and set filter attribute.
2630 Node* filter_node = nullptr;
2631 TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2632
2633 bool is_filter_const = false;
2634 if (HasNodeAttr(orig_node->def(), "is_filter_const")) {
2635 (void)GetNodeAttr(orig_node->def(), "is_filter_const", &is_filter_const);
2636 }
2637
2638 // In case that (1) orig_node does not have attribute 'is_filter_const',
2639 // or (2) it has the attribute but with the false value, then we set the
2640 // attribute for 'nb' with a value based on filter_node being const or not.
2641 // If is_filter_const == true, then there is no need to call nb->Attr() as
2642 // CopyAttrsAll() has already copied the attribute from orig_node to nb.
2643 if (!is_filter_const) {
2644 nb->Attr("is_filter_const", filter_node->IsConstant());
2645 }
2646}
2647
2648void MklLayoutRewritePass::CopyAttrsConvCheckConstFilter(const Node* orig_node,
2649 NodeBuilder* nb,
2650 bool change_format) {
2651 CopyAttrsConv(orig_node, nb, change_format);
2652
2653 // Check and set filter attribute.
2654 Node* filter_node = nullptr;
2655 TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2656 nb->Attr("is_filter_const", filter_node->IsConstant());
2657}
2658
2659void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node, NodeBuilder* nb,
2660 bool change_format) {
2661 DataType T;
2662 string padding;
2663 std::vector<int32> strides;
2664 std::vector<int32> dilations;
2665 std::vector<int32> explicit_paddings;
2666
2667 // Get all attributes from old node.
2668 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2669 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2670 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2671 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2672
2673 // Check `explicit_paddings` first because some Conv ops don't have
2674 // this attribute.
2675 if (TryGetNodeAttr(orig_node->def(), "explicit_paddings",
2676 &explicit_paddings) &&
2677 !explicit_paddings.empty()) {
2678 nb->Attr("explicit_paddings", explicit_paddings);
2679 }
2680
2681 // Add attributes to new node.
2682 nb->Attr("T", T);
2683 nb->Attr("padding", padding);
2684
2685 // Add attributes related to `data_format`.
2686 CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format);
2687}
2688
2689// Used with MergePadWithConv2D
2690void MklLayoutRewritePass::CopyAttrsFromPadAndConv2D(const Node* orig_node1,
2691 const Node* orig_node2,
2692 NodeBuilder* nb,
2693 bool change_format) {
2694 DataType Tpaddings;
2695 DataType T;
2696 string data_format;
2697 string padding;
2698 std::vector<int32> strides;
2699 std::vector<int32> dilations;
2700 bool use_cudnn_on_gpu;
2701
2702 // Get all attributes from old node 1.
2703 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "T", &T));
2704 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "strides", &strides));
2705 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "dilations", &dilations));
2706 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "padding", &padding));
2707 TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "data_format", &data_format));
2708 TF_CHECK_OK(
2709 GetNodeAttr(orig_node1->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2710 // Get all attributes from old node 2.
2711 TF_CHECK_OK(GetNodeAttr(orig_node2->def(), "Tpaddings", &Tpaddings));
2712
2713 // Add attributes to new node.
2714 nb->Attr("T", T);
2715 nb->Attr("strides", strides);
2716 nb->Attr("dilations", dilations);
2717 nb->Attr("padding", padding);
2718 nb->Attr("data_format", data_format);
2719 nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
2720 nb->Attr("Tpaddings", Tpaddings);
2721}
2722
2723void MklLayoutRewritePass::CopyAttrsFusedConv2DCheckConstFilter(
2724 const Node* orig_node, NodeBuilder* nb, bool change_format) {
2725 DataType T;
2726 int num_args;
2727 string data_format;
2728 string padding;
2729 std::vector<int32> strides;
2730 std::vector<int32> dilations;
2731 std::vector<int32> explicit_paddings;
2732 float epsilon;
2733 std::vector<string> fused_ops;
2734 float leakyrelu_alpha;
2735 bool use_cudnn_on_gpu;
2736
2737 // Get all attributes from old node.
2738 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2739 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_args", &num_args));
2740 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2741 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2742 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2743 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2744 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "fused_ops", &fused_ops));
2745 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon", &epsilon));
2746 TF_CHECK_OK(
2747 GetNodeAttr(orig_node->def(), "leakyrelu_alpha", &leakyrelu_alpha));
2748 TF_CHECK_OK(
2749 GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
2750
2751 // Add attributes to new node.
2752 nb->Attr("T", T);
2753 nb->Attr("num_args", num_args);
2754 nb->Attr("strides", strides);
2755 nb->Attr("padding", padding);
2756 nb->Attr("data_format", data_format);
2757 nb->Attr("dilations", dilations);
2758 nb->Attr("epsilon", epsilon);
2759 nb->Attr("fused_ops", fused_ops);
2760 nb->Attr("leakyrelu_alpha", leakyrelu_alpha);
2761 nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
2762
2763 // Check `explicit_paddings` first because some Conv ops don't have
2764 // this attribute.
2765 if (TryGetNodeAttr(orig_node->def(), "explicit_paddings",
2766 &explicit_paddings) &&
2767 !explicit_paddings.empty()) {
2768 nb->Attr("explicit_paddings", explicit_paddings);
2769 }
2770
2771 // Check and set filter attribute.
2772 Node* filter_node = nullptr;
2773 TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2774 nb->Attr("is_filter_const", filter_node->IsConstant());
2775}
2776
2777void MklLayoutRewritePass::CopyAttrsFromPadAndFusedConv2D(
2778 const Node* fused_conv2d, const Node* pad, NodeBuilder* nb,
2779 bool change_format) {
2780 DataType T;
2781 int num_args;
2782 string data_format;
2783 string padding;
2784 std::vector<int32> strides;
2785 std::vector<int32> dilations;
2786 float epsilon;
2787 std::vector<string> fused_ops;
2788 DataType Tpaddings;
2789 float leakyrelu_alpha;
2790
2791 // Get all attributes from old node.
2792 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "T", &T));
2793 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "num_args", &num_args));
2794 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "strides", &strides));
2795 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "padding", &padding));
2796 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "data_format", &data_format));
2797 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "dilations", &dilations));
2798 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "fused_ops", &fused_ops));
2799 TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "epsilon", &epsilon));
2800 TF_CHECK_OK(
2801 GetNodeAttr(fused_conv2d->def(), "leakyrelu_alpha", &leakyrelu_alpha));
2802 TF_CHECK_OK(GetNodeAttr(pad->def(), "Tpaddings", &Tpaddings));
2803
2804 // Add attributes to new node.
2805 nb->Attr("T", T);
2806 nb->Attr("num_args", num_args);
2807 nb->Attr("strides", strides);
2808 nb->Attr("padding", padding);
2809 nb->Attr("data_format", data_format);
2810 nb->Attr("dilations", dilations);
2811 nb->Attr("epsilon", epsilon);
2812 nb->Attr("Tpaddings", Tpaddings);
2813 nb->Attr("fused_ops", fused_ops);
2814 nb->Attr("leakyrelu_alpha", leakyrelu_alpha);
2815}
2816
2817void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node,
2818 NodeBuilder* nb,
2819 bool change_format) {
2820 DataType Tinput, Tfilter, out_type;
2821 string padding;
2822 string data_format("NHWC");
2823 std::vector<int32> strides, dilations, padding_list;
2824 bool has_padding_list = HasNodeAttr(orig_node->def(), "padding_list");
2825
2826 // Get all attributes from old node.
2827 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput));
2828 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tfilter", &Tfilter));
2829 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "out_type", &out_type));
2830 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2831 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2832 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
2833 if (has_padding_list) {
2834 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding_list", &padding_list));
2835 }
2836
2837 Node* filter_node = nullptr;
2838 TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2839
2840 // Add attributes to new node.
2841 nb->Attr("Tinput", Tinput);
2842 nb->Attr("Tfilter", Tfilter);
2843 nb->Attr("out_type", out_type);
2844 nb->Attr("padding", padding);
2845 nb->Attr("is_filter_const", filter_node->IsConstant());
2846 nb->Attr("strides", strides);
2847 nb->Attr("dilations", dilations);
2848 nb->Attr("data_format", data_format);
2849 if (has_padding_list) {
2850 nb->Attr("padding_list", padding_list);
2851 }
2852
2853 // Requantization attr Tbias.
2854 DataType Tbias;
2855 Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2856 if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2857}
2858
2859void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBiasAndDequantize(
2860 const Node* orig_node, NodeBuilder* nb, bool change_format) {
2861 CopyAttrsAll(orig_node, nb, change_format);
2862
2863 // Check and set filter attribute.
2864 Node* filter_node = nullptr;
2865 TF_CHECK_OK(orig_node->input_node(1, &filter_node));
2866 nb->Attr("is_weight_const", filter_node->IsConstant());
2867}
2868
2869void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBias(
2870 const Node* orig_node, NodeBuilder* nb, bool change_format) {
2871 DataType T1, T2, Toutput;
2872
2873 // Get all attributes from old node.
2874 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T1", &T1));
2875 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T2", &T2));
2876 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Toutput", &Toutput));
2877
2878 Node* weight_node = nullptr;
2879 TF_CHECK_OK(orig_node->input_node(1, &weight_node));
2880
2881 // Add attributes to new node.
2882 nb->Attr("T1", T1);
2883 nb->Attr("T2", T2);
2884 nb->Attr("Toutput", Toutput);
2885 nb->Attr("is_weight_const", weight_node->IsConstant());
2886
2887 // Requantization attr Tbias
2888 DataType Tbias;
2889 Status bias_status = GetNodeAttr(orig_node->def(), "Tbias", &Tbias);
2890 if (bias_status.ToString() == "OK") nb->Attr("Tbias", Tbias);
2891}
2892
2893void MklLayoutRewritePass::CopyFormatAttrsConv(
2894 const Node* orig_node, NodeBuilder* nb, const std::vector<int32>& strides,
2895 const std::vector<int32>& dilations, bool change_format) {
2896 string data_format;
2897
2898 if (!change_format) {
2899 nb->Attr("strides", strides);
2900 nb->Attr("dilations", dilations);
2901
2902 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2903 nb->Attr("data_format", data_format);
2904 } else {
2905 std::vector<int32> new_strides;
2906 std::vector<int32> new_dilations;
2907 if (strides.size() == 5) {
2908 // `strides` and `dilations` also need to be changed according to
2909 // `data_format`. In this case, from `NDHWC` to `NCDHW`.
2910 new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C],
2911 strides[NDHWC::dim::D], strides[NDHWC::dim::H],
2912 strides[NDHWC::dim::W]};
2913
2914 new_dilations = {dilations[NDHWC::dim::N], dilations[NDHWC::dim::C],
2915 dilations[NDHWC::dim::D], dilations[NDHWC::dim::H],
2916 dilations[NDHWC::dim::W]};
2917 } else {
2918 // `strides` and `dilations` also need to be changed according to
2919 // `data_format`. In this case, from `NHWC` to `NCHW`.
2920
2921 new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C],
2922 strides[NHWC::dim::H], strides[NHWC::dim::W]};
2923
2924 new_dilations = {dilations[NHWC::dim::N], dilations[NHWC::dim::C],
2925 dilations[NHWC::dim::H], dilations[NHWC::dim::W]};
2926 }
2927 nb->Attr("strides", new_strides);
2928 nb->Attr("dilations", new_dilations);
2929 }
2930}
2931
2932void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node,
2933 NodeBuilder* nb,
2934 bool change_format) {
2935 DataType T;
2936 string data_format;
2937 string padding;
2938 std::vector<int32> ksize, strides;
2939
2940 // Get all attributes from old node.
2941 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
2942 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize", &ksize));
2943 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
2944 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
2945 TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
2946
2947 // Add attributes to new node.
2948 nb->Attr("T", T);
2949 nb->Attr("padding", padding);
2950
2951 if (!change_format) {
2952 nb->Attr("strides", strides);
2953 nb->Attr("ksize", ksize);
2954
2955 nb->Attr("data_format", data_format);
2956 } else {
2957 std::vector<int32> new_strides;
2958 std::vector<int32> new_ksize;
2959 if (strides.size() == 5) {
2960 DCHECK(data_format == "NCDHW");
2961 // `strides` and `ksize` also need to be changed according to
2962 // `data_format`. In this case, from `NDHWC` to `NCDHW`.
2963 new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C],
2964 strides[NDHWC::dim::D], strides[NDHWC::dim::H],
2965 strides[NDHWC::dim::W]};
2966
2967 new_ksize = {ksize[NDHWC::dim::N], ksize[NDHWC::dim::C],
2968 ksize[NDHWC::dim::D], ksize[NDHWC::dim::H],
2969 ksize[NDHWC::dim::W]};
2970
2971 } else {
2972 // `strides` and `ksize` also need to be changed according to
2973 // `data_format`. In this case, from `NHWC` to `NCHW`.
2974 DCHECK(data_format == "NCHW");
2975 new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C],
2976 strides[NHWC::dim::H], strides[NHWC::dim::W]};
2977
2978 new_ksize = {ksize[NHWC::dim::N], ksize[NHWC::dim::C],
2979 ksize[NHWC::dim::H], ksize[NHWC::dim::W]};
2980 }
2981 nb->Attr("strides", new_strides);
2982 nb->Attr("ksize", new_ksize);
2983 }
2984}
2985
2986//////////////////////////////////////////////////////////////////////////
2987// Helper functions related to node merge pass
2988//////////////////////////////////////////////////////////////////////////
2989
2990Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
2991 // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite
2992 // once we support BiasAddGrad as Mkl layer.
2993
2994 // Search for all matching mergeinfo.
2995 // We allow more than one match for extensibility.
2996 std::vector<const MergeInfo*> matching_mi;
2997 for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
2998 if (a->type_string() == mi->op1 || a->type_string() == mi->op2) {
2999 matching_mi.push_back(&*mi);
3000 }
3001 }
3002
3003 for (const MergeInfo* mi : matching_mi) {
3004 // Get the operand with which 'a' can be merged.
3005 Node* b = nullptr;
3006 if ((b = mi->get_node_to_be_merged(a)) == nullptr) {
3007 continue;
3008 }
3009
3010 // Get the control edges and input of node
3011 const int N_in = a->num_inputs();
3012 gtl::InlinedVector<Node*, 4> a_control_edges;
3013 gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
3014 FillInputs(a, &a_control_edges, &a_in);
3015
3016 const int B_in = b->num_inputs();
3017 gtl::InlinedVector<Node*, 4> b_control_edges;
3018 gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in);
3019 FillInputs(b, &b_control_edges, &b_in);
3020
3021 // Shouldn't merge if a and b have different control edges.
3022 if (a_control_edges != b_control_edges) {
3023 continue;
3024 } else {
3025 // We found a match.
3026 return b;
3027 }
3028 }
3029
3030 return nullptr;
3031}
3032
3033Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
3034 Node* m, Node* n) {
3035 CHECK_EQ(((m->type_string() == csinfo_.bias_add &&
3036 n->type_string() == csinfo_.conv2d)) ||
3037 ((n->type_string() == csinfo_.bias_add &&
3038 m->type_string() == csinfo_.conv2d)),
3039 true);
3040
3041 // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd,
3042 // BiasAdd is successor node, and Conv2D predecessor node.
3043 Node* pred = m->type_string() == csinfo_.bias_add ? n : m;
3044 Node* succ = m->type_string() == csinfo_.bias_add ? m : n;
3045
3046 // 1. Get all attributes from input nodes.
3047 DataType T_pred, T_succ;
3048 string padding;
3049 std::vector<int32> strides;
3050 std::vector<int32> dilations;
3051 string data_format_pred, data_format_succ;
3052 bool use_cudnn_on_gpu;
3053 TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
3054 TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
3055 TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
3056 TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
3057 TF_CHECK_OK(GetNodeAttr(pred->def(), "dilations", &dilations));
3058 TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
3059 TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
3060 TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
3061 // We check to ensure that data formats of both succ and pred are same.
3062 // We expect them to be same, so we can enforce this as assert.
3063 // But assert can be too strict, so we enforce this as a check.
3064 // If the check fails, then we do not merge two nodes.
3065 // We also do same check for devices.
3066 if (data_format_pred != data_format_succ || T_pred != T_succ ||
3067 pred->assigned_device_name() != succ->assigned_device_name() ||
3068 pred->def().device() != succ->def().device()) {
3069 return Status(error::Code::INVALID_ARGUMENT,
3070 "data_format or T attribute or devices of Conv2D and "
3071 "BiasAdd do not match. Will skip node merge optimization");
3072 }
3073
3074 const int succ_num = succ->num_inputs();
3075 gtl::InlinedVector<Node*, 4> succ_control_edges;
3076 gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
3077 FillInputs(succ, &succ_control_edges, &succ_in);
3078
3079 const int pred_num = pred->num_inputs();
3080 gtl::InlinedVector<Node*, 4> pred_control_edges;
3081 gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
3082 FillInputs(pred, &pred_control_edges, &pred_in);
3083
3084 // We need to ensure that Conv2D only feeds to BiasAdd (some other operator is
3085 // not expecting output of Conv2D). If this is not the case, then we cannot
3086 // merge Conv2D with BiasAdd.
3087 const int kFirstOutputSlot = 0;
3088 for (const Edge* e : pred->out_edges()) {
3089 if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
3090 return Status(error::Code::INVALID_ARGUMENT,
3091 "Conv2D does not feed to BiasAdd, or "
3092 "it feeds BiasAdd but has multiple outputs. "
3093 "Will skip node merge optimization");
3094 }
3095 }
3096
3097 // 2. Get inputs from both the nodes.
3098 // Find the 2 inputs from the conv and the bias from the add Bias.
3099 // Get operand 0, 1 of conv2D.
3100 CHECK_EQ(pred->in_edges().size(), 2); // Conv2D must have 2 inputs.
3101 // Get operand 1 of add_bias
3102 // BiasAdd must have 2 inputs: Conv, bias
3103 CHECK_EQ(succ->in_edges().size(), 2);
3104
3105 // We will use the node name of BiasAdd as the name of new node
3106 // Build new node. We use same name as original node, but change the op
3107 // name.
3108 NodeBuilder nb(succ->name(), csinfo_.conv2d_with_bias);
3109 nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D
3110 // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
3111 nb.Input(pred_in[1].first, pred_in[1].second); // In2 of Conv2D
3112 // In1 of BiasAdd is same as output of Conv2D.
3113 nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd
3114
3115 // Copy attributes from Conv2D to Conv2DWithBias.
3116 CopyAttrsConvCheckConstFilter(const_cast<const Node*>(pred), &nb);
3117
3118 // Copy the device assigned to old node to new node.
3119 nb.Device(succ->def().device());
3120
3121 // Create node.
3122 Node* new_node;
3123 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3124
3125 // In the following code of this function, an unsorted set is used to make
3126 // sure no duplicated edges be added into the new node. Therefore, we can
3127 // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3128 // check in the routine.
3129
3130 // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
3131 // node are already copied in BuildNode. We handle control edges now.
3132 std::unordered_set<Node*> unique_node;
3133 for (const Edge* e : pred->in_edges()) {
3134 if (e->IsControlEdge()) {
3135 auto result = unique_node.insert(e->src());
3136 if (result.second) {
3137 (*g)->AddControlEdge(e->src(), new_node, true);
3138 }
3139 }
3140 }
3141 unique_node.clear();
3142
3143 for (const Edge* e : succ->in_edges()) {
3144 if (e->IsControlEdge()) {
3145 auto result = unique_node.insert(e->src());
3146 if (result.second) {
3147 (*g)->AddControlEdge(e->src(), new_node, true);
3148 }
3149 }
3150 }
3151 unique_node.clear();
3152
3153 // Incoming edges are fixed, we will fix the outgoing edges now.
3154 // First, we will fix outgoing control edges from 'pred' node.
3155 for (const Edge* e : pred->out_edges()) {
3156 if (e->IsControlEdge()) {
3157 auto result = unique_node.insert(e->dst());
3158 if (result.second) {
3159 (*g)->AddControlEdge(new_node, e->dst(), true);
3160 }
3161 }
3162 }
3163 unique_node.clear();
3164
3165 // Second, we will fix outgoing control and data edges from 'succ' node.
3166 for (const Edge* e : succ->out_edges()) {
3167 if (e->IsControlEdge()) {
3168 auto result = unique_node.insert(e->dst());
3169 if (result.second) {
3170 (*g)->AddControlEdge(new_node, e->dst(), true);
3171 }
3172 } else {
3173 // BiasAdd has only 1 output (at slot 0) and merged node also has only 1
3174 // output (at slot 0).
3175 const int kConv2DWithBiasOutputSlot = 0;
3176 auto new_edge = (*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot,
3177 e->dst(), e->dst_input());
3178 DCHECK(new_edge);
3179 }
3180 }
3181
3182 // Copy device assigned to old node to new node.
3183 // It's ok to use pred or succ as we have enforced a check that
3184 // both have same device assigned.
3185 new_node->set_assigned_device_name(pred->assigned_device_name());
3186
3187 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
3188 << ", and node: " << succ->DebugString()
3189 << ", into node:" << new_node->DebugString();
3190
3191 (*g)->RemoveNode(succ);
3192 (*g)->RemoveNode(pred);
3193
3194 return OkStatus();
3195}
3196
3197Status MklLayoutRewritePass::MergePadWithConv2D(std::unique_ptr<Graph>* g,
3198 Node* m, Node* n) {
3199 DCHECK((m->type_string() == csinfo_.pad &&
3200 (n->type_string() == csinfo_.conv2d ||
3201 n->type_string() == csinfo_.fused_conv2d)) ||
3202 (n->type_string() == csinfo_.pad &&
3203 (m->type_string() == csinfo_.conv2d ||
3204 m->type_string() == csinfo_.fused_conv2d)));
3205
3206 bool is_fused_conv2d = n->type_string() == csinfo_.fused_conv2d ||
3207 m->type_string() == csinfo_.fused_conv2d;
3208 // Conv2D is successor node, and Pad predecessor node.
3209 Node* pred = m->type_string() == csinfo_.pad ? m : n;
3210 Node* succ = m->type_string() == csinfo_.pad ? n : m;
3211
3212 // 1. Get all attributes from input nodes.
3213 DataType T_pred, T_succ;
3214 string padding;
3215 std::vector<int32> strides;
3216 std::vector<int32> dilations;
3217 string data_format_pred, data_format_succ;
3218
3219 TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
3220 TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
3221 TF_CHECK_OK(GetNodeAttr(succ->def(), "padding", &padding));
3222 TF_CHECK_OK(GetNodeAttr(succ->def(), "strides", &strides));
3223 TF_CHECK_OK(GetNodeAttr(succ->def(), "dilations", &dilations));
3224 // Check if the devices of both succ and pred are the same.
3225 // Assert is not used because it can be too strict.
3226 // Don't need to check for data formats because it is not available in Pad.
3227 if (T_pred != T_succ ||
3228 pred->assigned_device_name() != succ->assigned_device_name() ||
3229 pred->def().device() != succ->def().device()) {
3230 return Status(error::Code::INVALID_ARGUMENT,
3231 "T attribute or devices of Conv2D and "
3232 "Pad do not match. Will skip node merge optimization");
3233 }
3234
3235 const int succ_num = succ->num_inputs();
3236 gtl::InlinedVector<Node*, 4> succ_control_edges;
3237 gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
3238 FillInputs(succ, &succ_control_edges, &succ_in);
3239
3240 const int pred_num = pred->num_inputs();
3241 gtl::InlinedVector<Node*, 4> pred_control_edges;
3242 gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
3243 FillInputs(pred, &pred_control_edges, &pred_in);
3244
3245 // We need to ensure that Pad only feeds to Conv2D (some other operator is
3246 // not expecting output of Pad). If this is not the case, then we cannot
3247 // merge Conv2D with Pad.
3248 const int kFirstOutputSlot = 0;
3249 for (const Edge* e : pred->out_edges()) {
3250 if (e->src_output() == kFirstOutputSlot && e->dst() != succ) {
3251 return Status(error::Code::INVALID_ARGUMENT,
3252 "Pad does not feed to Conv2D, or "
3253 "it feeds Conv2D but has multiple outputs. "
3254 "Will skip node merge optimization");
3255 }
3256 }
3257
3258 // 2. Get inputs from both the nodes.
3259
3260 // Pad must have 2 data inputs: "input" and paddings.
3261 int PadDataInputEdges = 0;
3262 for (const Edge* e : pred->in_edges()) {
3263 if (!e->IsControlEdge()) {
3264 PadDataInputEdges++;
3265 }
3266 }
3267 DCHECK_EQ(PadDataInputEdges, 2);
3268
3269 // Conv2D must have 2 data inputs: Pad output and Filter
3270 // FusedConv2D have 3 data inputs: Pad output, Filter and Args;
3271 int ConvDataInputEdges = 0;
3272 for (const Edge* e : succ->in_edges()) {
3273 if (!e->IsControlEdge()) {
3274 ConvDataInputEdges++;
3275 }
3276 }
3277
3278 DCHECK_EQ(ConvDataInputEdges, is_fused_conv2d ? 3 : 2);
3279
3280 // We will use the node name of Conv2D as the name of new node
3281 // Build new node. We use same name as original node, but change the op
3282 // name.
3283
3284 NodeBuilder nb(succ->name(), is_fused_conv2d ? csinfo_.pad_with_fused_conv2d
3285 : csinfo_.pad_with_conv2d);
3286 nb.Input(pred_in[0].first, pred_in[0].second); // In1 (input data) of Pad
3287 // pred_in[1] will be 2nd Tensorflow tensor for Conv2D.
3288 nb.Input(succ_in[1].first, succ_in[1].second); // In2 (filter) of conv2d
3289 // In1 of Conv2D is same as output of Pad.
3290 // Thus, only need to add In2 of Conv2D
3291
3292 if (is_fused_conv2d) {
3293 // FusedConv2D has one additional input, args
3294 std::vector<NodeBuilder::NodeOut> args;
3295 int num_args = 1;
3296 TF_CHECK_OK(GetNodeAttr(succ->def(), "num_args", &num_args));
3297 for (int i = 0; i < num_args; i++) {
3298 args.emplace_back(succ_in[2 + i].first, succ_in[2 + i].second);
3299 }
3300 nb.Input(gtl::ArraySlice<NodeBuilder::NodeOut>{
3301 args}); // In3 (args) of FusedConv2D
3302 nb.Input(pred_in[1].first, pred_in[1].second); // In2 (paddings) of Pad
3303 // Copy attributes from Pad and FusedConv2D to PadWithFusedConv2D.
3304 CopyAttrsFromPadAndFusedConv2D(const_cast<const Node*>(succ),
3305 const_cast<const Node*>(pred), &nb);
3306 } else {
3307 nb.Input(pred_in[1].first, pred_in[1].second); // In2 (paddings) of Pad
3308 // Copy attributes from Pad and conv2D to PadWithConv2D.
3309 CopyAttrsFromPadAndConv2D(const_cast<const Node*>(succ),
3310 const_cast<const Node*>(pred), &nb);
3311 }
3312
3313 // Copy the device assigned to old node to new node.
3314 nb.Device(succ->def().device());
3315
3316 // Create node.
3317 Node* new_node;
3318 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3319 // No need to check if new_node is null because it will be null only when
3320 // Finalize fails.
3321
3322 // Incoming data edges from 'pred' node and 'succ' node to new 'new_node'
3323 // node are already copied in BuildNode.
3324 // We handle control edges now.
3325 for (const Edge* e : pred->in_edges()) {
3326 if (e->IsControlEdge()) {
3327 // Don't allow duplicate edge
3328 (*g)->AddControlEdge(e->src(), new_node, false);
3329 }
3330 }
3331 for (const Edge* e : succ->in_edges()) {
3332 if (e->IsControlEdge()) {
3333 // Don't allow duplicate edge
3334 (*g)->AddControlEdge(e->src(), new_node, false);
3335 }
3336 }
3337
3338 // Incoming edges are fixed, we will fix the outgoing edges now.
3339 // First, we will fix outgoing control edges from 'pred' node.
3340 for (const Edge* e : pred->out_edges()) {
3341 if (e->IsControlEdge()) {
3342 // Don't allow duplicate edge
3343 (*g)->AddControlEdge(new_node, e->dst(), false);
3344 }
3345 }
3346
3347 // Second, we will fix outgoing control and data edges from 'succ' node.
3348 for (const Edge* e : succ->out_edges()) {
3349 if (e->IsControlEdge()) {
3350 // Allow duplicate while adding control edge as it would fail (return
3351 // NULL) if we try to add duplicate edge.
3352 (*g)->AddControlEdge(new_node, e->dst(), false);
3353 } else {
3354 // Conv2D has only 1 output (at slot 0) and merged node also has only 1
3355 // output (at slot 0).
3356 const int kPadWithConv2DOutputSlot = 0;
3357 (*g)->AddEdge(new_node, kPadWithConv2DOutputSlot, e->dst(),
3358 e->dst_input());
3359 }
3360 }
3361
3362 // Copy device assigned to old node to new node.
3363 // It's ok to use pred or succ as we have enforced a check that
3364 // both have same device assigned.
3365 new_node->set_assigned_device_name(pred->assigned_device_name());
3366
3367 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
3368 << ", and node: " << succ->DebugString()
3369 << ", into node:" << new_node->DebugString();
3370
3371 (*g)->RemoveNode(succ);
3372 (*g)->RemoveNode(pred);
3373
3374 return OkStatus();
3375}
3376
3377Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
3378 std::unique_ptr<Graph>* g, Node* m, Node* n) {
3379 CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad &&
3380 n->type_string() == csinfo_.conv2d_grad_filter)) ||
3381 ((n->type_string() == csinfo_.bias_add_grad &&
3382 m->type_string() == csinfo_.conv2d_grad_filter)),
3383 true);
3384
3385 // If 'm' is BiasAddGrad, then 'n' is BackpropFilter.
3386 Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n;
3387 Node* fltr = m->type_string() == csinfo_.bias_add_grad ? n : m;
3388
3389 // Sanity check for attributes from input nodes.
3390 DataType T_b, T_f;
3391 string data_format_b, data_format_f;
3392 TF_CHECK_OK(GetNodeAttr(badd->def(), "T", &T_b));
3393 TF_CHECK_OK(GetNodeAttr(fltr->def(), "T", &T_f));
3394 TF_CHECK_OK(GetNodeAttr(badd->def(), "data_format", &data_format_b));
3395 TF_CHECK_OK(GetNodeAttr(fltr->def(), "data_format", &data_format_f));
3396 if (data_format_b != data_format_f || T_b != T_f ||
3397 badd->assigned_device_name() != fltr->assigned_device_name() ||
3398 badd->def().device() != fltr->def().device()) {
3399 return Status(error::Code::INVALID_ARGUMENT,
3400 "data_format or T attribute or devices of "
3401 "Conv2DBackpropFilter and BiasAddGrad do not match. "
3402 "Will skip node merge optimization");
3403 }
3404
3405 // We will use the node name of Conv2DBackpropFilter as the name of new node.
3406 // This is because BackpropFilterWithBias is going to emit bias output also.
3407 NodeBuilder nb(fltr->name(), csinfo_.conv2d_grad_filter_with_bias);
3408 // Since Conv2DBackpropFilterWithBias has same number of inputs as
3409 // Conv2DBackpropFilter, we can just copy input edges directly. We don't need
3410 // to copy any data input of BiasAddGrad because that input also goes to
3411 // Conv2DBackpropFilter.
3412 const int fltr_ins = fltr->num_inputs();
3413 gtl::InlinedVector<Node*, 4> fltr_control_edges;
3414 gtl::InlinedVector<std::pair<Node*, int>, 4> fltr_in_edges(fltr_ins);
3415 FillInputs(fltr, &fltr_control_edges, &fltr_in_edges);
3416 for (int idx = 0; idx < fltr_ins; idx++) {
3417 nb.Input(fltr_in_edges[idx].first, fltr_in_edges[idx].second);
3418 }
3419
3420 // Copy attributes from Conv2DBackpropFilter.
3421 CopyAttrsConv(const_cast<const Node*>(fltr), &nb);
3422
3423 // Copy the device assigned to old node to new node.
3424 nb.Device(fltr->def().device());
3425
3426 // Create node.
3427 Node* new_node;
3428 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3429
3430 // In the following code of this function, an unsorted set is used to make
3431 // sure no duplicated edges be added into the new node. Therefore, we can
3432 // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3433 // check in the routine.
3434
3435 // Incoming data edges from BiasAddGrad node and Conv2DBackpropFilter node to
3436 // new 'new_node' node are already copied in BuildNode. We handle control
3437 // edges now.
3438 std::unordered_set<Node*> unique_node;
3439 for (const Edge* e : badd->in_edges()) {
3440 if (e->IsControlEdge()) {
3441 auto result = unique_node.insert(e->src());
3442 if (result.second) {
3443 (*g)->AddControlEdge(e->src(), new_node, true);
3444 }
3445 }
3446 }
3447 unique_node.clear();
3448 for (const Edge* e : fltr->in_edges()) {
3449 if (e->IsControlEdge()) {
3450 auto result = unique_node.insert(e->src());
3451 if (result.second) {
3452 (*g)->AddControlEdge(e->src(), new_node, true);
3453 }
3454 }
3455 }
3456 unique_node.clear();
3457
3458 // Incoming edges are fixed, we will fix the outgoing edges now.
3459 // First, we will fix outgoing control edges from 'badd' node.
3460 // Conv2DBackpropFilter has 1 output -- filter_grad.
3461 // Conv2DBackpropFilterWithBias has 2 outputs -- filter_grad and
3462 // bias_grad. But filter_grad is at same slot number (0) in both the
3463 // nodes. bias_grad is at slot number 1 in Conv2DBackpropFilterWithBias, while
3464 // it is at slot number 0 in BiasAddGrad.
3465 const int kMergedNodeFilterGradOutputIdx = 0;
3466 const int kMergedNodeBiasGradOutputIdx = 1;
3467
3468 for (const Edge* e : badd->out_edges()) {
3469 if (e->IsControlEdge()) {
3470 auto result = unique_node.insert(e->dst());
3471 if (result.second) {
3472 (*g)->AddControlEdge(new_node, e->dst(), true);
3473 }
3474 } else {
3475 auto new_edge = (*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx,
3476 e->dst(), e->dst_input());
3477 DCHECK(new_edge);
3478 }
3479 }
3480 unique_node.clear();
3481
3482 // Second, we will fix outgoing control and data edges from 'fltr' node.
3483 for (const Edge* e : fltr->out_edges()) {
3484 if (e->IsControlEdge()) {
3485 auto result = unique_node.insert(e->dst());
3486 if (result.second) {
3487 (*g)->AddControlEdge(new_node, e->dst(), true);
3488 }
3489 } else {
3490 auto new_edge = (*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx,
3491 e->dst(), e->dst_input());
3492 DCHECK(new_edge);
3493 }
3494 }
3495
3496 // Copy device assigned to old node to new node.
3497 // It's ok to use badd or fltr as we have enforced a check that
3498 // both have same device assigned.
3499 new_node->set_assigned_device_name(badd->assigned_device_name());
3500
3501 VLOG(1) << "MklLayoutRewritePass: Merged old node:" << badd->DebugString()
3502 << ", and node: " << fltr->DebugString()
3503 << ", into node:" << new_node->DebugString();
3504
3505 (*g)->RemoveNode(badd);
3506 (*g)->RemoveNode(fltr);
3507
3508 return OkStatus();
3509}
3510
3511Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* m,
3512 Node* n) {
3513 DCHECK(m);
3514 DCHECK(n);
3515
3516 if (((m->type_string() == csinfo_.bias_add &&
3517 n->type_string() == csinfo_.conv2d)) ||
3518 ((n->type_string() == csinfo_.bias_add &&
3519 m->type_string() == csinfo_.conv2d))) {
3520 return this->MergeConv2DWithBiasAdd(g, m, n);
3521 }
3522 if ((m->type_string() == csinfo_.pad &&
3523 (n->type_string() == csinfo_.conv2d ||
3524 (n->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(n)))) ||
3525 (n->type_string() == csinfo_.pad &&
3526 (m->type_string() == csinfo_.conv2d ||
3527 (m->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(m))))) {
3528 return this->MergePadWithConv2D(g, m, n);
3529 }
3530
3531 if (((m->type_string() == csinfo_.bias_add_grad &&
3532 n->type_string() == csinfo_.conv2d_grad_filter)) ||
3533 ((n->type_string() == csinfo_.bias_add_grad &&
3534 m->type_string() == csinfo_.conv2d_grad_filter))) {
3535 return this->MergeConv2DBackpropFilterWithBiasAddGrad(g, m, n);
3536 }
3537
3538 return Status(error::Code::UNIMPLEMENTED,
3539 "Unimplemented case for node merge optimization.");
3540}
3541
3542//////////////////////////////////////////////////////////////////////////
3543// Helper functions for node rewrite
3544//////////////////////////////////////////////////////////////////////////
3545
3546Status MklLayoutRewritePass::RewriteNodeForLayoutPropagation(
3547 std::unique_ptr<Graph>* g, const Node* orig_node, Node** new_node,
3548 const RewriteInfo* ri) {
3549 // Get all data inputs.
3550 int num_data_inputs = orig_node->in_edges().size();
3551 // Drop count for control edges from inputs
3552 for (const Edge* e : orig_node->in_edges()) {
3553 if (e->IsControlEdge()) {
3554 num_data_inputs--;
3555 }
3556 }
3557
3558 gtl::InlinedVector<Node*, 4> control_edges;
3559 gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_data_inputs);
3560 FillInputs(orig_node, &control_edges, &inputs);
3561
3562 // Build new node. We use same name as original node, but change the op name.
3563 NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
3564 // Copy user-specified device assigned to original node to new node.
3565 nb.Device(orig_node->def().device());
3566 // Set up new inputs to the rewritten node.
3567 Status s = SetUpInputs(g, inputs, &nb, orig_node);
3568 if (s != OkStatus()) {
3569 return s;
3570 }
3571
3572 const bool kPartialCopyAttrs = false;
3573 ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, kPartialCopyAttrs);
3574
3575 // Set the Mkl layer label for this op.
3576 if (DataTypeIsQuantized(orig_node->input_type(0)) ||
3577 DataTypeIsQuantized(orig_node->output_type(0))) {
3578 nb.Attr("_kernel", mkl_op_registry::kMklQuantizedOpLabel);
3579 } else {
3580 nb.Attr("_kernel", mkl_op_registry::kMklLayoutDependentOpLabel);
3581 }
3582 // Finalize graph and get new node.
3583 s = nb.Finalize(&**g, new_node);
3584 if (s != OkStatus()) {
3585 return s;
3586 }
3587
3588 // In the following code of this function, an unsorted set is used to make
3589 // sure no duplicated edges be added into the new node. Therefore, we can
3590 // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3591 // check in the routine.
3592
3593 // Incoming data edges from 'orig_node' node to new 'new_node' node are
3594 // already copied in BuildNode. We need to handle control edges now.
3595 std::unordered_set<Node*> unique_node;
3596 for (const Edge* e : orig_node->in_edges()) {
3597 if (e->IsControlEdge()) {
3598 auto result = unique_node.insert(e->src());
3599 if (result.second) {
3600 (*g)->AddControlEdge(e->src(), *new_node, true);
3601 }
3602 }
3603 }
3604 unique_node.clear();
3605
3606 // Copy outgoing edges from 'orig_node' node to new
3607 // 'new_node' node, since the output also follows same ordering among
3608 // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow
3609 // tensors appropriately. Specifically, nth output of the original node
3610 // will become 2*nth output of the Mkl node for the interleaved ordering
3611 // of the tensors. For the contiguous ordering of the tensors, it will be n.
3612 // GetTensorDataIndex provides this mapping function.
3613 for (const Edge* e : orig_node->out_edges()) {
3614 if (e->IsControlEdge()) {
3615 auto result = unique_node.insert(e->dst());
3616 if (result.second) {
3617 (*g)->AddControlEdge(*new_node, e->dst(), true);
3618 }
3619 } else {
3620 auto new_edge = (*g)->AddEdge(
3621 *new_node,
3622 GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
3623 e->dst(), e->dst_input());
3624 DCHECK(new_edge);
3625 }
3626 }
3627 return OkStatus();
3628}
3629
3630Status MklLayoutRewritePass::RewriteNodeForJustOpNameChange(
3631 std::unique_ptr<Graph>* g, const Node* orig_node, Node** new_node,
3632 const RewriteInfo* ri) {
3633 // Get all data inputs.
3634 int num_data_inputs = orig_node->in_edges().size();
3635 // Drop count for control edges from inputs
3636 for (const Edge* e : orig_node->in_edges()) {
3637 if (e->IsControlEdge()) {
3638 num_data_inputs--;
3639 }
3640 }
3641 gtl::InlinedVector<Node*, 4> control_edges;
3642 gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_data_inputs);
3643 FillInputs(orig_node, &control_edges, &inputs);
3644
3645 // Build new node. We use same name as original node, but change the op name.
3646 NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str());
3647 // Copy user-specified device assigned to original node to new node.
3648 nb.Device(orig_node->def().device());
3649
3650 Status s = CopyInputs(orig_node, inputs, &nb);
3651 if (s != OkStatus()) {
3652 return s;
3653 }
3654
3655 std::vector<NodeBuilder::NodeOut> workspace_tensors;
3656 bool are_workspace_tensors_available = false;
3657 if (IsWorkspaceCheckNeeded(orig_node)) {
3658 AddWorkSpaceEdgeIfNeeded(g, orig_node, &nb, &workspace_tensors,
3659 &are_workspace_tensors_available);
3660 if (are_workspace_tensors_available) {
3661 DCHECK_EQ(workspace_tensors.size(), 1);
3662 nb.Input(workspace_tensors[0].node, workspace_tensors[0].index);
3663 }
3664 }
3665
3666 if (!NativeFormatEnabled()) {
3667 ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, true);
3668 } else {
3669 ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, false);
3670 }
3671
3672 if (DataTypeIsQuantized(orig_node->input_type(0)) ||
3673 DataTypeIsQuantized(orig_node->output_type(0))) {
3674 nb.Attr("_kernel", mkl_op_registry::kMklQuantizedOpLabel);
3675 } else {
3676 nb.Attr("_kernel", mkl_op_registry::kMklNameChangeOpLabel);
3677 }
3678
3679 // Finalize graph and get new node.
3680 s = nb.Finalize(&**g, new_node);
3681 if (s != OkStatus()) {
3682 return s;
3683 }
3684
3685 // In the following code of this function, an unsorted set is used to make
3686 // sure no duplicated edges be added into the new node. Therefore, we can
3687 // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges)
3688 // check in the routine.
3689
3690 // Incoming data edges from 'orig_node' node to new 'new_node' node are
3691 // already copied in BuildNode. We need to handle control edges now.
3692 std::unordered_set<Node*> unique_node;
3693 for (const Edge* e : orig_node->in_edges()) {
3694 if (e->IsControlEdge()) {
3695 auto result = unique_node.insert(e->src());
3696 if (result.second) {
3697 (*g)->AddControlEdge(e->src(), *new_node, true);
3698 }
3699 }
3700 }
3701 unique_node.clear();
3702
3703 // Transfer outgoing edges from 'orig_node' node to new 'new_node' node.
3704 for (const Edge* e : orig_node->out_edges()) {
3705 if (e->IsControlEdge()) {
3706 auto result = unique_node.insert(e->dst());
3707 if (result.second) {
3708 (*g)->AddControlEdge(*new_node, e->dst(), true);
3709 }
3710 } else {
3711 auto result =
3712 (*g)->AddEdge(*new_node, e->src_output(), e->dst(), e->dst_input());
3713 DCHECK(result != nullptr);
3714 }
3715 }
3716
3717 return OkStatus();
3718}
3719
3720Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
3721 Node* orig_node,
3722 const RewriteInfo* ri) {
3723 DCHECK(ri != nullptr);
3724 DCHECK(orig_node != nullptr);
3725
3726 VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString();
3727
3728 Status ret_status = OkStatus();
3729 Node* new_node = nullptr;
3730 if (ri->rewrite_cause == kRewriteForLayoutPropagation) {
3731 ret_status = RewriteNodeForLayoutPropagation(g, orig_node, &new_node, ri);
3732 } else if (ri->rewrite_cause == kRewriteForOpNameChange) {
3733 ret_status = RewriteNodeForJustOpNameChange(g, orig_node, &new_node, ri);
3734 } else {
3735 ret_status = Status(error::Code::INVALID_ARGUMENT,
3736 "Unsupported rewrite cause found."
3737 "RewriteNode will fail.");
3738 }
3739 TF_CHECK_OK(ret_status);
3740
3741 // Copy the runtime device assigned from original code to new node.
3742 new_node->set_assigned_device_name(orig_node->assigned_device_name());
3743
3744 // Delete original node and mark new node as rewritten.
3745 (*g)->RemoveNode(orig_node);
3746
3747 VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString();
3748 return ret_status;
3749}
3750
3751// TODO(mdfaijul): Is there any other elegant way to check for quantized ops
3752// having attributes other than "T"?
3753// Current implementation reflects only QuantizedConv2D and its fused Ops.
3754const MklLayoutRewritePass::RewriteInfo*
3755MklLayoutRewritePass::CheckForQuantizedNodeRewrite(const Node* n) const {
3756 DataType T1, T2;
3757 DataType Tinput, Tfilter;
3758 bool type_attrs_present = false;
3759
3760 if (TryGetNodeAttr(n->def(), "Tinput", &Tinput) &&
3761 TryGetNodeAttr(n->def(), "Tfilter", &Tfilter) &&
3762 mkl_op_registry::IsMklQuantizedOp(
3763 mkl_op_registry::GetMklOpName(n->type_string()), Tinput, Tfilter)) {
3764 type_attrs_present = true;
3765 } else if (TryGetNodeAttr(n->def(), "T1", &T1) &&
3766 TryGetNodeAttr(n->def(), "T2", &T2) &&
3767 mkl_op_registry::IsMklQuantizedOp(
3768 mkl_op_registry::GetMklOpName(n->type_string()), T1, T2)) {
3769 type_attrs_present = true;
3770 }
3771
3772 if (type_attrs_present) {
3773 for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3774 if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3775 return &*ri;
3776 }
3777 }
3778 }
3779
3780 return nullptr;
3781}
3782
3783const MklLayoutRewritePass::RewriteInfo*
3784MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
3785 DCHECK(n);
3786
3787 // QuantizedOps may have attributes other than "T", so decoupled the check
3788 // with a function, CheckForQuantizedNodeRewrite(const Node*).
3789 const RewriteInfo* ri = CheckForQuantizedNodeRewrite(n);
3790 if (ri != nullptr) return ri;
3791
3792 // First check if node along with its type is supported by MKL layer.
3793 // We do not want to rewrite an op into Mkl op if types are not supported.
3794 // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to
3795 // MklRelu if type is INT32.
3796 DataType T;
3797 if (!TryGetNodeAttr(n->def(), "T", &T)) {
3798 return nullptr;
3799 }
3800
3801 // We make an exception for Conv2DGrad and MaxPool related ops as
3802 // the corresponding MKL ops currently do not support the case
3803 // of padding == EXPLICIT yet.
3804 // TODO(intel): support `EXPLICIT` padding for ConvGrad
3805 if (n->type_string() == csinfo_.conv2d_grad_input ||
3806 n->type_string() == csinfo_.conv2d_grad_filter ||
3807 n->type_string() == csinfo_.depthwise_conv2d_grad_filter ||
3808 n->type_string() == csinfo_.depthwise_conv2d_grad_input ||
3809 n->type_string() == csinfo_.conv3d_grad_filter ||
3810 n->type_string() == csinfo_.conv3d_grad_filter ||
3811 n->type_string() == csinfo_.max_pool ||
3812 n->type_string() == csinfo_.max_pool_grad ||
3813 n->type_string() == csinfo_.max_pool3d ||
3814 n->type_string() == csinfo_.max_pool3d_grad) {
3815 string padding;
3816 TF_CHECK_OK(GetNodeAttr(n->def(), "padding", &padding));
3817 if (padding == "EXPLICIT") return nullptr;
3818 }
3819
3820 // We make an exception for __MklDummyConv2DWithBias,
3821 // __MklConv2DBackpropFilterWithBias, and __MklDummyPadWithConv2D since their
3822 // names do not match Mkl node names.
3823 if (n->type_string() != csinfo_.conv2d_with_bias &&
3824 n->type_string() != csinfo_.pad_with_conv2d &&
3825 n->type_string() != csinfo_.pad_with_fused_conv2d &&
3826 n->type_string() != csinfo_.conv2d_grad_filter_with_bias &&
3827 n->type_string() != csinfo_.fused_batch_norm_ex &&
3828 n->type_string() != csinfo_.fused_conv2d &&
3829 n->type_string() != csinfo_.fused_depthwise_conv2d &&
3830 n->type_string() != csinfo_.fused_matmul &&
3831 n->type_string() != csinfo_.fused_conv3d &&
3832 !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()),
3833 T)) {
3834 return nullptr;
3835 }
3836
3837 // We now check if rewrite rule applies for this op. If rewrite rule passes
3838 // for this op, then we rewrite it to Mkl op.
3839 // Find matching RewriteInfo and then check that rewrite rule applies.
3840 for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
3841 if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) {
3842 return &*ri;
3843 }
3844 }
3845
3846 // Else return not found.
3847 return nullptr;
3848}
3849
3850//////////////////////////////////////////////////////////////////////////
3851// Helper functions for node fusion
3852//////////////////////////////////////////////////////////////////////////
3853Status MklLayoutRewritePass::FuseTransposeMklOpTranspose(
3854 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3855 std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs,
3856 string data_format) {
3857 Node* transpose_to_nhwc = nodes[0];
3858 Node* mklop = nodes[1];
3859 Node* transpose_to_nchw = nodes[2];
3860
3861 const int transpose_nhwc_num_inputs = transpose_to_nhwc->num_inputs();
3862 gtl::InlinedVector<Node*, 4> transpose_nhwc_control_edges;
3863 gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nhwc_in(
3864 transpose_nhwc_num_inputs);
3865 FillInputs(transpose_to_nhwc, &transpose_nhwc_control_edges,
3866 &transpose_nhwc_in);
3867
3868 const int mklop_num_inputs = mklop->num_inputs();
3869 gtl::InlinedVector<Node*, 4> mklop_control_edges;
3870 gtl::InlinedVector<std::pair<Node*, int>, 4> mklop_in(mklop_num_inputs);
3871 FillInputs(mklop, &mklop_control_edges, &mklop_in);
3872
3873 const int transpose_nchw_num_inputs = transpose_to_nchw->num_inputs();
3874 gtl::InlinedVector<Node*, 4> transpose_nchw_control_edges;
3875 gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nchw_in(
3876 transpose_nchw_num_inputs);
3877 FillInputs(transpose_to_nchw, &transpose_nchw_control_edges,
3878 &transpose_nchw_in);
3879
3880 // We use same name as original node, but change the op
3881 // type.
3882 NodeBuilder nb(mklop->name(), mklop->type_string());
3883
3884 // Storing the output slots of the input nodes.
3885 for (int i = 0; i < mklop_num_inputs; i++) {
3886 if (mklop_in[i].first == transpose_to_nhwc) {
3887 // Fill "x":
3888 nb.Input(transpose_nhwc_in[0].first, transpose_nhwc_in[0].second);
3889 } else {
3890 // Fill inputs other than "x":
3891 nb.Input(mklop_in[i].first, mklop_in[i].second);
3892 }
3893 }
3894
3895 copy_attrs(const_cast<const Node*>(mklop), &nb, true);
3896 nb.Attr("data_format", data_format);
3897
3898 // Copy the device assigned to old node to new node.
3899 nb.Device(mklop->def().device());
3900
3901 // Create node.
3902 Node* new_node;
3903 TF_CHECK_OK(nb.Finalize(&**g, &new_node));
3904 // No need to check if new_node is null because it will be null only when
3905 // Finalize fails.
3906
3907 // Fill outputs.
3908 for (const Edge* e : transpose_to_nchw->out_edges()) {
3909 if (!e->IsControlEdge()) {
3910 const int kTransposeWithMklOpOutputSlot = 0;
3911 auto new_edge = (*g)->AddEdge(new_node, kTransposeWithMklOpOutputSlot,
3912 e->dst(), e->dst_input());
3913 DCHECK(new_edge);
3914 }
3915 }
3916
3917 // Copy device assigned to old node to new node.
3918 new_node->set_assigned_device_name(mklop->assigned_device_name());
3919
3920 // Copy requested_device and assigned_device_name_index
3921 new_node->set_requested_device(mklop->requested_device());
3922 new_node->set_assigned_device_name_index(mklop->assigned_device_name_index());
3923
3924 (*g)->RemoveNode(transpose_to_nhwc);
3925 (*g)->RemoveNode(mklop);
3926 (*g)->RemoveNode(transpose_to_nchw);
3927
3928 return OkStatus();
3929}
3930
3931Status MklLayoutRewritePass::FuseNode(
3932 std::unique_ptr<Graph>* g, std::vector<Node*>& nodes,
3933 const MklLayoutRewritePass::FusionInfo fi) {
3934 return fi.fuse_func(g, nodes, fi.copy_attrs);
3935}
3936
3937std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo>
3938MklLayoutRewritePass::CheckForNodeFusion(Node* a) const {
3939 // Stores matched nodes, in the same order as node_checkers.
3940 std::vector<Node*> nodes;
3941
3942 for (auto fi = finfo_.begin(); fi != finfo_.end(); ++fi) {
3943 //
3944 // Make sure node "a" and its succeeding nodes (b, c ...), match the pattern
3945 // defined in fusion info (ops[0], ops[1], ...),
3946 // a.k.a. "a->b->c" matches "op1->op2->op3"
3947 //
3948
3949 // Stores the first unvisited outgoing edge of each matched node in "nodes".
3950 std::stack<EdgeSet::const_iterator> current_neighbor_stack;
3951 nodes.clear();
3952
3953 auto node_checker = fi->node_checkers.begin();
3954 if (a != nullptr && (*node_checker)(a)) {
3955 nodes.push_back(a);
3956 current_neighbor_stack.push(a->out_edges().begin());
3957 ++node_checker;
3958 }
3959
3960 while (!nodes.empty()) {
3961 auto& current_neighbor_iter = current_neighbor_stack.top();
3962
3963 if (current_neighbor_iter != nodes.back()->out_edges().end()) {
3964 // Found an unvisited edge. Goes through the edge to get the neighbor.
3965 Node* neighbor_node = (*current_neighbor_iter)->dst();
3966 ++current_neighbor_stack.top(); // Retrieves the next unvisited edge.
3967
3968 if ((*node_checker)(neighbor_node)) {
3969 // Found a match. Stores the node and moves to the next checker.
3970 nodes.push_back(neighbor_node);
3971 current_neighbor_stack.push(neighbor_node->out_edges().begin());
3972 if (++node_checker == fi->node_checkers.end()) {
3973 return make_tuple(true, nodes, *fi);
3974 }
3975 }
3976 } else {
3977 // Removes the current node since none of its neighbor leads to a
3978 // further match.
3979 nodes.pop_back();
3980 current_neighbor_stack.pop();
3981 --node_checker;
3982 }
3983 }
3984 }
3985
3986 return make_tuple(false, std::vector<Node*>(), FusionInfo());
3987}
3988
3989///////////////////////////////////////////////////////////////////////////////
3990// Post-rewrite Mkl metadata fixup pass
3991///////////////////////////////////////////////////////////////////////////////
3992bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g,
3993 const Edge* e_data,
3994 const Edge* e_metadata) {
3995 if (g == nullptr || e_data == nullptr || e_metadata == nullptr) {
3996 return false;
3997 }
3998
3999 Node* n_data = e_data->src();
4000 int n_data_op_slot = e_data->src_output();
4001 int n_metadata_op_slot =
4002 GetTensorMetaDataIndex(n_data_op_slot, n_data->num_outputs());
4003
4004 // If the source of meta edge is a constant node (producing dummy Mkl metadata
4005 // tensor), then we will need to fix.
4006 if (IsConstant(e_metadata->src())) {
4007 Node* e_metadata_dst = e_metadata->dst();
4008 int e_metadata_in_slot = e_metadata->dst_input();
4009 auto new_edge = (*g)->AddEdge(n_data, n_metadata_op_slot, e_metadata_dst,
4010 e_metadata_in_slot);
4011 DCHECK(new_edge);
4012
4013 (*g)->RemoveEdge(e_metadata);
4014 return true;
4015 }
4016
4017 return false;
4018}
4019
4020bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
4021 Node* n) {
4022 bool result = false;
4023
4024 // If graph node is not Mkl node, then return.
4025 DataType T = DT_INVALID;
4026 if (!TryGetNodeAttr(n->def(), "T", &T) ||
4027 !mkl_op_registry::IsMklOp(n->type_string(), T, false)) {
4028 return result;
4029 }
4030
4031 // If it is Mkl node, then check if the input edges to this node that carry
4032 // Mkl metadata are linked up correctly with the source node.
4033
4034 // For Mkl nodes, we generate twice the number of input tensors (n for Mkl
4035 // data tensors + n for Mkl metadata tensors). We need to check for correct
4036 // connection of n metadata tensors only.
4037 int num_data_inputs = n->num_inputs() / 2;
4038 for (int idx = 0; idx < num_data_inputs; idx++) {
4039 // Get the edge connecting input slot with index (idx).
4040 const Edge* e = nullptr;
4041 TF_CHECK_OK(n->input_edge(idx, &e));
4042
4043 // If e is control edge, then skip.
4044 if (e->IsControlEdge()) {
4045 continue;
4046 }
4047
4048 // Check that the source node for edge 'e' is Mkl node. If it is not an Mkl
4049 // node, then we don't need to do anything.
4050 Node* e_src = e->src();
4051 if (TryGetNodeAttr(e_src->def(), "T", &T) &&
4052 mkl_op_registry::IsMklOp(e_src->type_string(), T, false)) {
4053 // Source node for edge 'e' is Mkl node.
4054 // Destination node and destination input slot of e is node 'n' and 'idx'
4055 // resp.
4056 CHECK_EQ(e->dst(), n);
4057 CHECK_EQ(e->dst_input(), idx);
4058
4059 // Let's get edge that carries Mkl metadata corresponding to Mkl data edge
4060 // 'e'. For that, let's first get the input slot of 'n' where the meta
4061 // edge will feed the value.
4062 int e_meta_in_slot =
4063 GetTensorMetaDataIndex(e->dst_input(), n->num_inputs());
4064 const Edge* e_meta = nullptr;
4065 TF_CHECK_OK(n->input_edge(e_meta_in_slot, &e_meta));
4066
4067 // Let's check if we need to fix this meta edge.
4068 if (FixMklMetaDataEdgeIfNeeded(g, e, e_meta)) {
4069 result = true;
4070 }
4071 }
4072 }
4073
4074 return result;
4075}
4076
4077///////////////////////////////////////////////////////////////////////////////
4078// Run function for the pass
4079///////////////////////////////////////////////////////////////////////////////
4080
4081bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
4082 bool result = false;
4083 DCHECK(g);
4084
4085 DumpGraph("Before running MklLayoutRewritePass", &**g);
4086
4087 std::vector<Node*> order;
4088 GetReversePostOrder(**g, &order); // This will give us topological sort.
4089 for (Node* n : order) {
4090 // If node is not an op or it cannot run on CPU device, then skip.
4091 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4092 continue;
4093 }
4094
4095 Node* m = nullptr;
4096 if ((m = CheckForNodeMerge(n)) != nullptr && CanOpRunOnCPUDevice(m)) {
4097 // Check if the node 'n' can be merged with any other node. If it can
4098 // be 'm' contains the node with which it can be merged.
4099 string n1_name = n->name();
4100 string n2_name = m->name();
4101
4102 VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and "
4103 << n2_name << " for merging";
4104
4105 if (MergeNode(g, n, m) == OkStatus()) {
4106 VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and "
4107 << n2_name;
4108 result = true;
4109 }
4110 }
4111 }
4112
4113 DumpGraph("After running MklLayoutRewritePass(NodeMerge)", &**g);
4114
4115 order.clear();
4116 GetReversePostOrder(**g, &order); // This will give us topological sort.
4117 for (Node* n : order) {
4118 // If node is not an op or it cannot run on CPU device, then skip.
4119 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4120 continue;
4121 }
4122
4123 auto check_result = CheckForNodeFusion(n);
4124 bool found_pattern = std::get<0>(check_result);
4125 std::vector<Node*> nodes = std::get<1>(check_result);
4126 const FusionInfo fi = std::get<2>(check_result);
4127
4128 // if "found_pattern" is true, we can do the fusion.
4129 if (found_pattern) {
4130 if (FuseNode(g, nodes, fi) == OkStatus()) {
4131 result = true;
4132 }
4133 }
4134 }
4135 DumpGraph("After running MklLayoutRewritePass(NodeFusion)", &**g);
4136
4137 order.clear();
4138 GetReversePostOrder(**g, &order); // This will give us topological sort.
4139 for (Node* n : order) {
4140 // If node is not an op or it cannot run on CPU device, then skip.
4141 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4142 continue;
4143 }
4144
4145 const RewriteInfo* ri = nullptr;
4146 // We will first search if node is to be rewritten.
4147 if ((ri = CheckForNodeRewrite(n)) != nullptr) {
4148 string node_name = n->name();
4149 string op_name = n->type_string();
4150
4151 VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name
4152 << " with op " << op_name << " for rewrite using"
4153 << " layout optimization.";
4154
4155 if (RewriteNode(g, n, ri) == OkStatus()) {
4156 VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name
4157 << " with op " << op_name << " for Mkl layout optimization.";
4158 result = true;
4159 }
4160 }
4161 }
4162
4163 DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g);
4164
4165 if (!NativeFormatEnabled()) {
4166 order.clear();
4167 GetReversePostOrder(**g, &order); // This will give us topological sort.
4168 for (Node* n : order) {
4169 // If node is not an op or it cannot run on CPU device, then skip.
4170 if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
4171 continue;
4172 }
4173 if (FixMklMetaDataEdges(g, n)) {
4174 string node_name = n->name();
4175 string op_name = n->type_string();
4176
4177 VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node "
4178 << node_name << " with op " << op_name;
4179 result = true;
4180 }
4181 }
4182 DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)",
4183 &**g);
4184 }
4185
4186 return result;
4187}
4188
4189bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) {
4190 return MklLayoutRewritePass().RunPass(g);
4191}
4192
4193Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
4194 if (options.graph == nullptr && options.partition_graphs == nullptr) {
4195 return OkStatus();
4196 }
4197 if (!IsMKLEnabled()) {
4198 VLOG(2) << "TF-MKL: MKL is not enabled";
4199 return OkStatus();
4200 }
4201
4202 auto process_graph = [&](std::unique_ptr<Graph>* g) {
4203 // Get the ownership of a graph
4204 std::unique_ptr<Graph>* ng = std::move(g);
4205 RunPass(ng);
4206 // Return the ownership of a graph back
4207 g->reset(ng->release());
4208 };
4209
4210 if (kMklLayoutRewritePassGroup !=
4211 OptimizationPassRegistry::POST_PARTITIONING) {
4212 // For any pre-partitioning phase, a graph is stored in options.graph.
4213 process_graph(options.graph);
4214 } else {
4215 // For post partitioning phase, graphs are stored in
4216 // options.partition_graphs.
4217 for (auto& pg : *options.partition_graphs) {
4218 process_graph(&pg.second);
4219 }
4220 }
4221
4222 return OkStatus();
4223}
4224
4225} // namespace tensorflow
4226
4227#endif // INTEL_MKL
4228