1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // TODO(intel): Improve error handling in this file; instead of CHECK failing |
17 | // all over the place, we should log an error and execute the original graph. |
18 | #ifdef INTEL_MKL |
19 | |
20 | #include "tensorflow/core/common_runtime/mkl_layout_pass.h" |
21 | |
22 | #include <algorithm> |
23 | #include <functional> |
24 | #include <memory> |
25 | #include <queue> |
26 | #include <set> |
27 | #include <stack> |
28 | #include <tuple> |
29 | #include <unordered_set> |
30 | #include <utility> |
31 | #include <vector> |
32 | |
33 | #include "absl/base/call_once.h" |
34 | #include "tensorflow/core/common_runtime/function.h" |
35 | #include "tensorflow/core/common_runtime/optimization_registry.h" |
36 | #include "tensorflow/core/framework/node_def_util.h" |
37 | #include "tensorflow/core/framework/tensor.pb.h" |
38 | #include "tensorflow/core/graph/algorithm.h" |
39 | #include "tensorflow/core/graph/graph.h" |
40 | #include "tensorflow/core/graph/mkl_graph_util.h" |
41 | #include "tensorflow/core/graph/node_builder.h" |
42 | #include "tensorflow/core/lib/core/status.h" |
43 | #include "tensorflow/core/lib/gtl/array_slice.h" |
44 | #include "tensorflow/core/lib/gtl/map_util.h" |
45 | #include "tensorflow/core/lib/hash/hash.h" |
46 | #include "tensorflow/core/platform/logging.h" |
47 | #include "tensorflow/core/util/tensor_format.h" |
48 | #include "tensorflow/core/util/util.h" |
49 | |
50 | namespace tensorflow { |
51 | |
52 | // This pass implements rewriting of graph to support following scenarios: |
53 | // (A) Merging nodes in the graph |
54 | // (B) Rewriting a node in the graph to a new node |
55 | // Rewrite happens under following scenario: |
56 | // - Propagating Mkl layout as an additional output tensor |
57 | // (we will loosely call a tensor that carries Mkl layout as Mkl tensor |
58 | // henceforth.) from every Mkl supported NN layer. |
59 | // |
60 | // Example of A : Merging nodes in the graph |
61 | // ----------------------------------------- |
62 | // Currently, we merge Conv2D+AddBias together. Consider Conv2D and BiasAdd as: |
63 | // |
64 | // O = Conv2D(A, B) |
65 | // P = BiasAdd(O, C) |
66 | // |
67 | // We merge them into Conv2DWithBias as: |
68 | // P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m) |
69 | // |
70 | // The meaning of A_m, B_m and C_m is explained in B.1. |
71 | // |
72 | // Merge rules: |
73 | // - The merge for Conv2D and BiasAdd happens when the output of Conv2D _only_ |
74 | // goes to BiasAdd. |
75 | // - Also, the intersection of attributes of both the nodes must have same |
76 | // values. |
77 | // - Both the nodes must have been assigned to same device (if any). |
78 | // |
79 | // Example of B.1 : Rewriting nodes to Mkl nodes |
80 | // --------------------------------------------- |
81 | // Consider a Relu node. Current definition of Relu node looks like: |
82 | // |
83 | // O = Relu(A) |
84 | // |
85 | // Relu has 1 input (A), and 1 output (O). |
86 | // |
87 | // This rewrite pass will generate a new graph node for Relu (new node is |
88 | // called MklRelu) as: |
89 | // |
90 | // O, O_m = MklRelu(A, A_m) |
91 | // |
92 | // MklRelu has 2 inputs (A and A_m) and 2 outputs (O and O_m). Here input A is |
93 | // same as input A of Relu; output O is same as output O of Relu. O_m is the |
94 | // additional output tensor that will be set by MklRelu, and it represents |
95 | // Mkl tensor corresponding to O -- in other words, O_m is some kind of |
96 | // metadata for O. A_m is additional input of Relu, and it represents metadata |
97 | // for A - as O_m is metadata for O, A_m is metadata for A. MklRelu receives |
98 | // this metadata from previous node in the graph. |
99 | // |
100 | // When a previous node in the graph is an Mkl node, A_m will represent a valid |
101 | // Mkl tensor. But when a previous node is not an Mkl node, A_m will represent |
102 | // a dummy Mkl tensor. |
103 | // |
104 | // Rewriting rules: |
105 | // - Selection of a node for rewriting happens by registering the op type of |
106 | // the node with the rewriting pass. If the op type is not registered, then |
107 | // all nodes of this op type will not be rewritten. |
108 | // - Number of inputs after rewriting: |
109 | // Since for every input Tensorflow tensor, the rewritten node gets Mkl |
110 | // tensor(s), rewritten node gets 2*N inputs, where N is the number of |
111 | // inputs for the original node. |
112 | // - Number of outputs after rewriting: |
113 | // Since for every output Tensorflow tensor, the rewritten node generates |
114 | // Mkl tensor(s), the rewritten node generates 2*N outputs, where N is the |
115 | // number of outputs of the original node. |
116 | // - Ordering of Tensorflow tensors and Mkl tensors: |
117 | // Since every rewritten node generates twice the number of inputs and |
118 | // outputs, one could imagine various orderings among Tensorflow tensors |
119 | // and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as |
120 | // inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m |
121 | // in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m |
122 | // order. Among N inputs one can get N! permutations. |
123 | // |
124 | // So the question is: which order do we follow? We support 2 types of |
125 | // orderings: (1) interleaved, and (2) contiguous. Interleaved ordering |
126 | // follows an intuitive order where an Mkl tensor follows the |
127 | // corresponding Tensorflow tensor immediately. In the context of the |
128 | // above example, it will be: A, A_m, B, B_m. Note that the ordering rule |
129 | // applies to both the inputs and outputs. Contiguous ordering means |
130 | // all the Tensorflow tensors are contiguous followed by all the Mkl |
131 | // tensors. We use contiguous ordering as default. |
132 | // |
133 | // Graph rewrite algorithm: |
134 | // Algorithm: Graph Rewrite |
135 | // Input: Graph G, Names of the nodes to rewrite and their new names |
136 | // Output: Modified Graph G' if the nodes are modified, G otherwise. |
137 | // Start: |
138 | // N = Topological_Sort(G) // N is a set of nodes in toposort order. |
139 | // foreach node n in N |
140 | // do |
141 | // if (Is_MKL_Op(n)) // Can this node accept an Mkl layout as input. |
142 | // then |
143 | // E = set of <incoming edge and its src_output slot> of n |
144 | // E' = {} // a new set of edges for rewritten node |
145 | // foreach <e,s> in E |
146 | // do |
147 | // E' U {<e,s>} // First copy edge which generates Tensorflow |
148 | // // tensor as it is |
149 | // m = Source node of edge e |
150 | // if Is_Rewritten(m) // Did we rewrite this node in this pass? |
151 | // then |
152 | // E' U {<m,s+1>} // If yes, then m will generate an Mkl |
153 | // // tensor as an additional output. |
154 | // else |
155 | // d = Generate_Dummy_Mkl_Tensor() // If not, generate a dummy |
156 | // // Mkl tensor. |
157 | // E' U {<d,0>} // The dummy Mkl tensor has only 1 output slot. |
158 | // fi |
159 | // done |
160 | // n' = Build_New_Node(G,new_name,E') |
161 | // Mark_Rewritten(n') // Mark the new node as being rewritten. |
162 | // fi |
163 | // done |
164 | // |
165 | // Explanation: |
166 | // For graph rewrite, we visit nodes of the input graph in the |
167 | // topological sort order. With this ordering, we visit nodes in the |
168 | // top-to-bottom fashion. We need this order because while visiting a |
169 | // node we want that all of its input nodes are visited and rewritten if |
170 | // applicable. This is because if we need to rewrite a given node |
171 | // then all of its input nodes need to be fixed (in other words they |
172 | // cannot be deleted later.) |
173 | // |
174 | // While visiting a node, we first check if the op type of the node is |
175 | // an Mkl op. If it is, then we rewrite that node after constructing |
176 | // new inputs to the node. If the op type of the node is not Mkl op, |
177 | // then we do not rewrite that node. |
178 | // |
179 | // Handling workspace propagation for certain ops: |
180 | // |
181 | // Certain backward ops in MKL (MaxPool, LRN and BatchNorm) require |
182 | // passing of a workspace from their respective forward ops. Workspace |
183 | // tensors provide memory for storing results of intermediate operations |
184 | // which are helpful in backward propagation. TensorFlow does not have |
185 | // a notion of a workspace and as a result does not allow producing |
186 | // additional outputs from these forward ops. For these ops, we need |
187 | // to add 2 extra edges between forward ops and their corresponding |
188 | // backward ops - the first extra edge carries a workspace tensor and |
189 | // the second one carries an Mkl tensor for the workspace tensor. |
190 | // |
191 | // Example: |
192 | // |
193 | // Typical graph for MaxPool and its gradient looks like: |
194 | // |
195 | // A = MaxPool(T) |
196 | // B = MaxPoolGrad(X, A, Y) |
197 | // |
198 | // We will transform this graph to propagate the workspace as: |
199 | // (with the contiguous ordering) |
200 | // |
201 | // A, W, A_m, W_m = MklMaxPool(T, T_m) |
202 | // B, B_m = MklMaxPoolGrad(X, A, Y, W, X_m, A_m, Y_m, W_m) |
203 | // |
204 | // Here W is the workspace tensor. Transformed tensor names with the |
205 | // suffix _m are Mkl tensors, and this transformation has been done |
206 | // using the algorithm discussed earlier. The transformation for |
207 | // workspace propagation only adds extra outputs (W, W_m) for a forward |
208 | // op and connects them to the corresponding backward ops. |
209 | // |
210 | // Terms: |
211 | // |
212 | // Forward op name = name of the op in the forward pass |
213 | // where a workspace tensor originates (MaxPool in this example) |
214 | // Backward op name = name of the op in the backward pass that receives |
215 | // a workspace tensor from the forward op (MaxPoolGrad in the example) |
216 | // Slot = Position of the output or input slot that will be |
217 | // used by the workspace tensor (1 for MklMaxPool as W is the 2nd |
218 | // output of MaxPool (0 is 1st); 3 for MklMaxPoolGrad) |
219 | // |
220 | // Question: |
221 | // |
222 | // How do we associate a backward op to a forward op? There can be more |
223 | // than one op with the exact same name. |
224 | // |
225 | // In this example, we associate MaxPoolGrad with MaxPool. But there |
226 | // could be more than one MaxPool ops. To solve this problem, we look |
227 | // for _direct_ edge between a forward op and a backward op (tensor A is |
228 | // flowing along this edge in the example). |
229 | // |
230 | // How do we transform forward and backward ops when there is no direct |
231 | // edge between them? In such a case, we generate dummy tensors for |
232 | // workspace tensors. For the example, transformation of MaxPool will |
233 | // be exactly same as it would be when there is a direct edge between |
234 | // the forward and the backward op --- it is just that MaxPool won't |
235 | // generate any workspace tensor. For MaxPoolGrad, the transformation |
236 | // will also be same, but instead of connecting W and W_m with the |
237 | // outputs of MaxPool, we will produce dummy tensors for them, and we |
238 | // will set workspace_enabled attribute to false. |
239 | // |
240 | class MklLayoutRewritePass : public GraphOptimizationPass { |
241 | public: |
242 | MklLayoutRewritePass() { |
243 | // NOTE: names are alphabetically sorted. |
244 | csinfo_.addn = "AddN" ; |
245 | csinfo_.avg_pool = "AvgPool" ; |
246 | csinfo_.avg_pool_grad = "AvgPoolGrad" ; |
247 | csinfo_.avg_pool3d = "AvgPool3D" ; |
248 | csinfo_.avg_pool3d_grad = "AvgPool3DGrad" ; |
249 | csinfo_.batch_matmul = "BatchMatMul" ; |
250 | csinfo_.batch_matmul_v2 = "BatchMatMulV2" ; |
251 | csinfo_.bias_add = "BiasAdd" ; |
252 | csinfo_.bias_add_grad = "BiasAddGrad" ; |
253 | csinfo_.concat = "Concat" ; |
254 | csinfo_.concatv2 = "ConcatV2" ; |
255 | csinfo_.conjugate_transpose = "ConjugateTranspose" ; |
256 | csinfo_.conv2d = "Conv2D" ; |
257 | csinfo_.conv2d_with_bias = "__MklDummyConv2DWithBias" ; |
258 | csinfo_.conv2d_grad_input = "Conv2DBackpropInput" ; |
259 | csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter" ; |
260 | csinfo_.conv2d_grad_filter_with_bias = |
261 | "__MklDummyConv2DBackpropFilterWithBias" ; |
262 | csinfo_.conv3d = "Conv3D" ; |
263 | csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2" ; |
264 | csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2" ; |
265 | csinfo_.depthwise_conv2d = "DepthwiseConv2dNative" ; |
266 | csinfo_.depthwise_conv2d_grad_input = "DepthwiseConv2dNativeBackpropInput" ; |
267 | csinfo_.depthwise_conv2d_grad_filter = |
268 | "DepthwiseConv2dNativeBackpropFilter" ; |
269 | csinfo_.dequantize = "Dequantize" ; |
270 | csinfo_.einsum = "Einsum" ; |
271 | csinfo_.fused_batch_norm = "FusedBatchNorm" ; |
272 | csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad" ; |
273 | csinfo_.fused_batch_norm_ex = "_FusedBatchNormEx" ; |
274 | csinfo_.fused_batch_norm_v2 = "FusedBatchNormV2" ; |
275 | csinfo_.fused_batch_norm_grad_v2 = "FusedBatchNormGradV2" ; |
276 | csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3" ; |
277 | csinfo_.fused_batch_norm_grad_v3 = "FusedBatchNormGradV3" ; |
278 | csinfo_.fused_conv2d = "_FusedConv2D" ; |
279 | csinfo_.fused_conv3d = "_FusedConv3D" ; |
280 | csinfo_.fused_depthwise_conv2d = "_FusedDepthwiseConv2dNative" ; |
281 | csinfo_.fused_matmul = "_FusedMatMul" ; |
282 | csinfo_.identity = "Identity" ; |
283 | csinfo_.leakyrelu = "LeakyRelu" ; |
284 | csinfo_.leakyrelu_grad = "LeakyReluGrad" ; |
285 | csinfo_.lrn = "LRN" ; |
286 | csinfo_.lrn_grad = "LRNGrad" ; |
287 | csinfo_.matmul = "MatMul" ; |
288 | csinfo_.max_pool = "MaxPool" ; |
289 | csinfo_.max_pool_grad = "MaxPoolGrad" ; |
290 | csinfo_.max_pool3d = "MaxPool3D" ; |
291 | csinfo_.max_pool3d_grad = "MaxPool3DGrad" ; |
292 | csinfo_.mkl_conv2d = "_MklConv2D" ; |
293 | csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput" ; |
294 | csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter" ; |
295 | csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias" ; |
296 | csinfo_.mkl_conv2d_grad_filter_with_bias = |
297 | "_MklConv2DBackpropFilterWithBias" ; |
298 | csinfo_.mkl_depthwise_conv2d_grad_input = |
299 | "_MklDepthwiseConv2dNativeBackpropInput" ; |
300 | csinfo_.mkl_depthwise_conv2d_grad_filter = |
301 | "_MklDepthwiseConv2dNativeBackpropFilter" ; |
302 | csinfo_.mkl_fused_batch_norm_ex = "_MklFusedBatchNormEx" ; |
303 | csinfo_.mkl_fused_conv2d = "_MklFusedConv2D" ; |
304 | csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative" ; |
305 | csinfo_.mkl_fused_matmul = "_MklFusedMatMul" ; |
306 | csinfo_.mkl_native_conv2d_with_bias = "_MklNativeConv2DWithBias" ; |
307 | csinfo_.mkl_native_conv2d_grad_filter_with_bias = |
308 | "_MklNativeConv2DBackpropFilterWithBias" ; |
309 | csinfo_.mkl_native_fused_batch_norm_ex = "_MklNativeFusedBatchNormEx" ; |
310 | csinfo_.mkl_native_fused_conv2d = "_MklNativeFusedConv2D" ; |
311 | csinfo_.mkl_native_fused_conv3d = "_MklNativeFusedConv3D" ; |
312 | csinfo_.mkl_native_fused_depthwise_conv2d = |
313 | "_MklNativeFusedDepthwiseConv2dNative" ; |
314 | csinfo_.mkl_native_fused_matmul = "_MklNativeFusedMatMul" ; |
315 | csinfo_.mkl_native_pad_with_conv2d = "_MklNativePadWithConv2D" ; |
316 | csinfo_.mkl_native_pad_with_fused_conv2d = "_MklNativePadWithFusedConv2D" ; |
317 | csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D" ; |
318 | csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D" ; |
319 | csinfo_.pad = "Pad" ; |
320 | csinfo_.pad_with_conv2d = "__MklDummyPadWithConv2D" ; |
321 | csinfo_.pad_with_fused_conv2d = "__MklDummyPadWithFusedConv2D" ; |
322 | csinfo_.quantized_avg_pool = "QuantizedAvgPool" ; |
323 | csinfo_.quantized_concatv2 = "QuantizedConcatV2" ; |
324 | csinfo_.quantized_conv2d = "QuantizedConv2D" ; |
325 | csinfo_.quantized_conv2d_per_channel = "QuantizedConv2DPerChannel" ; |
326 | csinfo_.quantized_conv2d_with_requantize = "QuantizedConv2DAndRequantize" ; |
327 | csinfo_.quantized_conv2d_with_bias = "QuantizedConv2DWithBias" ; |
328 | csinfo_.quantized_conv2d_with_bias_and_requantize = |
329 | "QuantizedConv2DWithBiasAndRequantize" ; |
330 | csinfo_.quantized_conv2d_and_relu = "QuantizedConv2DAndRelu" ; |
331 | csinfo_.quantized_conv2d_and_relu_and_requantize = |
332 | "QuantizedConv2DAndReluAndRequantize" ; |
333 | csinfo_.quantized_conv2d_with_bias_and_relu = |
334 | "QuantizedConv2DWithBiasAndRelu" ; |
335 | csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize = |
336 | "QuantizedConv2DWithBiasAndReluAndRequantize" ; |
337 | csinfo_.quantized_max_pool = "QuantizedMaxPool" ; |
338 | csinfo_.quantized_conv2d_with_bias_sum_and_relu = |
339 | "QuantizedConv2DWithBiasSumAndRelu" ; |
340 | csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize = |
341 | "QuantizedConv2DWithBiasSumAndReluAndRequantize" ; |
342 | csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize = |
343 | "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize" ; |
344 | csinfo_.quantized_matmul_with_bias = "QuantizedMatMulWithBias" ; |
345 | csinfo_.quantized_matmul_with_bias_and_relu = |
346 | "QuantizedMatMulWithBiasAndRelu" ; |
347 | csinfo_.quantized_matmul_with_bias_and_relu_and_requantize = |
348 | "QuantizedMatMulWithBiasAndReluAndRequantize" ; |
349 | csinfo_.quantized_matmul_with_bias_and_dequantize = |
350 | "QuantizedMatMulWithBiasAndDequantize" ; |
351 | csinfo_.quantized_matmul_with_bias_and_requantize = |
352 | "QuantizedMatMulWithBiasAndRequantize" ; |
353 | csinfo_.quantized_depthwise_conv2d = "QuantizedDepthwiseConv2D" ; |
354 | csinfo_.quantized_depthwise_conv2d_with_bias = |
355 | "QuantizedDepthwiseConv2DWithBias" ; |
356 | csinfo_.quantized_depthwise_conv2d_with_bias_and_relu = |
357 | "QuantizedDepthwiseConv2DWithBiasAndRelu" ; |
358 | csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize = |
359 | "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize" ; |
360 | csinfo_.quantize_v2 = "QuantizeV2" ; |
361 | csinfo_.relu = "Relu" ; |
362 | csinfo_.relu_grad = "ReluGrad" ; |
363 | csinfo_.relu6 = "Relu6" ; |
364 | csinfo_.relu6_grad = "Relu6Grad" ; |
365 | csinfo_.requantize = "Requantize" ; |
366 | csinfo_.tanh = "Tanh" ; |
367 | csinfo_.tanh_grad = "TanhGrad" ; |
368 | csinfo_.reshape = "Reshape" ; |
369 | csinfo_.slice = "Slice" ; |
370 | csinfo_.softmax = "Softmax" ; |
371 | csinfo_.split = "Split" ; |
372 | csinfo_.transpose = "Transpose" ; |
373 | // Element-wise ops. Ensure you also add any new ops to IsOpElementWise |
374 | // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the |
375 | // MklInputConversion op is added before it. |
376 | csinfo_.add = "Add" ; |
377 | csinfo_.add_v2 = "AddV2" ; |
378 | csinfo_.maximum = "Maximum" ; |
379 | csinfo_.mul = "Mul" ; |
380 | csinfo_.squared_difference = "SquaredDifference" ; |
381 | csinfo_.sub = "Sub" ; |
382 | // End - element-wise ops. See note above. |
383 | |
384 | const bool native_fmt = NativeFormatEnabled(); |
385 | // NOTE: names are alphabetically sorted. |
386 | rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), |
387 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
388 | rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), |
389 | CopyAttrsAll, RewriteIfAtleastOneMklInput, |
390 | GetRewriteCause()}); |
391 | rinfo_.push_back( |
392 | {csinfo_.add_v2, mkl_op_registry::GetMklOpName(csinfo_.add_v2), |
393 | CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()}); |
394 | rinfo_.push_back({csinfo_.avg_pool, |
395 | mkl_op_registry::GetMklOpName(csinfo_.avg_pool), |
396 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
397 | rinfo_.push_back({csinfo_.avg_pool_grad, |
398 | mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad), |
399 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
400 | rinfo_.push_back({csinfo_.avg_pool3d, |
401 | mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d), |
402 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
403 | rinfo_.push_back({csinfo_.avg_pool3d_grad, |
404 | mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad), |
405 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
406 | rinfo_.push_back({csinfo_.batch_matmul, |
407 | mkl_op_registry::GetMklOpName(csinfo_.batch_matmul), |
408 | CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange}); |
409 | rinfo_.push_back({csinfo_.einsum, |
410 | mkl_op_registry::GetMklOpName(csinfo_.einsum), |
411 | CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange}); |
412 | rinfo_.push_back({csinfo_.batch_matmul_v2, |
413 | mkl_op_registry::GetMklOpName(csinfo_.batch_matmul_v2), |
414 | CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange}); |
415 | rinfo_.push_back({csinfo_.concat, |
416 | mkl_op_registry::GetMklOpName(csinfo_.concat), |
417 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
418 | rinfo_.push_back({csinfo_.concatv2, |
419 | mkl_op_registry::GetMklOpName(csinfo_.concatv2), |
420 | CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()}); |
421 | rinfo_.push_back( |
422 | {csinfo_.conjugate_transpose, |
423 | mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose), |
424 | CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange}); |
425 | rinfo_.push_back( |
426 | {csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d), |
427 | CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()}); |
428 | rinfo_.push_back({csinfo_.conv2d_with_bias, |
429 | native_fmt ? csinfo_.mkl_native_conv2d_with_bias |
430 | : csinfo_.mkl_conv2d_with_bias, |
431 | CopyAttrsConvCheckConstFilter, AlwaysRewrite, |
432 | GetRewriteCause()}); |
433 | rinfo_.push_back({csinfo_.conv2d_grad_filter, |
434 | mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter), |
435 | CopyAttrsConv, AlwaysRewrite, GetRewriteCause()}); |
436 | rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias, |
437 | native_fmt |
438 | ? csinfo_.mkl_native_conv2d_grad_filter_with_bias |
439 | : csinfo_.mkl_conv2d_grad_filter_with_bias, |
440 | CopyAttrsConv, AlwaysRewrite, GetRewriteCause()}); |
441 | rinfo_.push_back({csinfo_.conv2d_grad_input, |
442 | mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input), |
443 | CopyAttrsConv, AlwaysRewrite, GetRewriteCause()}); |
444 | rinfo_.push_back( |
445 | {csinfo_.conv3d, mkl_op_registry::GetMklOpName(csinfo_.conv3d), |
446 | CopyAttrsConvCheckConstFilter, AlwaysRewrite, GetRewriteCause()}); |
447 | rinfo_.push_back({csinfo_.conv3d_grad_filter, |
448 | mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter), |
449 | CopyAttrsConv, AlwaysRewrite, GetRewriteCause()}); |
450 | rinfo_.push_back({csinfo_.conv3d_grad_input, |
451 | mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input), |
452 | CopyAttrsConv, AlwaysRewrite, GetRewriteCause()}); |
453 | rinfo_.push_back({csinfo_.depthwise_conv2d, |
454 | mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d), |
455 | CopyAttrsConvCheckConstFilter, AlwaysRewrite, |
456 | GetRewriteCause()}); |
457 | rinfo_.push_back( |
458 | {csinfo_.depthwise_conv2d_grad_input, |
459 | mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_input), |
460 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
461 | rinfo_.push_back( |
462 | {csinfo_.depthwise_conv2d_grad_filter, |
463 | mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_filter), |
464 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
465 | rinfo_.push_back( |
466 | {csinfo_.dequantize, mkl_op_registry::GetMklOpName(csinfo_.dequantize), |
467 | CopyAttrsAll, DequantizeRewrite, kRewriteForOpNameChange}); |
468 | rinfo_.push_back({csinfo_.fused_batch_norm, |
469 | mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm), |
470 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
471 | rinfo_.push_back( |
472 | {csinfo_.fused_batch_norm_grad, |
473 | mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad), |
474 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
475 | rinfo_.push_back( |
476 | {csinfo_.fused_batch_norm_v2, |
477 | mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v2), |
478 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
479 | rinfo_.push_back( |
480 | {csinfo_.fused_batch_norm_grad_v2, |
481 | mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v2), |
482 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
483 | |
484 | // Using CopyAttrsAll for V3 on CPU, as there are no additional |
485 | // attributes. |
486 | rinfo_.push_back( |
487 | {csinfo_.fused_batch_norm_v3, |
488 | mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_v3), |
489 | CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()}); |
490 | rinfo_.push_back( |
491 | {csinfo_.fused_batch_norm_grad_v3, |
492 | mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3), |
493 | CopyAttrsAll, FusedBatchNormV3Rewrite, GetRewriteCause()}); |
494 | rinfo_.push_back({csinfo_.fused_batch_norm_ex, |
495 | native_fmt ? csinfo_.mkl_native_fused_batch_norm_ex |
496 | : csinfo_.mkl_fused_batch_norm_ex, |
497 | CopyAttrsAll, FusedBatchNormExRewrite, |
498 | GetRewriteCause()}); |
499 | rinfo_.push_back({csinfo_.fused_conv2d, |
500 | native_fmt ? csinfo_.mkl_native_fused_conv2d |
501 | : csinfo_.mkl_fused_conv2d, |
502 | CopyAttrsFusedConv2DCheckConstFilter, FusedConv2DRewrite, |
503 | GetRewriteCause()}); |
504 | rinfo_.push_back({csinfo_.fused_conv3d, csinfo_.mkl_native_fused_conv3d, |
505 | CopyAttrsAllCheckConstFilter, AlwaysRewrite, |
506 | kRewriteForOpNameChange}); |
507 | rinfo_.push_back({csinfo_.fused_depthwise_conv2d, |
508 | native_fmt ? csinfo_.mkl_native_fused_depthwise_conv2d |
509 | : csinfo_.mkl_fused_depthwise_conv2d, |
510 | CopyAttrsAllCheckConstFilter, FusedDepthwiseConv2DRewrite, |
511 | GetRewriteCause()}); |
512 | rinfo_.push_back({csinfo_.fused_matmul, |
513 | native_fmt ? csinfo_.mkl_native_fused_matmul |
514 | : csinfo_.mkl_fused_matmul, |
515 | CopyAttrsAllCheckConstFilter, FusedMatMulRewrite, |
516 | GetRewriteCause()}); |
517 | rinfo_.push_back( |
518 | {csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity), |
519 | CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()}); |
520 | rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn), |
521 | CopyAttrsAll, LrnRewrite, GetRewriteCause()}); |
522 | rinfo_.push_back({csinfo_.lrn_grad, |
523 | mkl_op_registry::GetMklOpName(csinfo_.lrn_grad), |
524 | CopyAttrsAll, LrnGradRewrite, GetRewriteCause()}); |
525 | rinfo_.push_back({csinfo_.matmul, |
526 | mkl_op_registry::GetMklOpName(csinfo_.matmul), |
527 | CopyAttrsAll, MatMulRewrite, kRewriteForOpNameChange}); |
528 | rinfo_.push_back({csinfo_.leakyrelu, |
529 | mkl_op_registry::GetMklOpName(csinfo_.leakyrelu), |
530 | CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()}); |
531 | rinfo_.push_back({csinfo_.leakyrelu_grad, |
532 | mkl_op_registry::GetMklOpName(csinfo_.leakyrelu_grad), |
533 | CopyAttrsAll, LeakyReluRewrite, GetRewriteCause()}); |
534 | rinfo_.push_back( |
535 | {csinfo_.max_pool, mkl_op_registry::GetMklOpName(csinfo_.max_pool), |
536 | CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()}); |
537 | rinfo_.push_back({csinfo_.max_pool_grad, |
538 | mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad), |
539 | CopyAttrsAll, MaxpoolGradRewrite, GetRewriteCause()}); |
540 | rinfo_.push_back( |
541 | {csinfo_.max_pool3d, mkl_op_registry::GetMklOpName(csinfo_.max_pool3d), |
542 | CopyAttrsAll, NonDepthBatchWisePoolRewrite, GetRewriteCause()}); |
543 | rinfo_.push_back({csinfo_.max_pool3d_grad, |
544 | mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad), |
545 | CopyAttrsAll, Maxpool3DGradRewrite, GetRewriteCause()}); |
546 | rinfo_.push_back( |
547 | {csinfo_.maximum, mkl_op_registry::GetMklOpName(csinfo_.maximum), |
548 | CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()}); |
549 | rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul), |
550 | CopyAttrsAll, RewriteIfAtleastOneMklInput, |
551 | GetRewriteCause()}); |
552 | rinfo_.push_back({csinfo_.pad_with_conv2d, |
553 | native_fmt ? csinfo_.mkl_native_pad_with_conv2d |
554 | : csinfo_.mkl_pad_with_conv2d, |
555 | CopyAttrsAllCheckConstFilter, AlwaysRewrite, |
556 | GetRewriteCause()}); |
557 | rinfo_.push_back({csinfo_.pad_with_fused_conv2d, |
558 | native_fmt ? csinfo_.mkl_native_pad_with_fused_conv2d |
559 | : csinfo_.mkl_pad_with_fused_conv2d, |
560 | CopyAttrsAllCheckConstFilter, AlwaysRewrite, |
561 | GetRewriteCause()}); |
562 | rinfo_.push_back({csinfo_.quantized_avg_pool, |
563 | mkl_op_registry::GetMklOpName(csinfo_.quantized_avg_pool), |
564 | CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange}); |
565 | rinfo_.push_back({csinfo_.quantized_concatv2, |
566 | mkl_op_registry::GetMklOpName(csinfo_.quantized_concatv2), |
567 | CopyAttrsAll, ConcatV2Rewrite, kRewriteForOpNameChange}); |
568 | rinfo_.push_back({csinfo_.quantized_conv2d, |
569 | mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d), |
570 | CopyAttrsQuantizedConv2D, AlwaysRewrite, |
571 | kRewriteForOpNameChange}); |
572 | rinfo_.push_back( |
573 | {csinfo_.quantized_conv2d_per_channel, |
574 | mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_per_channel), |
575 | CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); |
576 | rinfo_.push_back({csinfo_.quantized_conv2d_with_requantize, |
577 | mkl_op_registry::GetMklOpName( |
578 | csinfo_.quantized_conv2d_with_requantize), |
579 | CopyAttrsQuantizedConv2D, AlwaysRewrite, |
580 | kRewriteForOpNameChange}); |
581 | rinfo_.push_back( |
582 | {csinfo_.quantized_conv2d_with_bias, |
583 | mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_with_bias), |
584 | CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); |
585 | rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_requantize, |
586 | mkl_op_registry::GetMklOpName( |
587 | csinfo_.quantized_conv2d_with_bias_and_requantize), |
588 | CopyAttrsQuantizedConv2D, AlwaysRewrite, |
589 | kRewriteForOpNameChange}); |
590 | rinfo_.push_back( |
591 | {csinfo_.quantized_conv2d_and_relu, |
592 | mkl_op_registry::GetMklOpName(csinfo_.quantized_conv2d_and_relu), |
593 | CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); |
594 | rinfo_.push_back({csinfo_.quantized_conv2d_and_relu_and_requantize, |
595 | mkl_op_registry::GetMklOpName( |
596 | csinfo_.quantized_conv2d_and_relu_and_requantize), |
597 | CopyAttrsQuantizedConv2D, AlwaysRewrite, |
598 | kRewriteForOpNameChange}); |
599 | rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_and_relu, |
600 | mkl_op_registry::GetMklOpName( |
601 | csinfo_.quantized_conv2d_with_bias_and_relu), |
602 | CopyAttrsQuantizedConv2D, AlwaysRewrite, |
603 | kRewriteForOpNameChange}); |
604 | rinfo_.push_back( |
605 | {csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize, |
606 | mkl_op_registry::GetMklOpName( |
607 | csinfo_.quantized_conv2d_with_bias_and_relu_and_requantize), |
608 | CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); |
609 | rinfo_.push_back({csinfo_.quantized_max_pool, |
610 | mkl_op_registry::GetMklOpName(csinfo_.quantized_max_pool), |
611 | CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange}); |
612 | rinfo_.push_back({csinfo_.quantized_conv2d_with_bias_sum_and_relu, |
613 | mkl_op_registry::GetMklOpName( |
614 | csinfo_.quantized_conv2d_with_bias_sum_and_relu), |
615 | CopyAttrsQuantizedConv2D, AlwaysRewrite, |
616 | kRewriteForOpNameChange}); |
617 | rinfo_.push_back( |
618 | {csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize, |
619 | mkl_op_registry::GetMklOpName( |
620 | csinfo_.quantized_conv2d_with_bias_sum_and_relu_and_requantize), |
621 | CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); |
622 | rinfo_.push_back( |
623 | {csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize, |
624 | mkl_op_registry::GetMklOpName( |
625 | csinfo_.quant_conv2d_with_bias_signed_sum_and_relu_and_requantize), |
626 | CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); |
627 | rinfo_.push_back( |
628 | {csinfo_.quantized_matmul_with_bias, |
629 | mkl_op_registry::GetMklOpName(csinfo_.quantized_matmul_with_bias), |
630 | CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite, |
631 | kRewriteForOpNameChange}); |
632 | rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_relu, |
633 | mkl_op_registry::GetMklOpName( |
634 | csinfo_.quantized_matmul_with_bias_and_relu), |
635 | CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite, |
636 | kRewriteForOpNameChange}); |
637 | rinfo_.push_back( |
638 | {csinfo_.quantized_matmul_with_bias_and_relu_and_requantize, |
639 | mkl_op_registry::GetMklOpName( |
640 | csinfo_.quantized_matmul_with_bias_and_relu_and_requantize), |
641 | CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite, |
642 | kRewriteForOpNameChange}); |
643 | rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_requantize, |
644 | mkl_op_registry::GetMklOpName( |
645 | csinfo_.quantized_matmul_with_bias_and_requantize), |
646 | CopyAttrsQuantizedMatMulWithBias, AlwaysRewrite, |
647 | kRewriteForOpNameChange}); |
648 | rinfo_.push_back({csinfo_.quantized_matmul_with_bias_and_dequantize, |
649 | mkl_op_registry::GetMklOpName( |
650 | csinfo_.quantized_matmul_with_bias_and_dequantize), |
651 | CopyAttrsQuantizedMatMulWithBiasAndDequantize, |
652 | AlwaysRewrite, kRewriteForOpNameChange}); |
653 | rinfo_.push_back( |
654 | {csinfo_.quantized_depthwise_conv2d, |
655 | mkl_op_registry::GetMklOpName(csinfo_.quantized_depthwise_conv2d), |
656 | CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); |
657 | rinfo_.push_back({csinfo_.quantized_depthwise_conv2d_with_bias, |
658 | mkl_op_registry::GetMklOpName( |
659 | csinfo_.quantized_depthwise_conv2d_with_bias), |
660 | CopyAttrsQuantizedConv2D, AlwaysRewrite, |
661 | kRewriteForOpNameChange}); |
662 | rinfo_.push_back( |
663 | {csinfo_.quantized_depthwise_conv2d_with_bias_and_relu, |
664 | mkl_op_registry::GetMklOpName( |
665 | csinfo_.quantized_depthwise_conv2d_with_bias_and_relu), |
666 | CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); |
667 | rinfo_.push_back( |
668 | {csinfo_.quantized_depthwise_conv2d_with_bias_and_relu_and_requantize, |
669 | mkl_op_registry::GetMklOpName( |
670 | csinfo_ |
671 | .quantized_depthwise_conv2d_with_bias_and_relu_and_requantize), |
672 | CopyAttrsQuantizedConv2D, AlwaysRewrite, kRewriteForOpNameChange}); |
673 | rinfo_.push_back({csinfo_.quantize_v2, |
674 | mkl_op_registry::GetMklOpName(csinfo_.quantize_v2), |
675 | CopyAttrsAll, QuantizeOpRewrite, |
676 | kRewriteForOpNameChange}); |
677 | rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), |
678 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
679 | rinfo_.push_back({csinfo_.relu_grad, |
680 | mkl_op_registry::GetMklOpName(csinfo_.relu_grad), |
681 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
682 | rinfo_.push_back({csinfo_.relu6, |
683 | mkl_op_registry::GetMklOpName(csinfo_.relu6), |
684 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
685 | rinfo_.push_back({csinfo_.relu6_grad, |
686 | mkl_op_registry::GetMklOpName(csinfo_.relu6_grad), |
687 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
688 | rinfo_.push_back({csinfo_.requantize, |
689 | mkl_op_registry::GetMklOpName(csinfo_.requantize), |
690 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
691 | // Optimized TanhGrad support exists only in DNNL 1.x. |
692 | rinfo_.push_back({csinfo_.tanh, mkl_op_registry::GetMklOpName(csinfo_.tanh), |
693 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
694 | rinfo_.push_back({csinfo_.tanh_grad, |
695 | mkl_op_registry::GetMklOpName(csinfo_.tanh_grad), |
696 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
697 | rinfo_.push_back({csinfo_.reshape, |
698 | mkl_op_registry::GetMklOpName(csinfo_.reshape), |
699 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
700 | rinfo_.push_back( |
701 | {csinfo_.slice, mkl_op_registry::GetMklOpName(csinfo_.slice), |
702 | CopyAttrsAll, RewriteIfAtleastOneMklInput, GetRewriteCause()}); |
703 | rinfo_.push_back({csinfo_.softmax, |
704 | mkl_op_registry::GetMklOpName(csinfo_.softmax), |
705 | CopyAttrsAll, AlwaysRewrite, GetRewriteCause()}); |
706 | |
707 | rinfo_.push_back({csinfo_.squared_difference, |
708 | mkl_op_registry::GetMklOpName(csinfo_.squared_difference), |
709 | CopyAttrsAll, RewriteIfAtleastOneMklInput, |
710 | GetRewriteCause()}); |
711 | rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub), |
712 | CopyAttrsAll, RewriteIfAtleastOneMklInput, |
713 | GetRewriteCause()}); |
714 | rinfo_.push_back({csinfo_.transpose, |
715 | mkl_op_registry::GetMklOpName(csinfo_.transpose), |
716 | CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange}); |
717 | |
718 | // Add info about which ops to add workspace edge to and the slots. |
719 | wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3}); |
720 | wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3}); |
721 | wsinfo_.push_back( |
722 | {csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3}); |
723 | |
724 | // Add a rule for merging nodes |
725 | minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add, |
726 | csinfo_.conv2d_with_bias, GetConv2DOrBiasAdd}); |
727 | |
728 | // Merge Pad and Conv2d, only if the pad op is "Pad" |
729 | // Doesn't merge if pad op is "PadV2" or "MirrorPad" |
730 | minfo_.push_back( |
731 | {csinfo_.pad, csinfo_.conv2d, csinfo_.pad_with_conv2d, GetPadOrConv2D}); |
732 | |
733 | minfo_.push_back({csinfo_.pad, csinfo_.fused_conv2d, |
734 | csinfo_.pad_with_fused_conv2d, GetPadOrFusedConv2D}); |
735 | |
736 | minfo_.push_back({csinfo_.conv2d_grad_filter, csinfo_.bias_add_grad, |
737 | csinfo_.conv2d_grad_filter_with_bias, |
738 | GetConv2DBackpropFilterOrBiasAddGrad}); |
739 | |
740 | // The fusion patterns in "finfo_" that show up first will get applied |
741 | // first, for example, graph "A->B->C-D" and finfo_ is {A->B->C to ABC, |
742 | // A->B->C->D to ABCD}, since the first gets applied first, the final |
743 | // graph will be ABC->D. |
744 | } |
745 | |
746 | // Standard interface to run pass |
747 | Status Run(const GraphOptimizationPassOptions& options); |
748 | |
749 | // Helper function which does most of heavy lifting for rewriting |
750 | // Mkl nodes to propagate Mkl tensor as additional output |
751 | // |
752 | // Extracts common functionality between Run public interface and |
753 | // test interface. |
754 | // |
755 | // @return true, if and only if graph is mutated; false otherwise. |
756 | bool RunPass(std::unique_ptr<Graph>* g); |
757 | |
758 | /// Cause for rewrite |
759 | /// Currently, we only support 2 causes - either for Mkl layout propagation |
760 | /// which is the most common case, or for just a name change (used in case |
761 | /// of ops like MatMul, Transpose, which do not support Mkl layout) |
762 | enum RewriteCause { kRewriteForLayoutPropagation, kRewriteForOpNameChange }; |
763 | |
764 | // Get the op rewrite cause depending on whether native format mode |
765 | // is enabled or not. |
766 | RewriteCause GetRewriteCause() { |
767 | if (NativeFormatEnabled()) { |
768 | return kRewriteForOpNameChange; |
769 | } else { |
770 | return kRewriteForLayoutPropagation; |
771 | } |
772 | } |
773 | |
774 | /// Structure to specify the name of an original node, its new name after |
775 | /// rewrite, the number of inputs to the original node, the function to |
776 | /// be used to copy attributes for the op, and the rule (if any) which |
777 | /// must hold for rewriting the node |
778 | typedef struct { |
779 | string name; // Original name of op of the node in the graph |
780 | string new_name; // New name of the op of the node in the graph |
781 | // A function handler to copy attributes from an old node to a new node. |
782 | std::function<void(const Node*, NodeBuilder*, bool)> copy_attrs; |
783 | // A rule under which to rewrite this node |
784 | std::function<bool(const Node*)> rewrite_rule; |
785 | // Why are we rewriting? |
786 | RewriteCause rewrite_cause; |
787 | } RewriteInfo; |
788 | |
789 | /// Structure to specify a forward op, a backward op, and the slot numbers |
790 | /// in the forward and backward ops where we will add a workspace edge. |
791 | typedef struct { |
792 | string fwd_op; // Name of a forward op in the graph |
793 | string bwd_op; // Name of a backward op in the graph |
794 | int fwd_slot; // Output slot in the forward op node where actual |
795 | // output tensor resides |
796 | int bwd_slot; // Input slot in the backward op node where actual |
797 | // input tensor resides |
798 | int ws_fwd_slot; // Output slot in the forward op node where workspace |
799 | // edge is added |
800 | int ws_bwd_slot; // Input slot in the backward op node where workspace |
801 | // edge is added |
802 | } WorkSpaceInfo; |
803 | |
804 | /// Structure to specify information used in node merge of 2 operators |
805 | typedef struct { |
806 | string op1; // Node string for one operator. |
807 | string op2; // Node string for second operator. |
808 | string new_node; // Name of the node after merge |
809 | // Function that enables user of the node merger to specify how to find |
810 | // second operator given the first operator. |
811 | std::function<Node*(const Node*)> get_node_to_be_merged; |
812 | } MergeInfo; |
813 | |
814 | // Structure to specify information used in node fusion of 3+ operators |
815 | typedef struct { |
816 | std::string pattern_name; // Name to describe this pattern, such as |
817 | // "Transpose_Mklop_Transpose". |
818 | std::vector<std::function<bool(const Node*)> > |
819 | node_checkers; // Extra restriction checker for these ops |
820 | std::function<Status( |
821 | std::unique_ptr<Graph>*, std::vector<Node*>&, |
822 | std::function<void(const Node*, NodeBuilder* nb, bool)>)> |
823 | fuse_func; |
824 | std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs; |
825 | } FusionInfo; |
826 | |
827 | // |
828 | // Dimension indices for 2D tensor. |
829 | // |
830 | struct NCHW { |
831 | enum dim { N = 0, C = 1, H = 2, W = 3 }; |
832 | }; |
833 | |
834 | struct NHWC { |
835 | enum dim { N = 0, H = 1, W = 2, C = 3 }; |
836 | }; |
837 | |
838 | // |
839 | // dimension indices for 3D tensor. |
840 | // |
841 | struct NCDHW { |
842 | enum dim { N = 0, C = 1, D = 2, H = 3, W = 4 }; |
843 | }; |
844 | |
845 | struct NDHWC { |
846 | enum dim { N = 0, D = 1, H = 2, W = 3, C = 4 }; |
847 | }; |
848 | |
849 | /// Structure to store all constant strings |
850 | /// NOTE: names are alphabetically sorted. |
851 | typedef struct { |
852 | string addn; |
853 | string add; |
854 | string add_v2; |
855 | string avg_pool; |
856 | string avg_pool_grad; |
857 | string avg_pool3d; |
858 | string avg_pool3d_grad; |
859 | string batch_matmul; |
860 | string batch_matmul_v2; |
861 | string bias_add; |
862 | string bias_add_grad; |
863 | string concat; |
864 | string concatv2; |
865 | string conjugate_transpose; |
866 | string conv2d; |
867 | string conv2d_with_bias; |
868 | string conv2d_grad_input; |
869 | string conv2d_grad_filter; |
870 | string conv2d_grad_filter_with_bias; |
871 | string conv3d; |
872 | string conv3d_grad_input; |
873 | string conv3d_grad_filter; |
874 | string depthwise_conv2d; |
875 | string depthwise_conv2d_grad_input; |
876 | string depthwise_conv2d_grad_filter; |
877 | string dequantize; |
878 | string einsum; |
879 | string fused_batch_norm; |
880 | string fused_batch_norm_grad; |
881 | string fused_batch_norm_ex; |
882 | string fused_batch_norm_v2; |
883 | string fused_batch_norm_grad_v2; |
884 | string fused_batch_norm_v3; |
885 | string fused_batch_norm_grad_v3; |
886 | string fused_conv2d; |
887 | string fused_conv3d; |
888 | string fused_depthwise_conv2d; |
889 | string fused_matmul; |
890 | string identity; |
891 | string leakyrelu; |
892 | string leakyrelu_grad; |
893 | string lrn; |
894 | string lrn_grad; |
895 | string matmul; |
896 | string max_pool; |
897 | string max_pool_grad; |
898 | string max_pool3d; |
899 | string max_pool3d_grad; |
900 | string maximum; |
901 | string mkl_conv2d; |
902 | string mkl_conv2d_grad_input; |
903 | string mkl_conv2d_grad_filter; |
904 | string mkl_conv2d_grad_filter_with_bias; |
905 | string mkl_conv2d_with_bias; |
906 | string mkl_depthwise_conv2d_grad_input; |
907 | string mkl_depthwise_conv2d_grad_filter; |
908 | string mkl_fused_batch_norm_ex; |
909 | string mkl_fused_conv2d; |
910 | string mkl_fused_depthwise_conv2d; |
911 | string mkl_fused_matmul; |
912 | string mkl_native_conv2d_with_bias; |
913 | string mkl_native_conv2d_grad_filter_with_bias; |
914 | string mkl_native_fused_batch_norm_ex; |
915 | string mkl_native_fused_conv2d; |
916 | string mkl_native_fused_conv3d; |
917 | string mkl_native_fused_depthwise_conv2d; |
918 | string mkl_native_fused_matmul; |
919 | string mkl_native_pad_with_conv2d; |
920 | string mkl_native_pad_with_fused_conv2d; |
921 | string mkl_pad_with_conv2d; |
922 | string mkl_pad_with_fused_conv2d; |
923 | string mul; |
924 | string pad; |
925 | string pad_with_conv2d; |
926 | string pad_with_fused_conv2d; |
927 | string quantized_avg_pool; |
928 | string quantized_conv2d; |
929 | string quantized_conv2d_per_channel; |
930 | string quantized_conv2d_with_requantize; |
931 | string quantized_conv2d_with_bias; |
932 | string quantized_conv2d_with_bias_and_requantize; |
933 | string quantized_conv2d_and_relu; |
934 | string quantized_conv2d_and_relu_and_requantize; |
935 | string quantized_conv2d_with_bias_and_relu; |
936 | string quantized_conv2d_with_bias_and_relu_and_requantize; |
937 | string quantized_concatv2; |
938 | string quantized_max_pool; |
939 | string quantized_conv2d_with_bias_sum_and_relu; |
940 | string quantized_conv2d_with_bias_sum_and_relu_and_requantize; |
941 | string quant_conv2d_with_bias_signed_sum_and_relu_and_requantize; |
942 | string quantized_matmul_with_bias; |
943 | string quantized_matmul_with_bias_and_relu; |
944 | string quantized_matmul_with_bias_and_relu_and_requantize; |
945 | string quantized_matmul_with_bias_and_requantize; |
946 | string quantized_matmul_with_bias_and_dequantize; |
947 | string quantized_depthwise_conv2d; |
948 | string quantized_depthwise_conv2d_with_bias; |
949 | string quantized_depthwise_conv2d_with_bias_and_relu; |
950 | string quantized_depthwise_conv2d_with_bias_and_relu_and_requantize; |
951 | string quantize_v2; |
952 | string relu; |
953 | string relu_grad; |
954 | string relu6; |
955 | string relu6_grad; |
956 | string requantize; |
957 | string tanh; |
958 | string tanh_grad; |
959 | string transpose; |
960 | string reshape; |
961 | string slice; |
962 | string softmax; |
963 | string split; |
964 | string squared_difference; |
965 | string sub; |
966 | } ConstStringsInfo; |
967 | |
968 | private: |
969 | /// Maintain info about nodes to rewrite |
970 | std::vector<RewriteInfo> rinfo_; |
971 | |
972 | /// Maintain info about nodes to add workspace edge |
973 | std::vector<WorkSpaceInfo> wsinfo_; |
974 | |
975 | /// Maintain info about nodes to be merged |
976 | std::vector<MergeInfo> minfo_; |
977 | |
978 | /// Maintain info about nodes to be fused |
979 | std::vector<FusionInfo> finfo_; |
980 | |
981 | /// Maintain structure of constant strings |
982 | static ConstStringsInfo csinfo_; |
983 | |
984 | private: |
985 | // Is OpDef::ArgDef a list type? It could be N * T or list(type). |
986 | // Refer to opdef.proto for details of list type. |
987 | inline bool ArgIsList(const OpDef::ArgDef& arg) const { |
988 | return !arg.type_list_attr().empty() || !arg.number_attr().empty(); |
989 | } |
990 | |
991 | // Get length of a list in 'n' if 'arg' is of list type. Refer to |
992 | // description of ArgIsList for definition of list type. |
993 | inline int GetTensorListLength(const OpDef::ArgDef& arg, const Node* n) { |
994 | CHECK_EQ(ArgIsList(arg), true); |
995 | int N = 0; |
996 | const string attr_name = !arg.type_list_attr().empty() |
997 | ? arg.type_list_attr() |
998 | : arg.number_attr(); |
999 | if (!arg.type_list_attr().empty()) { |
1000 | std::vector<DataType> value; |
1001 | TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value)); |
1002 | N = value.size(); |
1003 | } else { |
1004 | TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &N)); |
1005 | } |
1006 | return N; |
1007 | } |
1008 | |
1009 | // Can op represented by node 'n' run on DEVICE_CPU? |
1010 | // Op can run on CPU with MKL if the runtime assigned device or the |
1011 | // user requested device contains device CPU, or both are empty. |
1012 | bool CanOpRunOnCPUDevice(const Node* n) { |
1013 | bool result = true; |
1014 | string reason; |
1015 | |
1016 | // Substring that should be checked for in device name for CPU device. |
1017 | const char* const kCPUDeviceSubStr = "CPU" ; |
1018 | const char* const kXLACPUDeviceSubStr = "XLA_CPU" ; |
1019 | |
1020 | // If Op has been specifically assigned to a non-CPU or XLA_CPU device, then |
1021 | // No. |
1022 | if (!n->assigned_device_name().empty() && |
1023 | (!absl::StrContains(n->assigned_device_name(), kCPUDeviceSubStr) || |
1024 | absl::StrContains(n->assigned_device_name(), kXLACPUDeviceSubStr))) { |
1025 | result = false; |
1026 | reason = "Op has been assigned a runtime device that is not CPU." ; |
1027 | } |
1028 | |
1029 | // If user has specifically assigned this op to a non-CPU or XLA_CPU device, |
1030 | // then No. |
1031 | if (!n->def().device().empty() && |
1032 | (!absl::StrContains(n->def().device(), kCPUDeviceSubStr) || |
1033 | absl::StrContains(n->def().device(), kXLACPUDeviceSubStr))) { |
1034 | result = false; |
1035 | reason = "User has assigned a device that is not CPU." ; |
1036 | } |
1037 | |
1038 | if (result == false) { |
1039 | VLOG(1) << "MklLayoutRewritePass: Skipping rewriting of the node " |
1040 | << n->type_string() << ", reason: " << reason; |
1041 | } |
1042 | |
1043 | // Otherwise Yes. |
1044 | return result; |
1045 | } |
1046 | |
1047 | // Return a node that can be merged with input node 'n' |
1048 | // |
1049 | // @return pointer to the node if we can find such a |
1050 | // node. Otherwise, it returns nullptr. |
1051 | Node* CheckForNodeMerge(const Node* n) const; |
1052 | |
1053 | // Merge node 'm' with node 'n'. |
1054 | // Currently, we merge (1) Conv2D with BiasAdd, and (2) BiasAddGrad with |
1055 | // Conv2DBackpropFilter. |
1056 | // |
1057 | // Input nodes m and n may be deleted if the call to |
1058 | // this function is successful. Attempt to use the pointers |
1059 | // after the call to function may result in undefined behaviors. |
1060 | // |
1061 | // @input g - input graph, m - graph node, n - graph node to be merged with m |
1062 | // @return OkStatus(), if merging is successful and supported. |
1063 | // Returns appropriate Status error code otherwise. |
1064 | // Graph is updated in case nodes are merged. Otherwise, it is |
1065 | // not updated. |
1066 | Status MergeNode(std::unique_ptr<Graph>* g, Node* m, Node* n); |
1067 | |
1068 | // Helper function to merge different nodes |
1069 | Status MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, Node* m, Node* n); |
1070 | Status MergePadWithConv2D(std::unique_ptr<Graph>* g, Node* m, Node* n); |
1071 | Status MergeConv2DBackpropFilterWithBiasAddGrad(std::unique_ptr<Graph>* g, |
1072 | Node* m, Node* n); |
1073 | |
1074 | // Find BiasAdd or Conv2D node that can be merged with input node 'm'. |
1075 | // If input 'm' is BiasAdd, then check if there exists Conv2D node that can be |
1076 | // merged with 'm'. If input 'm' is Conv2D, then check if there exists BiasAdd |
1077 | // node that can be merged with 'm'. |
1078 | static Node* GetConv2DOrBiasAdd(const Node* m) { |
1079 | DCHECK(m); |
1080 | Node* n = nullptr; |
1081 | |
1082 | DataType T_m; |
1083 | TF_CHECK_OK(GetNodeAttr(m->def(), "T" , &T_m)); |
1084 | |
1085 | // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16 |
1086 | if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n; |
1087 | |
1088 | if (m->type_string() == csinfo_.bias_add) { |
1089 | // If a is BiasAdd, then Conv2D is 0th input of BiasAdd. |
1090 | TF_CHECK_OK(m->input_node(0, &n)); |
1091 | } else { |
1092 | CHECK_EQ(m->type_string(), csinfo_.conv2d); |
1093 | // Go over all output edges and search for BiasAdd Node. |
1094 | // 0th input of BiasAdd is Conv2D. |
1095 | for (const Edge* e : m->out_edges()) { |
1096 | if (!e->IsControlEdge() && |
1097 | e->dst()->type_string() == csinfo_.bias_add && |
1098 | e->dst_input() == 0) { |
1099 | n = e->dst(); |
1100 | break; |
1101 | } |
1102 | } |
1103 | } |
1104 | |
1105 | if (n == nullptr) { |
1106 | VLOG(1) << "MklLayoutRewritePass: Could not find matching " |
1107 | << "Conv2D and BiasAdd node for merging. Input node: " |
1108 | << m->DebugString(); |
1109 | } |
1110 | |
1111 | return n; |
1112 | } |
1113 | |
1114 | // Find Pad or Conv2D node that can be merged with input node 'm'. |
1115 | // If input 'm' is Pad, then check if there exists Conv2D node that can be |
1116 | // merged with 'm'. If input 'm' is Conv2D, then check if there exists Pad |
1117 | // node that can be merged with 'm'. |
1118 | static Node* GetPadOrConv2D(const Node* m) { |
1119 | DCHECK(m); |
1120 | Node* n = nullptr; |
1121 | |
1122 | DataType T_m; |
1123 | TF_CHECK_OK(GetNodeAttr(m->def(), "T" , &T_m)); |
1124 | |
1125 | // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16 |
1126 | if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n; |
1127 | |
1128 | const Node* conv_node; |
1129 | if (m->type_string() == csinfo_.pad) { |
1130 | // If m is Pad, then Conv2D is the output of Pad. |
1131 | for (const Edge* e : m->out_edges()) { |
1132 | if (!e->IsControlEdge() && e->dst()->type_string() == csinfo_.conv2d) { |
1133 | n = e->dst(); |
1134 | conv_node = n; |
1135 | break; |
1136 | } |
1137 | } |
1138 | } else { |
1139 | DCHECK_EQ(m->type_string(), csinfo_.conv2d); |
1140 | // If m is conv2D, Go over all input edges |
1141 | // and search for Pad Node. |
1142 | for (const Edge* e : m->in_edges()) { |
1143 | if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) { |
1144 | n = e->src(); |
1145 | conv_node = m; |
1146 | break; |
1147 | } |
1148 | } |
1149 | } |
1150 | // Check if only VALID type of padding is used |
1151 | // or not. |
1152 | if (n != nullptr) { |
1153 | string padding; |
1154 | TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding" , &padding)); |
1155 | if (padding != "VALID" ) |
1156 | // Then do not merge. |
1157 | // Only VALID type of padding in conv op can be |
1158 | // merged with Pad op. |
1159 | n = nullptr; |
1160 | } else { |
1161 | VLOG(1) << "MklLayoutRewritePass: Could not find matching " |
1162 | << "Pad and Conv2D node for merging. Input node: " |
1163 | << m->DebugString(); |
1164 | } |
1165 | |
1166 | return n; |
1167 | } |
1168 | |
1169 | // Find Pad or _FusedConv2D node that can be merged with input node 'm'. |
1170 | // If input 'm' is Pad, then check if there exists _FusedConv2D node that can |
1171 | // be merged with 'm'. If input 'm' is _FusedConv2D, then check if there |
1172 | // exists Pad node that can be merged with 'm'. |
1173 | static Node* GetPadOrFusedConv2D(const Node* m) { |
1174 | DCHECK(m); |
1175 | Node* n = nullptr; |
1176 | |
1177 | const Node* conv_node; |
1178 | if (m->type_string() == csinfo_.pad) { |
1179 | // If m is Pad, then _FusedConv2D is the output of Pad. |
1180 | for (const Edge* e : m->out_edges()) { |
1181 | if (!e->IsControlEdge() && |
1182 | e->dst()->type_string() == csinfo_.fused_conv2d) { |
1183 | n = e->dst(); |
1184 | conv_node = n; |
1185 | break; |
1186 | } |
1187 | } |
1188 | } else { |
1189 | DCHECK_EQ(m->type_string(), csinfo_.fused_conv2d); |
1190 | // If m is _FusedConv2D, Go over all input edges |
1191 | // and search for Pad node. |
1192 | for (const Edge* e : m->in_edges()) { |
1193 | if (!e->IsControlEdge() && e->src()->type_string() == csinfo_.pad) { |
1194 | n = e->src(); |
1195 | conv_node = m; |
1196 | break; |
1197 | } |
1198 | } |
1199 | } |
1200 | // Check if only VALID type of padding is used or not. |
1201 | if (n != nullptr) { |
1202 | string padding; |
1203 | string data_format; |
1204 | string filter_format; |
1205 | int num_host_args = 0; |
1206 | TF_CHECK_OK(GetNodeAttr(conv_node->def(), "padding" , &padding)); |
1207 | TF_CHECK_OK(GetNodeAttr(conv_node->def(), "data_format" , &data_format)); |
1208 | TF_CHECK_OK( |
1209 | GetNodeAttr(conv_node->def(), "filter_format" , &filter_format)); |
1210 | TF_CHECK_OK( |
1211 | GetNodeAttr(conv_node->def(), "num_host_args" , &num_host_args)); |
1212 | |
1213 | if ((data_format != "NCHW" && data_format != "NHWC" ) || |
1214 | (filter_format != "HWIO" && filter_format != "OIHW" )) { |
1215 | n = nullptr; |
1216 | VLOG(1) << "MklLayoutRewritePass: Could match Pad and _FusedConv2D " |
1217 | << "nodes but cannot merge them. Only conv ops with NCHW or " |
1218 | << "NHWC data format and HWIO or OIHW filter format can be " |
1219 | << "merged with Pad op Input node: " << m->DebugString(); |
1220 | } else if (padding != "VALID" ) { |
1221 | // Then do not merge. |
1222 | n = nullptr; |
1223 | VLOG(1) << "MklLayoutRewritePass: Could match Pad and _FusedConv2D " |
1224 | << "nodes but cannot merge them. Only conv ops with padding " |
1225 | << "type VALID can be merged with Pad op Input node: " |
1226 | << m->DebugString(); |
1227 | } else if (num_host_args != 0) { |
1228 | n = nullptr; |
1229 | VLOG(1) << "MklLayoutRewritePass: Could match Pad and _FusedConv2D " |
1230 | << "nodes but cannot merge them. Only conv ops without host " |
1231 | << "args can be merged with Pad op Input node: " |
1232 | << m->DebugString(); |
1233 | } |
1234 | } else { |
1235 | VLOG(1) << "MklLayoutRewritePass: Could not find matching " |
1236 | << "Pad and _FusedConv2D node for merging. Input node: " |
1237 | << m->DebugString(); |
1238 | } |
1239 | |
1240 | return n; |
1241 | } |
1242 | |
1243 | // Find Conv2DBackpropFilter or BiasAddGrad node that can be merged with input |
1244 | // node 'm'. If input 'm' is Conv2DBackpropFilter, then check if there exists |
1245 | // BiasAddGrad node that can be merged with 'm'. If input 'm' is BiasAddGrad, |
1246 | // then check if there exists Conv2DBackpropFilter node that can be merged |
1247 | // with 'm'. |
1248 | // |
1249 | // Graph that will allow us to connect Conv2DBackpropFilter with BiasAddGrad |
1250 | // would look like: |
1251 | // |
1252 | // _ = Conv2DBackpropFilter(F, _, G) |
1253 | // _ = BiasAddGrad(G) |
1254 | // |
1255 | // So 1st input of BiasAddGrad connects with 3rd input of |
1256 | // Conv2DBackpropFilter and vice versa. |
1257 | static Node* GetConv2DBackpropFilterOrBiasAddGrad(const Node* m) { |
1258 | DCHECK(m); |
1259 | Node* n = nullptr; |
1260 | const Node* conv2d_backprop_filter = nullptr; |
1261 | |
1262 | DataType T_m; |
1263 | TF_CHECK_OK(GetNodeAttr(m->def(), "T" , &T_m)); |
1264 | |
1265 | // Don't try to merge if datatype is not DT_FLOAT or DT_BFLOAT16 |
1266 | if (T_m != DT_FLOAT && T_m != DT_BFLOAT16) return n; |
1267 | |
1268 | if (m->type_string() == csinfo_.bias_add_grad) { |
1269 | // Get 1st input 'g' of BiasAddGrad. |
1270 | Node* g = nullptr; |
1271 | TF_CHECK_OK(m->input_node(0, &g)); |
1272 | // Now traverse all outgoing edges from g that have destination node as |
1273 | // Conv2DBackpropFilter. |
1274 | for (const Edge* e : g->out_edges()) { |
1275 | if (!e->IsControlEdge() && |
1276 | e->dst()->type_string() == csinfo_.conv2d_grad_filter && |
1277 | e->dst_input() == 2 /* 3rd input of BackpropFilter */) { |
1278 | n = e->dst(); |
1279 | conv2d_backprop_filter = n; |
1280 | break; |
1281 | } |
1282 | } |
1283 | } else { |
1284 | conv2d_backprop_filter = m; |
1285 | CHECK_EQ(m->type_string(), csinfo_.conv2d_grad_filter); |
1286 | // Get 3rd input 'g' of Conv2DBackpropFilter. |
1287 | Node* g = nullptr; |
1288 | TF_CHECK_OK(m->input_node(2, &g)); |
1289 | // Now traverse all outgoing edges from g that have destination node as |
1290 | // BiasAddGrad. |
1291 | for (const Edge* e : g->out_edges()) { |
1292 | if (!e->IsControlEdge() && |
1293 | e->dst()->type_string() == csinfo_.bias_add_grad && |
1294 | e->dst_input() == 0 /* 1st input of BiasAddGrad */) { |
1295 | n = e->dst(); |
1296 | break; |
1297 | } |
1298 | } |
1299 | } |
1300 | |
1301 | // Do not merge if padding type is EXPLICIT. |
1302 | // TODO(intel): Support `EXPLICIT` padding for MklConv2DBackpropFilter. |
1303 | if (conv2d_backprop_filter != nullptr) { |
1304 | string padding; |
1305 | TF_CHECK_OK( |
1306 | GetNodeAttr(conv2d_backprop_filter->def(), "padding" , &padding)); |
1307 | if (padding == "EXPLICIT" ) { |
1308 | // Then do not merge. |
1309 | VLOG(1) << "MklLayoutRewritePass: Could match Conv2DBackpropFilter " |
1310 | << "and BiasAddGrad nodes but cannot merge them. " |
1311 | << "EXPLICIT padding is not supported now. " |
1312 | << conv2d_backprop_filter->DebugString(); |
1313 | return nullptr; |
1314 | } |
1315 | } |
1316 | |
1317 | if (n == nullptr) { |
1318 | VLOG(1) << "MklLayoutRewritePass: Could not find matching " |
1319 | << "Conv2DBackpropFilter and BiasAddGrad node for merging. " |
1320 | << "Input node: " << m->DebugString(); |
1321 | } |
1322 | return n; |
1323 | } |
1324 | |
1325 | // Return a node that can be fused with input node 'n' |
1326 | // |
1327 | // @return tuple. If we can find such nodes, the first |
1328 | // element of the tuple is a true. Otherwise, it's false. |
1329 | std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo> |
1330 | CheckForNodeFusion(Node* n) const; |
1331 | |
1332 | // Fuse nodes in the vector "nodes" |
1333 | Status FuseNode(std::unique_ptr<Graph>* g, std::vector<Node*>& nodes, |
1334 | const MklLayoutRewritePass::FusionInfo fi); |
1335 | |
1336 | // Fuse transpose(to "NHWC") + mklop("NHWC") + transpose(to "NCHW") into |
1337 | // mklop("NCHW"). |
1338 | // Here "mklop" can be any MKL-DNN supported op, such as Conv2D. |
1339 | static Status FuseTransposeMklOpTranspose( |
1340 | std::unique_ptr<Graph>* g, std::vector<Node*>& nodes, |
1341 | std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs, |
1342 | string data_format); |
1343 | |
1344 | static bool CheckForTranspose(const Node* node, std::vector<int> perm) { |
1345 | // Check if node's type is "Transpose" |
1346 | if (node->type_string() != "Transpose" ) return false; |
1347 | |
1348 | // If "Transpose" has multiple output data edges, also don't fuse it. |
1349 | if (node->num_outputs() > 1 || node->out_edges().size() > 1) return false; |
1350 | |
1351 | // Check if has out control edge. If true, this is a training graph. |
1352 | // Currently we focus on inference and do no fusion in training. |
1353 | // Note: this constraint will eventually be removed, if we enabled this |
1354 | // fusion for training |
1355 | // in the future. |
1356 | for (const Edge* e : node->out_edges()) { |
1357 | if (e->IsControlEdge()) { |
1358 | return false; |
1359 | } |
1360 | } |
1361 | |
1362 | // If "Transpose" has input control edges, don't fuse on it. |
1363 | for (const Edge* e : node->in_edges()) { |
1364 | if (e->IsControlEdge()) { |
1365 | return false; |
1366 | } |
1367 | } |
1368 | |
1369 | // We compared the tensor containing the permutation order ("perm_node") |
1370 | // with our desired order ("perm"). If they're exactly match, this check |
1371 | // succeed and returns true. |
1372 | for (const Edge* e : node->in_edges()) { |
1373 | if (!e->IsControlEdge()) { |
1374 | const Node* perm_node = e->src(); |
1375 | |
1376 | const int kPermTensorIndex = 1; |
1377 | if (perm_node->type_string() == "Const" && |
1378 | e->dst_input() == kPermTensorIndex) { |
1379 | // we find the "perm" node, now try to retrieve its value. |
1380 | const TensorProto* proto = nullptr; |
1381 | TF_CHECK_OK(GetNodeAttr(perm_node->def(), "value" , &proto)); |
1382 | |
1383 | DataType type; |
1384 | TF_CHECK_OK(GetNodeAttr(perm_node->def(), "dtype" , &type)); |
1385 | |
1386 | Tensor tensor; |
1387 | if (!tensor.FromProto(*proto)) { |
1388 | TF_CHECK_OK(errors::InvalidArgument( |
1389 | "Could not construct Tensor from TensorProto in node: " , |
1390 | node->name())); |
1391 | return false; |
1392 | } |
1393 | // Current fusion only supports 4D or 5D tensors according to `perm` |
1394 | // vector, return false otherwise. |
1395 | if (tensor.dim_size(0) != perm.size()) return false; |
1396 | DCHECK_EQ(tensor.dims(), 1); |
1397 | if (type == DT_INT32) { |
1398 | const auto tensor_content = tensor.flat<int>().data(); |
1399 | for (int i = 0; i < perm.size(); ++i) |
1400 | if (tensor_content[i] != perm[i]) return false; |
1401 | return true; |
1402 | } else if (type == DT_INT64) { |
1403 | const auto tensor_content = tensor.flat<int64_t>().data(); |
1404 | for (int i = 0; i < perm.size(); ++i) |
1405 | if (tensor_content[i] != perm[i]) return false; |
1406 | return true; |
1407 | } |
1408 | return false; |
1409 | } |
1410 | } |
1411 | } |
1412 | return false; |
1413 | } |
1414 | |
1415 | static bool CheckForMklOp(const Node* node, string name = "" ) { |
1416 | if (node == nullptr) return false; |
1417 | |
1418 | if (!name.empty() && node->type_string() != name) { |
1419 | return false; |
1420 | } |
1421 | |
1422 | // if mklop has multiple outputs, don't fuse it. |
1423 | if (node->num_outputs() > 1) return false; |
1424 | |
1425 | if (node->out_edges().size() > 1) return false; |
1426 | |
1427 | DataType T; |
1428 | TF_CHECK_OK(GetNodeAttr(node->def(), "T" , &T)); |
1429 | return mkl_op_registry::IsMklOp( |
1430 | mkl_op_registry::GetMklOpName(node->type_string()), T); |
1431 | } |
1432 | |
1433 | // Check if the node 'n' has any applicable rewrite rule |
1434 | // We check for 2 scenarios for rewrite. |
1435 | // |
1436 | // @return RewriteInfo* for the applicable rewrite rule |
1437 | const RewriteInfo* CheckForNodeRewrite(const Node* n) const; |
1438 | const RewriteInfo* CheckForQuantizedNodeRewrite(const Node* n) const; |
1439 | |
1440 | // Default rewrite rule to be used in scenario 1 for rewrite. |
1441 | // @return - true (since we want to always rewrite) |
1442 | static bool AlwaysRewrite(const Node* n) { return true; } |
1443 | |
1444 | // Rewrite rule which considers "context" of the current node to decide if we |
1445 | // should rewrite. By "context" we currently mean all the inputs of current |
1446 | // node. The idea is if none of the inputs of current node are not MKL nodes, |
1447 | // then rewriting current node to MKL node _may not_ offer any performance |
1448 | // improvement. |
1449 | // |
1450 | // One such case is element-wise ops. For such ops, we reuse the Eigen |
1451 | // implementation and pass the MKL metadata tensor through so we can avoid |
1452 | // conversions. However, if all incoming edges are in TF format, we don't |
1453 | // need all this overhead, so replace the elementwise node only if at least |
1454 | // one of its parents is a MKL node. |
1455 | // |
1456 | // More generally, all memory- or IO-bound ops (such as Identity) may fall |
1457 | // under this category. |
1458 | // |
1459 | // @input - Input graph node to be rewritten |
1460 | // @return - true if node is to be rewritten as MKL node; false otherwise. |
1461 | static bool RewriteIfAtleastOneMklInput(const Node* n) { |
1462 | DataType T; |
1463 | if (GetNodeAttr(n->def(), "T" , &T).ok() && |
1464 | mkl_op_registry::IsMklOp( |
1465 | mkl_op_registry::GetMklOpName(n->type_string()), T)) { |
1466 | for (auto e : n->in_edges()) { |
1467 | if (e->IsControlEdge()) continue; |
1468 | if (mkl_op_registry::IsMklOp(e->src())) { |
1469 | return true; |
1470 | } |
1471 | } |
1472 | } |
1473 | return false; |
1474 | } |
1475 | |
1476 | static bool MatMulRewrite(const Node* n) { |
1477 | DataType T; |
1478 | TF_CHECK_OK(GetNodeAttr(n->def(), "T" , &T)); |
1479 | if ((T == DT_FLOAT) || (T == DT_BFLOAT16)) { |
1480 | VLOG(2) << "Rewriting MatMul to _MklMatMul" ; |
1481 | return true; |
1482 | } |
1483 | return false; |
1484 | } |
1485 | // For oneDNN, only int32 is supported for axis data type |
1486 | static bool ConcatV2Rewrite(const Node* n) { |
1487 | DataType T; |
1488 | TF_CHECK_OK(GetNodeAttr(n->def(), "Tidx" , &T)); |
1489 | return (T == DT_INT32); |
1490 | } |
1491 | |
1492 | static bool DequantizeRewrite(const Node* n) { |
1493 | DCHECK(n); |
1494 | Node* input = nullptr; |
1495 | TF_CHECK_OK(n->input_node(0, &input)); |
1496 | string mode_string; |
1497 | int axis = -1; |
1498 | TF_CHECK_OK(GetNodeAttr(n->def(), "mode" , &mode_string)); |
1499 | TF_CHECK_OK(GetNodeAttr(n->def(), "axis" , &axis)); |
1500 | if (mode_string != "SCALED" ) { |
1501 | VLOG(1) << "DequantizeRewrite: Mode is not SCALED. " |
1502 | << "This case is not optimized by Intel MKL kernel, thus using " |
1503 | "Eigen op for Dequantize op." ; |
1504 | return false; |
1505 | } |
1506 | if (input->IsConstant()) { |
1507 | VLOG(1) << "DequantizeRewrite: Trying to dequantize a Const node which " |
1508 | << "could possibly be a filter. " |
1509 | << "This case is not supported by Intel MKL kernel, thus using " |
1510 | "Eigen op for Dequantize op." ; |
1511 | return false; |
1512 | } |
1513 | |
1514 | if (axis != -1) { |
1515 | VLOG(1) << "DequantizeRewrite: Using Eigen op for Dequantize op because " |
1516 | << "dimension is specified for per slice dequantization. " ; |
1517 | return false; |
1518 | } |
1519 | return true; |
1520 | } |
1521 | |
1522 | // Rewrite rule for _FusedMatMul. |
1523 | // @return - true (no transpose attribute for input 1); |
1524 | // false otherwise. |
1525 | static bool FusedMatMulRewrite(const Node* n) { |
1526 | bool trans_a; |
1527 | |
1528 | // Do not rewrite with transpose attribute because reorder has performance |
1529 | // impact. |
1530 | TF_CHECK_OK(GetNodeAttr(n->def(), "transpose_a" , &trans_a)); |
1531 | |
1532 | return !trans_a; |
1533 | } |
1534 | |
1535 | // Check if we are performing pooling on depth or batch. If it is, then we |
1536 | // do not rewrite MaxPool node to Mkl version. |
1537 | // @return - true (if it is not a depth/batch wise pooling case); |
1538 | // false otherwise. |
1539 | static bool NonDepthBatchWisePoolRewrite(const Node* n) { |
1540 | DCHECK(n); |
1541 | |
1542 | string data_format_str; |
1543 | TensorFormat data_format; |
1544 | std::vector<int32> ksize, strides; |
1545 | TF_CHECK_OK(GetNodeAttr(n->def(), "ksize" , &ksize)); |
1546 | TF_CHECK_OK(GetNodeAttr(n->def(), "strides" , &strides)); |
1547 | TF_CHECK_OK(GetNodeAttr(n->def(), "data_format" , &data_format_str)); |
1548 | bool result = FormatFromString(data_format_str, &data_format); |
1549 | DCHECK(result); |
1550 | |
1551 | // Condition that specifies non-batch-wise and non-depth-wise pooling. |
1552 | if (GetTensorDim(ksize, data_format, 'N') == 1 && |
1553 | GetTensorDim(strides, data_format, 'N') == 1 && |
1554 | GetTensorDim(ksize, data_format, 'C') == 1 && |
1555 | GetTensorDim(strides, data_format, 'C') == 1) { |
1556 | return true; |
1557 | } |
1558 | |
1559 | return false; |
1560 | } |
1561 | |
1562 | // If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized |
1563 | // path. The unoptimized path is slow. Thus we don't rewrite the node |
1564 | // and use default Eigen. But for depth_radius=2, MKL DNN optimized |
1565 | // path is taken, i.e., eigen node is rewritten by MKl DNN node. |
1566 | static bool LrnRewrite(const Node* n) { |
1567 | DCHECK(n); |
1568 | |
1569 | int depth_radius; |
1570 | TF_CHECK_OK(GetNodeAttr(n->def(), "depth_radius" , &depth_radius)); |
1571 | |
1572 | // if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN |
1573 | // and use eigen node instead |
1574 | if (depth_radius == 2) { |
1575 | return true; |
1576 | } |
1577 | VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which" |
1578 | << "case is not optimized by Intel MKL, thus using Eigen op" |
1579 | << "for LRN " ; |
1580 | |
1581 | return false; |
1582 | } |
1583 | |
1584 | static bool LrnGradRewrite(const Node* n) { |
1585 | DCHECK(n); |
1586 | bool do_rewrite = false; |
1587 | |
1588 | for (const Edge* e : n->in_edges()) { |
1589 | // Rewrite only if there is corresponding LRN, i.e workspace is available |
1590 | if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 && |
1591 | e->src()->type_string() == |
1592 | mkl_op_registry::GetMklOpName(csinfo_.lrn) && |
1593 | e->src_output() == 0) { |
1594 | do_rewrite = true; |
1595 | break; |
1596 | } |
1597 | } |
1598 | return do_rewrite; |
1599 | } |
1600 | |
1601 | // MKL-DNN's LeakyRelu(feature) = feature (if feature > 0), or |
1602 | // feature * alpha (otherwise), |
1603 | // while TensorFlow's LeakyRelu(feature) = max(feature, feature * alpha). |
1604 | // These two algorithms are not consistent when alpha > 1, |
1605 | // so we only rewrite LeakyRelu to MKL OP when alpha <= 1. |
1606 | static bool LeakyReluRewrite(const Node* n) { |
1607 | DCHECK(n); |
1608 | |
1609 | float alpha; |
1610 | bool has_attr = TryGetNodeAttr(n->def(), "alpha" , &alpha); |
1611 | DCHECK(has_attr); |
1612 | |
1613 | // If the alpha of LeakyRelu is less than 1, rewrite the node. |
1614 | // Otherwise eigen node is used instead. |
1615 | if (alpha <= 1) { |
1616 | return true; |
1617 | } |
1618 | VLOG(1) << "LeakyReluRewrite: The model sets alpha is greater than 1 " |
1619 | << "which case is not optimized by Intel MKL, thus using Eigen op" |
1620 | << "for LeakyRelu " ; |
1621 | |
1622 | return false; |
1623 | } |
1624 | |
1625 | static bool QuantizeOpRewrite(const Node* n) { |
1626 | DCHECK(n); |
1627 | Node* filter_node = nullptr; |
1628 | TF_CHECK_OK(n->input_node(0, &filter_node)); |
1629 | bool narrow_range = false; |
1630 | int axis = -1; |
1631 | string mode_string; |
1632 | string round_mode_string; |
1633 | DataType type; |
1634 | TryGetNodeAttr(n->def(), "narrow_range" , &narrow_range); |
1635 | TryGetNodeAttr(n->def(), "axis" , &axis); |
1636 | TF_CHECK_OK(GetNodeAttr(n->def(), "mode" , &mode_string)); |
1637 | TF_CHECK_OK(GetNodeAttr(n->def(), "round_mode" , &round_mode_string)); |
1638 | TF_CHECK_OK(GetNodeAttr(n->def(), "T" , &type)); |
1639 | |
1640 | if (narrow_range) { |
1641 | VLOG(1) << "QuantizeOpRewrite: narrow range is enabled for quantization." |
1642 | << "This case is not optimized by Intel MKL, " |
1643 | << "thus using Eigen op for Quantize op " ; |
1644 | return false; |
1645 | } |
1646 | if (axis != -1) { |
1647 | VLOG(1) << "QuantizeOpRewrite: dimension is specified for " |
1648 | << "per slice quantization." |
1649 | << "This case is not optimized by Intel MKL, " |
1650 | << "thus using Eigen op for Quantize op " ; |
1651 | return false; |
1652 | } |
1653 | if (!((mode_string == "SCALED" && round_mode_string == "HALF_TO_EVEN" ) || |
1654 | (mode_string == "MIN_FIRST" ))) { |
1655 | VLOG(1) << "QuantizeOpRewrite: Mode is not SCALED or MIN_FIRST and/or" |
1656 | << "rounding mode is not HALF_TO_EVEN. " |
1657 | << "This case is not optimized by Intel MKL, thus using Eigen op" |
1658 | << "for Quantize op " ; |
1659 | return false; |
1660 | } |
1661 | if (filter_node->IsConstant()) { |
1662 | VLOG(1) << "QuantizeOpRewrite: Trying to quantize a node which " |
1663 | << "is a constant. " |
1664 | << "This case is not supported by the kernel, thus using Eigen op" |
1665 | << "for Quantize op " ; |
1666 | |
1667 | return false; |
1668 | } |
1669 | if (mode_string == "MIN_FIRST" ) { |
1670 | if (type != DT_QUINT8) { |
1671 | VLOG(1) << "QuantizeOpRewrite: For MIN_FIRST mode the data type is " |
1672 | << "not DT_UINT8. This case is not optimized by Intel MKL, " |
1673 | << "thus using Eigen op for Quantize op " ; |
1674 | return false; |
1675 | } |
1676 | } |
1677 | return true; |
1678 | } |
1679 | |
1680 | static bool MaxpoolGradRewrite(const Node* n) { |
1681 | DCHECK(n); |
1682 | bool do_rewrite = false; |
1683 | for (const Edge* e : n->in_edges()) { |
1684 | // Rewrite only if there is corresponding Maxpool, i.e workspace is |
1685 | // available |
1686 | if (e->dst()->type_string() == csinfo_.max_pool_grad && |
1687 | e->dst_input() == 1 && |
1688 | e->src()->type_string() == |
1689 | mkl_op_registry::GetMklOpName(csinfo_.max_pool) && |
1690 | e->src_output() == 0) { |
1691 | do_rewrite = true; |
1692 | break; |
1693 | } |
1694 | } |
1695 | return do_rewrite; |
1696 | } |
1697 | |
1698 | static bool Maxpool3DGradRewrite(const Node* n) { |
1699 | DCHECK(n); |
1700 | for (const Edge* e : n->in_edges()) { |
1701 | // Rewrite only if there is corresponding Maxpool3D, i.e., workspace is |
1702 | // available |
1703 | if (e->dst()->type_string() == csinfo_.max_pool3d_grad && |
1704 | e->dst_input() == 1 && |
1705 | e->src()->type_string() == |
1706 | mkl_op_registry::GetMklOpName(csinfo_.max_pool3d) && |
1707 | e->src_output() == 0) { |
1708 | return true; |
1709 | } |
1710 | } |
1711 | return false; |
1712 | } |
1713 | |
1714 | static bool FusedBatchNormV3Rewrite(const Node* n) { |
1715 | DCHECK(n); |
1716 | if (Check5DFormat(n->def())) { |
1717 | VLOG(1) << "Graph Rewrite: FusedBatchNorm(Grad)V3 op currently does not " |
1718 | << "support 5D tensors." ; |
1719 | return false; |
1720 | } |
1721 | return true; |
1722 | } |
1723 | |
1724 | static bool FusedBatchNormExRewrite(const Node* n) { |
1725 | DCHECK(n); |
1726 | |
1727 | int num_side_inputs; |
1728 | TF_CHECK_OK(GetNodeAttr(n->def(), "num_side_inputs" , &num_side_inputs)); |
1729 | string activation_mode; |
1730 | TF_CHECK_OK(GetNodeAttr(n->def(), "activation_mode" , &activation_mode)); |
1731 | |
1732 | // if the num_side_inputs is not 0, don't rewrite the node. |
1733 | if (num_side_inputs != 0) { |
1734 | VLOG(1) << "FusedBatchNormExRewrite: The model sets num_side_inputs" |
1735 | << "larger than 0 is not optimized by Intel MKL." ; |
1736 | return false; |
1737 | } |
1738 | |
1739 | // if the activation_mode is not 'Relu', don't rewrite the node. |
1740 | if (activation_mode != "Relu" ) { |
1741 | VLOG(1) << "FusedBatchNormExRewrite: Only Relu activation mode is" |
1742 | << "supported by Intel MKL." ; |
1743 | return false; |
1744 | } |
1745 | |
1746 | return true; |
1747 | } |
1748 | |
1749 | static bool FusedConv2DRewrite(const Node* n) { |
1750 | // MKL DNN currently doesn't support all fusions that grappler fuses |
1751 | // together with Conv2D (ex. batchnorm). We rewrite _FusedConv2D only if |
1752 | // it includes those we support. |
1753 | DataType T; |
1754 | if (!TryGetNodeAttr(n->def(), "T" , &T) || |
1755 | !mkl_op_registry::IsMklOp(NativeFormatEnabled() |
1756 | ? csinfo_.mkl_native_fused_conv2d |
1757 | : csinfo_.mkl_fused_conv2d, |
1758 | T)) { |
1759 | return false; |
1760 | } |
1761 | |
1762 | string data_format; |
1763 | string filter_format; |
1764 | int num_host_args = 0; |
1765 | TF_CHECK_OK(GetNodeAttr(n->def(), "data_format" , &data_format)); |
1766 | TF_CHECK_OK(GetNodeAttr(n->def(), "filter_format" , &filter_format)); |
1767 | TF_CHECK_OK(GetNodeAttr(n->def(), "num_host_args" , &num_host_args)); |
1768 | if ((data_format != "NCHW" && data_format != "NHWC" ) || |
1769 | (filter_format != "HWIO" && filter_format != "OIHW" ) || |
1770 | (num_host_args != 0)) { |
1771 | return false; |
1772 | } |
1773 | |
1774 | std::vector<string> fused_ops; |
1775 | TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops" , &fused_ops)); |
1776 | return (fused_ops == std::vector<string>{"BiasAdd" } || |
1777 | fused_ops == std::vector<string>{"Relu" } || |
1778 | fused_ops == std::vector<string>{"Relu6" } || |
1779 | fused_ops == std::vector<string>{"Elu" } || |
1780 | fused_ops == std::vector<string>{"BiasAdd" , "Relu" } || |
1781 | fused_ops == std::vector<string>{"BiasAdd" , "Relu6" } || |
1782 | fused_ops == std::vector<string>{"BiasAdd" , "Elu" } || |
1783 | fused_ops == std::vector<string>{"BiasAdd" , "Add" } || |
1784 | fused_ops == std::vector<string>{"BiasAdd" , "Add" , "Relu" } || |
1785 | fused_ops == std::vector<string>{"BiasAdd" , "Add" , "Relu6" } || |
1786 | fused_ops == std::vector<string>{"BiasAdd" , "Add" , "Elu" } || |
1787 | fused_ops == std::vector<string>{"LeakyRelu" } || |
1788 | fused_ops == std::vector<string>{"BiasAdd" , "LeakyRelu" } || |
1789 | fused_ops == std::vector<string>{"BiasAdd" , "Add" , "LeakyRelu" } || |
1790 | fused_ops == std::vector<string>{"FusedBatchNorm" } || |
1791 | fused_ops == std::vector<string>{"FusedBatchNorm" , "Relu" } || |
1792 | fused_ops == std::vector<string>{"FusedBatchNorm" , "Relu6" } || |
1793 | fused_ops == std::vector<string>{"FusedBatchNorm" , "Elu" } || |
1794 | fused_ops == std::vector<string>{"FusedBatchNorm" , "LeakyRelu" }); |
1795 | } |
1796 | |
1797 | static bool FusedDepthwiseConv2DRewrite(const Node* n) { |
1798 | // MKL DNN currently doesn't support all fusions that grappler fuses |
1799 | // together with DepthwiseConv2D (ex. batchnorm). We rewrite |
1800 | // _FusedDepthwiseConv2DNative only if it includes those we support. |
1801 | DataType T; |
1802 | if (!TryGetNodeAttr(n->def(), "T" , &T) || |
1803 | !mkl_op_registry::IsMklOp( |
1804 | NativeFormatEnabled() ? csinfo_.mkl_native_fused_depthwise_conv2d |
1805 | : csinfo_.mkl_fused_depthwise_conv2d, |
1806 | T)) { |
1807 | return false; |
1808 | } |
1809 | |
1810 | std::vector<string> fused_ops; |
1811 | TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops" , &fused_ops)); |
1812 | return (fused_ops == std::vector<string>{"BiasAdd" } || |
1813 | fused_ops == std::vector<string>{"BiasAdd" , "Relu" } || |
1814 | fused_ops == std::vector<string>{"BiasAdd" , "Relu6" } || |
1815 | fused_ops == std::vector<string>{"BiasAdd" , "Elu" }); |
1816 | } |
1817 | |
1818 | // Rewrites input node to a new node specified by its matching rewrite info. |
1819 | // |
1820 | // Method first searches matching rewrite info for input node and then |
1821 | // uses that info to rewrite. |
1822 | // |
1823 | // Input node may be deleted in case of rewrite. Attempt to use the node |
1824 | // after the call can result in undefined behaviors. |
1825 | // |
1826 | // @input g - input graph, n - Node to be rewritten, |
1827 | // ri - matching rewriteinfo |
1828 | // @return OkStatus(), if the input node is rewritten; |
1829 | // Returns appropriate Status error code otherwise. |
1830 | // Graph is updated in case the input node is rewritten. |
1831 | // Otherwise, it is not updated. |
1832 | Status RewriteNode(std::unique_ptr<Graph>* g, Node* n, const RewriteInfo* ri); |
1833 | |
1834 | // Rewrites input node to just change its operator name. The number of |
1835 | // inputs to the node and the number of outputs remain the same. Attributes |
1836 | // of the new node could be copied from attributes of the old node or |
1837 | // modified. copy_attrs field of RewriteInfo controls this. |
1838 | // |
1839 | // Conceptually, it allows us to rewrite: |
1840 | // |
1841 | // f[a=v1,b=v2](x,y) -> g[a'=v3,b'=v4](x,y) |
1842 | // |
1843 | // Attributes can be altered without any restrictions --- they could be |
1844 | // copied, modified, or deleted completely. |
1845 | // |
1846 | // @input g - input graph, orig_node - Node to be rewritten, |
1847 | // ri - matching rewriteinfo |
1848 | // @output new_node - points to newly created node |
1849 | // @return OkStatus(), if the input node is rewritten; |
1850 | // Returns appropriate Status error code otherwise. |
1851 | // Graph is only updated when the input node is rewritten. |
1852 | Status RewriteNodeForJustOpNameChange(std::unique_ptr<Graph>* g, |
1853 | const Node* orig_node, Node** new_node, |
1854 | const RewriteInfo* ri); |
1855 | |
1856 | // Rewrites input node to enable MKL layout propagation. Please also refer to |
1857 | // documentation for the function RewriteNodeForJustOpNameChange() to |
1858 | // understand what it means. |
1859 | // |
1860 | // @input g - input graph, orig_node - Node to be rewritten, |
1861 | // ri - matching rewriteinfo |
1862 | // @output new_node - points to newly created node |
1863 | // @return OkStatus(), if the input node is rewritten; |
1864 | // Returns appropriate Status error code otherwise. |
1865 | // Graph is updated in case the input node is rewritten. |
1866 | // Otherwise, it is not updated. |
1867 | Status RewriteNodeForLayoutPropagation(std::unique_ptr<Graph>* g, |
1868 | const Node* orig_node, Node** new_node, |
1869 | const RewriteInfo* ri); |
1870 | |
1871 | // Get nodes that will feed a list of TF tensors to the new |
1872 | // node that we are constructing. |
1873 | // |
1874 | // @input g - input graph, |
1875 | // @input inputs - inputs to old node that we are using for constructing |
1876 | // new inputs, |
1877 | // @input input_idx - the index in the 'inputs' vector pointing to the |
1878 | // current input that we have processed so far |
1879 | // @output input_idx - index will be incremented by the number of nodes |
1880 | // from 'inputs' that are processed |
1881 | // @input list_length - The expected length of list of TF tensors |
1882 | // @output output_nodes - the list of new nodes creating TF tensors |
1883 | // |
1884 | // @return None |
1885 | void GetNodesProducingTFTensorList( |
1886 | const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, |
1887 | int* input_idx, int list_length, |
1888 | std::vector<NodeBuilder::NodeOut>* output_nodes); |
1889 | |
1890 | // Get nodes that will feed a list of Mkl tensors to the new |
1891 | // node that we are constructing. |
1892 | // |
1893 | // @input g - input graph, |
1894 | // @input orig_node - Original node that we are rewriting |
1895 | // @input inputs - inputs to old node that we are using for constructing |
1896 | // new inputs, |
1897 | // @input input_idx - the index in the 'inputs' vector pointing to the |
1898 | // current input that we have processed so far |
1899 | // @output input_idx - index will be incremented by the number of nodes |
1900 | // from 'inputs' that are processed |
1901 | // @input list_length - The expected length of list of Mkl tensors |
1902 | // @output output_nodes - the list of new nodes creating Mkl tensors |
1903 | // |
1904 | // @return None |
1905 | void GetNodesProducingMklTensorList( |
1906 | std::unique_ptr<Graph>* g, const Node* orig_node, |
1907 | const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, |
1908 | int* input_idx, int list_length, |
1909 | std::vector<NodeBuilder::NodeOut>* output_nodes); |
1910 | |
1911 | // Get a node that will feed an Mkl tensor to the new |
1912 | // node that we are constructing. The output node could be (1) 'n' |
1913 | // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor |
1914 | // if 'n' is not an Mkl layer. |
1915 | // |
1916 | // @input g - input graph, |
1917 | // @input orig_node - Original node that we are rewriting, |
1918 | // @input n - Node based on which we are creating Mkl node, |
1919 | // @input n_output_slot - the output slot of node 'n' |
1920 | // which is feeding to the node that we are constructing |
1921 | // @output mkl_node - the new node that will feed Mkl tensor |
1922 | // @output mkl_node_output_slot - the slot number of mkl_node that |
1923 | // will feed the tensor |
1924 | // @return None |
1925 | void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, |
1926 | const Node* orig_node, Node* n, |
1927 | int n_output_slot, Node** mkl_node, |
1928 | int* mkl_node_output_slot); |
1929 | |
1930 | // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' |
1931 | // in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are |
1932 | // set up in contiguous fashion. 'workspace_tensors' carry graph nodes |
1933 | // producing workspace edges if 'are_workspace_tensors_available' is true. |
1934 | // Otherwise, 'workspace_tensors' is empty vector. |
1935 | // |
1936 | // For details, refer to 'Ordering of inputs after rewriting' section in the |
1937 | // documentation above. |
1938 | // |
1939 | // Returns OkStatus() if setting up inputs is successful, otherwise |
1940 | // returns appropriate status code. |
1941 | int SetUpContiguousInputs( |
1942 | std::unique_ptr<Graph>* g, |
1943 | const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, |
1944 | NodeBuilder* nb, const Node* old_node, |
1945 | std::vector<NodeBuilder::NodeOut>* workspace_tensors, |
1946 | bool are_workspace_tensors_available); |
1947 | |
1948 | // Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb' |
1949 | // in graph 'g'. Original node is input in 'orig_node'. |
1950 | // |
1951 | // For details, refer to 'Ordering of Tensorflow tensors and Mkl tensors' |
1952 | // section in the documentation above. |
1953 | // |
1954 | // Returns OkStatus() if setting up inputs is successful, otherwise |
1955 | // returns appropriate status code. |
1956 | Status SetUpInputs(std::unique_ptr<Graph>* g, |
1957 | const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, |
1958 | NodeBuilder* nb, const Node* orig_node); |
1959 | |
1960 | // Create new inputs by copying old inputs 'inputs' for the rewritten node |
1961 | // in 'nb' in graph 'g'. Original node is input in 'orig_node'. This is mostly |
1962 | // used in the context of rewrite for just operator name change in which |
1963 | // inputs of old operator and new operator are same. |
1964 | // |
1965 | // Returns OkStatus() if setting up inputs is successful, otherwise |
1966 | // returns appropriate status code. |
1967 | Status CopyInputs(const Node* orig_node, |
1968 | const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, |
1969 | NodeBuilder* nb); |
1970 | |
1971 | // Add workspace edge on the input or output side of Node 'orig_node' by using |
1972 | // NodeBuilder 'nb' for the new node provided. If 'orig_node' does not dictate |
1973 | // adding workspace edge then do not add it. Workspace Tensorflow and Mkl |
1974 | // tensors, if they need to be added, will be set into these tensors. |
1975 | // If we set workspace tensors, then are_ws_tensors_added should be true. |
1976 | void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, |
1977 | const Node* orig_node, NodeBuilder* nb, |
1978 | std::vector<NodeBuilder::NodeOut>* ws_tensors, |
1979 | bool* are_ws_tensors_added); |
1980 | |
1981 | // Helper function used by FixMklMetaDataEdges. Fixes the metadata edge |
1982 | // pointed by 'e_metadata' corresponding to the data edge 'e_data' in graph |
1983 | // 'g'. Returns true if fixup was done; otherwise, it returns false. |
1984 | bool FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, const Edge* e_data, |
1985 | const Edge* e_metadata); |
1986 | |
1987 | // Are the input Mkl metadata edges for node 'n' in graph 'g' correctly |
1988 | // connected? If not, then fix them. This is needed because a graph may have |
1989 | // some input Mkl metadata edges incorrectly setup after node merge and |
1990 | // rewrite passes. This could happen because GetReversePostOrder function may |
1991 | // not provide topologically sorted order if a graph contains cycles. The |
1992 | // function returns true if at least one Mkl metadata edge for node 'n' was |
1993 | // fixed. Otherwise, it returns false. |
1994 | // |
1995 | // Example: |
1996 | // |
1997 | // X = MklConv2D(_, _, _) |
1998 | // Y = MklConv2DWithBias(_, _, _, _, _, _) |
1999 | // Z = MklAdd(X, Y, DummyMklTensor, Y:1) |
2000 | // |
2001 | // For a graph such as shown above, note that 3rd argument of MklAdd contains |
2002 | // DummyMklTensor. Actually, it should be getting the Mkl metadata from |
2003 | // MklConv2D op (specifically, X:2). This incorrect plumbing could be possible |
2004 | // (although rare) if the Mkl NodeMerge + NodeRewrite passes visit Z before X |
2005 | // (possible if X, Y, Z are part of a loop.) This function fixes the Mkl |
2006 | // metadata edges only - it does not rewrite nodes nor does it modify the Mkl |
2007 | // data edges (1st and 2nd arguments of MklAdd). |
2008 | bool FixMklMetaDataEdges(std::unique_ptr<Graph>* g, Node* n); |
2009 | |
2010 | // Functions specific to operators to copy attributes |
2011 | // We need operator-specific function to copy attributes because the framework |
2012 | // does not provide any generic function for it. |
2013 | // NOTE: names are alphabetically sorted. |
2014 | static void CopyAttrsAll(const Node* orig_node, NodeBuilder* nb, |
2015 | bool change_format = false); |
2016 | static void CopyAttrsAllCheckConstFilter(const Node* orig_node, |
2017 | NodeBuilder* nb, |
2018 | bool change_format = false); |
2019 | |
2020 | static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb, |
2021 | bool change_format = false); |
2022 | static void CopyAttrsConvCheckConstFilter(const Node* orig_node, |
2023 | NodeBuilder* nb, |
2024 | bool change_format = false); |
2025 | static void CopyAttrsFusedConv2DCheckConstFilter(const Node* orig_node, |
2026 | NodeBuilder* nb, |
2027 | bool change_format = false); |
2028 | static void CopyAttrsFromPadAndConv2D(const Node* orig_node1, |
2029 | const Node* orig_node2, NodeBuilder* nb, |
2030 | bool change_format = false); |
2031 | static void CopyAttrsFromPadAndFusedConv2D(const Node* orig_node1, |
2032 | const Node* orig_node2, |
2033 | NodeBuilder* nb, |
2034 | bool change_format = false); |
2035 | static void CopyAttrsQuantizedConv2D(const Node* orig_node, NodeBuilder* nb, |
2036 | bool change_format = false); |
2037 | static void CopyFormatAttrsConv(const Node* orig_node, NodeBuilder* nb, |
2038 | const std::vector<int32>& strides, |
2039 | const std::vector<int32>& dilations, |
2040 | bool change_format = false); |
2041 | |
2042 | static void CopyAttrsQuantizedMatMulWithBias(const Node* orig_node, |
2043 | NodeBuilder* nb, |
2044 | bool change_format = false); |
2045 | static void CopyAttrsQuantizedMatMulWithBiasAndDequantize( |
2046 | const Node* orig_node, NodeBuilder* nb, bool change_format = false); |
2047 | static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb, |
2048 | bool change_format = false); |
2049 | |
2050 | // Generate a graph node in graph 'g' representing a dummy Mkl tensor node, |
2051 | // using node for original node 'orig_node' and return it in '*out'. |
2052 | // TODO(nhasabni) We should move this to mkl_util.h |
2053 | void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out, |
2054 | const Node* orig_node); |
2055 | void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out, |
2056 | const Node* orig_node); |
2057 | }; |
2058 | |
2059 | MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_; |
2060 | |
2061 | // We register Mkl rewrite pass for phase 1 in post partitioning group. |
2062 | // We register it here so that we get a complete picture of all users of Mkl |
2063 | // nodes. Do not change the ordering of the Mkl passes. |
2064 | const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup = |
2065 | OptimizationPassRegistry::POST_PARTITIONING; |
2066 | REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass); |
2067 | |
2068 | ////////////////////////////////////////////////////////////////////////// |
2069 | // Helper functions for creating new node |
2070 | ////////////////////////////////////////////////////////////////////////// |
2071 | |
2072 | static void FillInputs(const Node* n, |
2073 | gtl::InlinedVector<Node*, 4>* control_edges, |
2074 | gtl::InlinedVector<std::pair<Node*, int>, 4>* in) { |
2075 | control_edges->clear(); |
2076 | for (const Edge* e : n->in_edges()) { |
2077 | if (e->IsControlEdge()) { |
2078 | control_edges->push_back(e->src()); |
2079 | } else { |
2080 | (*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output()); |
2081 | } |
2082 | } |
2083 | std::sort(control_edges->begin(), control_edges->end()); |
2084 | } |
2085 | |
2086 | void MklLayoutRewritePass::GetNodesProducingTFTensorList( |
2087 | const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, |
2088 | int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { |
2089 | CHECK_LT(*input_idx, inputs.size()); |
2090 | CHECK_GT(list_length, 0); |
2091 | DCHECK(output_nodes); |
2092 | output_nodes->reserve(list_length); |
2093 | |
2094 | while (list_length != 0) { |
2095 | CHECK_GT(list_length, 0); |
2096 | CHECK_LT(*input_idx, inputs.size()); |
2097 | Node* n = inputs[*input_idx].first; |
2098 | int slot = inputs[*input_idx].second; |
2099 | // If input node 'n' is just producing a single tensor at |
2100 | // output slot 'slot' then we just add that single node. |
2101 | output_nodes->push_back(NodeBuilder::NodeOut(n, slot)); |
2102 | (*input_idx)++; |
2103 | list_length--; |
2104 | } |
2105 | } |
2106 | |
2107 | // TODO(nhasabni) We should move this to mkl_util.h. |
2108 | void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g, |
2109 | Node** out, |
2110 | const Node* orig_node) { |
2111 | // We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent |
2112 | // dummy Mkl tensor. 8 = 2*size_t. |
2113 | const DataType dt = DataTypeToEnum<uint8>::v(); |
2114 | TensorProto proto; |
2115 | proto.set_dtype(dt); |
2116 | uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0}; |
2117 | proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 8)); |
2118 | TensorShape dummy_shape({8}); |
2119 | dummy_shape.AsProto(proto.mutable_tensor_shape()); |
2120 | TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT" ), "Const" ) |
2121 | .Attr("value" , proto) |
2122 | .Attr("dtype" , dt) |
2123 | .Device(orig_node->def().device()) // We place this node on |
2124 | // the same device as the |
2125 | // device of the original |
2126 | // node. |
2127 | .Finalize(&**g, out)); |
2128 | DCHECK(*out); // Make sure we got a valid object before using it |
2129 | |
2130 | // If number of inputs to the original node is > 0, then we add |
2131 | // control dependency between 1st input (index 0) of the original node and |
2132 | // the dummy Mkl node. This is needed because control-flow ops such as Enter, |
2133 | // Merge, etc, require frame_name of the dummy Mkl node to be same as the |
2134 | // rewritten node. Adding control edge between 1st input of the original node |
2135 | // and the dummy Mkl node ensures that the dummy node is in the same frame |
2136 | // as the original node. Choosing 1st input is not necessary - any input of |
2137 | // the original node is fine because all the inputs of a node are always in |
2138 | // the same frame. |
2139 | if (orig_node->num_inputs() > 0) { |
2140 | Node* orig_input0 = nullptr; |
2141 | TF_CHECK_OK( |
2142 | orig_node->input_node(0, const_cast<const Node**>(&orig_input0))); |
2143 | auto edge = (*g)->AddControlEdge(orig_input0, *out, false); |
2144 | DCHECK(edge != nullptr || DoesControlEdgeExist(orig_input0, *out)); |
2145 | } |
2146 | |
2147 | (*out)->set_assigned_device_name(orig_node->assigned_device_name()); |
2148 | } |
2149 | |
2150 | void MklLayoutRewritePass::GetNodesProducingMklTensorList( |
2151 | std::unique_ptr<Graph>* g, const Node* orig_node, |
2152 | const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx, |
2153 | int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) { |
2154 | CHECK_LT(*input_idx, inputs.size()); |
2155 | CHECK_GT(list_length, 0); |
2156 | DCHECK(output_nodes); |
2157 | output_nodes->reserve(list_length); |
2158 | |
2159 | while (list_length != 0) { |
2160 | CHECK_GT(list_length, 0); |
2161 | CHECK_LT(*input_idx, inputs.size()); |
2162 | Node* n = inputs[*input_idx].first; |
2163 | int slot = inputs[*input_idx].second; |
2164 | // If 'n' is producing a single tensor, then create a single Mkl tensor |
2165 | // node. |
2166 | Node* mkl_node = nullptr; |
2167 | int mkl_node_output_slot = 0; |
2168 | GetNodeProducingMklTensor(g, orig_node, n, slot, &mkl_node, |
2169 | &mkl_node_output_slot); |
2170 | output_nodes->push_back( |
2171 | NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot)); |
2172 | (*input_idx)++; |
2173 | list_length--; |
2174 | } |
2175 | } |
2176 | |
2177 | // Get an input node that will feed Mkl tensor to the new |
2178 | // node that we are constructing. An input node could be (1) 'n' |
2179 | // if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor |
2180 | // if 'n' is not an Mkl layer. |
2181 | void MklLayoutRewritePass::GetNodeProducingMklTensor( |
2182 | std::unique_ptr<Graph>* g, const Node* orig_node, Node* n, |
2183 | int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) { |
2184 | DCHECK(n); |
2185 | DCHECK(mkl_node); |
2186 | DCHECK(mkl_node_output_slot); |
2187 | |
2188 | // If this is an MKL op, then it will create extra output for MKL layout. |
2189 | DataType T; |
2190 | if (TryGetNodeAttr(n->def(), "T" , &T) && |
2191 | mkl_op_registry::IsMklOp(n->type_string(), T, false)) { |
2192 | // If this is an MKL op, then it will generate an edge that will receive |
2193 | // Mkl tensor from a node. |
2194 | // output slot number for Mkl tensor would be N+slot number of TensorFlow |
2195 | // tensor, where N is total number of TensorFlow tensors. |
2196 | *mkl_node = n; |
2197 | *mkl_node_output_slot = |
2198 | GetTensorMetaDataIndex(n_output_slot, n->num_outputs()); |
2199 | } else { |
2200 | // If we have not visited the node and rewritten it, then we need |
2201 | // to create a dummy node that will feed a dummy Mkl tensor to this node. |
2202 | // DummyMklTensor node has no input and generates only 1 output |
2203 | // (dummy Mkl tensor) as output slot number 0. |
2204 | GetDummyMklTensorNode(g, mkl_node, orig_node); |
2205 | DCHECK(*mkl_node); |
2206 | *mkl_node_output_slot = 0; |
2207 | } |
2208 | } |
2209 | |
2210 | int MklLayoutRewritePass::SetUpContiguousInputs( |
2211 | std::unique_ptr<Graph>* g, |
2212 | const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, |
2213 | NodeBuilder* nb, const Node* old_node, |
2214 | std::vector<NodeBuilder::NodeOut>* workspace_tensors, |
2215 | bool are_workspace_tensors_available) { |
2216 | DCHECK(workspace_tensors); |
2217 | CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); |
2218 | |
2219 | // TODO(nhasabni): Temporary solution to connect filter input of |
2220 | // BackpropInput with the converted filter from Conv2D. |
2221 | bool do_connect_conv2d_backprop_input_filter = false; |
2222 | Node* conv2d_node = nullptr; |
2223 | // Filter node is 2nd input (slot index 1) of Conv2D. |
2224 | int kConv2DFilterInputSlotIdx = 1; |
2225 | int kConv2DBackpropInputFilterInputSlotIdx = 1; |
2226 | int kConv2DFilterOutputSlotIdx = 1; |
2227 | if (old_node->type_string() == csinfo_.conv2d_grad_input) { |
2228 | // We need to find Conv2D node from Conv2DBackpropInput. |
2229 | // For that let's first find filter node that is 2nd input (slot 1) |
2230 | // of BackpropInput. |
2231 | Node* filter_node = nullptr; |
2232 | TF_CHECK_OK(old_node->input_node(kConv2DBackpropInputFilterInputSlotIdx, |
2233 | &filter_node)); |
2234 | DCHECK(filter_node); |
2235 | |
2236 | // Now check which nodes receive from filter_node. Filter feeds as |
2237 | // 2nd input (slot 1) of _MklConv2D, _MklConv2DWithBias, and |
2238 | // _MklFusedConv2D. |
2239 | for (const Edge* e : filter_node->out_edges()) { |
2240 | if ((e->dst()->type_string() == csinfo_.mkl_conv2d || |
2241 | e->dst()->type_string() == csinfo_.mkl_pad_with_conv2d || |
2242 | e->dst()->type_string() == csinfo_.mkl_pad_with_fused_conv2d || |
2243 | e->dst()->type_string() == csinfo_.mkl_conv2d_with_bias || |
2244 | e->dst()->type_string() == csinfo_.mkl_fused_conv2d) && |
2245 | e->dst_input() == kConv2DFilterInputSlotIdx |
2246 | /* filter is 2nd input of Conv2D and _MklConv2D. */) { |
2247 | if (conv2d_node != nullptr) { |
2248 | VLOG(1) << "MklLayoutRewritePass: unusual case of same filter" |
2249 | << " feeding multiple Conv2D nodes: " |
2250 | << filter_node->DebugString(); |
2251 | // We will not connect filter input of Conv2DBackpropInput |
2252 | // to be safe here. |
2253 | do_connect_conv2d_backprop_input_filter = false; |
2254 | break; |
2255 | } else { |
2256 | conv2d_node = e->dst(); |
2257 | do_connect_conv2d_backprop_input_filter = true; |
2258 | } |
2259 | } |
2260 | } |
2261 | } |
2262 | |
2263 | // Number of input slots to original op |
2264 | // Input slots are represented by .Input() calls in REGISTER_OP. |
2265 | int old_node_input_slots = old_node->op_def().input_arg_size(); |
2266 | int nn_slot_idx = 0; // slot index for inputs of new node |
2267 | |
2268 | // Let's copy all inputs (TF tensors) of original node to new node. |
2269 | int iidx = 0; |
2270 | for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { |
2271 | // An input slot could be a single tensor or a list. We need |
2272 | // to handle this case accordingly. |
2273 | const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); |
2274 | if (ArgIsList(arg)) { |
2275 | std::vector<NodeBuilder::NodeOut> new_node_inputs; |
2276 | int tensor_list_length = GetTensorListLength(arg, old_node); |
2277 | if (tensor_list_length != 0) { |
2278 | GetNodesProducingTFTensorList(old_node_inputs, &iidx, |
2279 | tensor_list_length, &new_node_inputs); |
2280 | } |
2281 | nb->Input(new_node_inputs); |
2282 | nn_slot_idx++; |
2283 | } else { |
2284 | // Special case for connecting filter input of Conv2DBackpropInput |
2285 | if (do_connect_conv2d_backprop_input_filter && |
2286 | iidx == kConv2DBackpropInputFilterInputSlotIdx) { |
2287 | nb->Input(conv2d_node, kConv2DFilterOutputSlotIdx); |
2288 | } else { |
2289 | nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second); |
2290 | } |
2291 | iidx++; |
2292 | nn_slot_idx++; |
2293 | } |
2294 | } |
2295 | |
2296 | // If workspace tensors are available for this op and we are using |
2297 | // contiguous ordering then we need to add Tensorflow tensor for |
2298 | // workspace here because Tensorflow tensor for workspace is the |
2299 | // last tensor in the list of Tensorflow tensors. |
2300 | if (are_workspace_tensors_available) { |
2301 | CHECK_EQ(workspace_tensors->size(), 2); |
2302 | // Tensorflow tensor |
2303 | nb->Input((*workspace_tensors)[0].node, (*workspace_tensors)[0].index); |
2304 | nn_slot_idx++; |
2305 | } |
2306 | |
2307 | // Let's now setup all Mkl inputs to a new node. |
2308 | // Number of Mkl inputs must be same as number of TF inputs. |
2309 | iidx = 0; |
2310 | for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { |
2311 | // An input slot could be a single tensor or a list. We need |
2312 | // to handle this case accordingly. |
2313 | const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); |
2314 | if (ArgIsList(arg)) { |
2315 | std::vector<NodeBuilder::NodeOut> new_node_inputs; |
2316 | int tensor_list_length = GetTensorListLength(arg, old_node); |
2317 | if (tensor_list_length != 0) { |
2318 | GetNodesProducingMklTensorList(g, old_node, old_node_inputs, &iidx, |
2319 | tensor_list_length, &new_node_inputs); |
2320 | } |
2321 | nb->Input(new_node_inputs); |
2322 | nn_slot_idx++; |
2323 | } else { |
2324 | Node* mkl_node = nullptr; |
2325 | int mkl_node_output_slot = 0; |
2326 | // Special case for connecting filter input of Conv2DBackpropInput |
2327 | if (do_connect_conv2d_backprop_input_filter && |
2328 | iidx == kConv2DBackpropInputFilterInputSlotIdx) { |
2329 | GetNodeProducingMklTensor(g, old_node, conv2d_node, |
2330 | kConv2DFilterOutputSlotIdx, &mkl_node, |
2331 | &mkl_node_output_slot); |
2332 | } else { |
2333 | GetNodeProducingMklTensor(g, old_node, old_node_inputs[iidx].first, |
2334 | old_node_inputs[iidx].second, &mkl_node, |
2335 | &mkl_node_output_slot); |
2336 | } |
2337 | nb->Input(mkl_node, mkl_node_output_slot); |
2338 | iidx++; |
2339 | nn_slot_idx++; |
2340 | } |
2341 | } |
2342 | |
2343 | // If workspace tensors are available for this op and we are using |
2344 | // contiguous ordering then we need to add Mkl tensor for |
2345 | // workspace here because Mkl tensor for workspace is the |
2346 | // last tensor in the list of Mkl tensors. |
2347 | if (are_workspace_tensors_available) { |
2348 | CHECK_EQ(workspace_tensors->size(), 2); |
2349 | // Mkl tensor |
2350 | nb->Input((*workspace_tensors)[1].node, (*workspace_tensors)[1].index); |
2351 | nn_slot_idx++; |
2352 | } |
2353 | |
2354 | return nn_slot_idx; |
2355 | } |
2356 | |
2357 | // This method finds out if checking workspace is needed or not. Workspace is |
2358 | // not used in quantized ops, so checking that would fail as quantized ops |
2359 | // don't have attribute: "T". |
2360 | bool IsWorkspaceCheckNeeded(const Node* node) { |
2361 | std::vector<string> quant_ops{ |
2362 | "Dequantize" , |
2363 | "QuantizeV2" , |
2364 | "QuantizedConv2D" , |
2365 | "QuantizedConv2DWithBias" , |
2366 | "QuantizedConv2DAndRelu" , |
2367 | "QuantizedConv2DWithBiasAndRelu" , |
2368 | "QuantizedConv2DWithBiasSumAndRelu" , |
2369 | "QuantizedConv2DPerChannel" , |
2370 | "QuantizedConv2DAndRequantize" , |
2371 | "QuantizedConv2DWithBiasAndRequantize" , |
2372 | "QuantizedConv2DAndReluAndRequantize" , |
2373 | "QuantizedConv2DWithBiasAndReluAndRequantize" , |
2374 | "QuantizedConv2DWithBiasSumAndReluAndRequantize" , |
2375 | "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize" , |
2376 | "QuantizedMatMulWithBias" , |
2377 | "QuantizedMatMulWithBiasAndRequantize" , |
2378 | "QuantizedMatMulWithBiasAndDequantize" , |
2379 | "QuantizedMatMulWithBiasAndRelu" , |
2380 | "QuantizedMatMulWithBiasAndReluAndRequantize" , |
2381 | "QuantizedDepthwiseConv2D" , |
2382 | "QuantizedDepthwiseConv2DWithBias" , |
2383 | "QuantizedDepthwiseConv2DWithBiasAndRelu" , |
2384 | "QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize" }; |
2385 | return std::find(std::begin(quant_ops), std::end(quant_ops), |
2386 | node->type_string()) == std::end(quant_ops); |
2387 | } |
2388 | |
2389 | Status MklLayoutRewritePass::SetUpInputs( |
2390 | std::unique_ptr<Graph>* g, |
2391 | const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, |
2392 | NodeBuilder* nb, const Node* old_node) { |
2393 | // Let's check if we need to add workspace tensors for this node. |
2394 | // We add workspace edge only for MaxPool, LRN and BatchNorm. |
2395 | std::vector<NodeBuilder::NodeOut> workspace_tensors; |
2396 | bool are_workspace_tensors_available = false; |
2397 | |
2398 | if (IsWorkspaceCheckNeeded(old_node)) { |
2399 | AddWorkSpaceEdgeIfNeeded(g, old_node, nb, &workspace_tensors, |
2400 | &are_workspace_tensors_available); |
2401 | } |
2402 | |
2403 | int new_node_input_slots = 0; |
2404 | if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) { |
2405 | // TODO(nhasabni): implement this function just for same of completion. |
2406 | // We do not use interleaved ordering right now. |
2407 | return Status( |
2408 | error::Code::UNIMPLEMENTED, |
2409 | "Interleaved ordering of tensors is currently not supported." ); |
2410 | } else { |
2411 | CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS); |
2412 | new_node_input_slots = SetUpContiguousInputs( |
2413 | g, old_node_inputs, nb, old_node, &workspace_tensors, |
2414 | are_workspace_tensors_available); |
2415 | } |
2416 | |
2417 | // Sanity check |
2418 | int old_node_input_slots = old_node->op_def().input_arg_size(); |
2419 | if (!are_workspace_tensors_available) { |
2420 | // If we are not adding workspace tensors for this op, then the total |
2421 | // number of input slots to the new node _must_ be 2 times the number |
2422 | // of input slots to the original node: N original Tensorflow tensors and |
2423 | // N for Mkl tensors corresponding to each Tensorflow tensors. |
2424 | CHECK_EQ(new_node_input_slots, old_node_input_slots * 2); |
2425 | } else { |
2426 | // If we are adding workspace tensors for this op, then the total |
2427 | // The total number of input slots to new node _must_ be 2 times the number |
2428 | // of input slots to the original node: N original Tensorflow tensors and |
2429 | // N for Mkl tensors corresponding to each Tensorflow tensors plus 2 |
2430 | // (for workspace Tensorflow tensor and workspace Mkl tensor). |
2431 | CHECK_EQ(new_node_input_slots, old_node_input_slots * 2 + 2); |
2432 | } |
2433 | |
2434 | return OkStatus(); |
2435 | } |
2436 | |
2437 | Status MklLayoutRewritePass::CopyInputs( |
2438 | const Node* old_node, |
2439 | const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs, |
2440 | NodeBuilder* nb) { |
2441 | // Number of input slots to old node |
2442 | // Input slots are represented by .Input() calls in REGISTER_OP. |
2443 | int old_node_input_slots = old_node->op_def().input_arg_size(); |
2444 | // Actual number of inputs can be greater than or equal to number |
2445 | // of Input slots because inputs of type list could be unfolded. |
2446 | auto old_node_input_size = old_node_inputs.size(); |
2447 | |
2448 | if (old_node->type_string() == "_FusedConv2D" ) { |
2449 | // [TODO(intel-tf)] |
2450 | // commit 5be9a5 updates _FusedConv2D with additional host_args, |
2451 | // but mkl version currently doesn't have this input arg, needs to |
2452 | // remove this extra input when replace node with mkl node. |
2453 | old_node_input_slots--; |
2454 | } |
2455 | |
2456 | DCHECK_GE(old_node_input_size, old_node_input_slots); |
2457 | |
2458 | // Let's copy all inputs of old node to new node. |
2459 | int iidx = 0; |
2460 | for (int on_slot_idx = 0; on_slot_idx < old_node_input_slots; on_slot_idx++) { |
2461 | // An input slot could be a single tensor or a list. We need |
2462 | // to handle this case accordingly. |
2463 | DCHECK_LT(iidx, old_node_input_size); |
2464 | const OpDef::ArgDef& arg = old_node->op_def().input_arg(on_slot_idx); |
2465 | if (ArgIsList(arg)) { |
2466 | std::vector<NodeBuilder::NodeOut> new_node_inputs; |
2467 | int N = GetTensorListLength(arg, old_node); |
2468 | if (N != 0) { |
2469 | GetNodesProducingTFTensorList(old_node_inputs, &iidx, N, |
2470 | &new_node_inputs); |
2471 | } |
2472 | nb->Input(new_node_inputs); |
2473 | } else { |
2474 | nb->Input(old_node_inputs[iidx].first, old_node_inputs[iidx].second); |
2475 | iidx++; |
2476 | } |
2477 | } |
2478 | return OkStatus(); |
2479 | } |
2480 | |
2481 | ////////////////////////////////////////////////////////////////////////// |
2482 | // Helper functions related to workspace pass |
2483 | ////////////////////////////////////////////////////////////////////////// |
2484 | |
2485 | // TODO(nhasabni) We should move this to mkl_util.h. |
2486 | void MklLayoutRewritePass::GetDummyWorkspaceTensorNode( |
2487 | std::unique_ptr<Graph>* g, Node** out, const Node* orig_node) { |
2488 | // We use uint8 tensor of shape 8 with content {0,0,0,0,0,0,0,0} to represent |
2489 | // workspace tensor. |
2490 | GetDummyMklTensorNode(g, out, orig_node); |
2491 | } |
2492 | |
2493 | void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded( |
2494 | std::unique_ptr<Graph>* g, const Node* orig_node, NodeBuilder* nb, |
2495 | std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) { |
2496 | bool workspace_edge_added = false; // Default initializer |
2497 | DCHECK(are_ws_tensors_added); |
2498 | *are_ws_tensors_added = false; // Default initializer |
2499 | |
2500 | DataType T; |
2501 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T" , &T)); |
2502 | for (auto ws : wsinfo_) { |
2503 | if (orig_node->type_string() == ws.fwd_op && |
2504 | mkl_op_registry::IsMklOp( |
2505 | mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) { |
2506 | // If this op is a fwd op, then we need to check if there is an |
2507 | // edge from this node's fwd_slot to bwdop's bwd_slot. If there is |
2508 | // an edge, then we just add an attribute on this node for setting |
2509 | // workspace_passed to true. We don't add actual workspace edge |
2510 | // in this node. Actual workspace edge gets added in the backward |
2511 | // op for this node. |
2512 | for (const Edge* e : orig_node->out_edges()) { |
2513 | if (e->src_output() == ws.fwd_slot && |
2514 | e->dst()->type_string() == ws.bwd_op && |
2515 | e->dst_input() == ws.bwd_slot) { |
2516 | nb->Attr("workspace_enabled" , true); |
2517 | VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " |
2518 | << orig_node->type_string(); |
2519 | workspace_edge_added = true; |
2520 | // We found the edge that we were looking for, so break. |
2521 | break; |
2522 | } |
2523 | } |
2524 | |
2525 | if (!workspace_edge_added) { |
2526 | // If we are here, then we did not find backward operator for this |
2527 | // node. |
2528 | nb->Attr("workspace_enabled" , false); |
2529 | } |
2530 | } else if (orig_node->type_string() == ws.bwd_op && |
2531 | mkl_op_registry::IsMklOp( |
2532 | mkl_op_registry::GetMklOpName(orig_node->type_string()), |
2533 | T)) { |
2534 | // If this op is a bwd op, then we need to add workspace edge and |
2535 | // it's Mkl tensor edge between its corresponding fwd op and this |
2536 | // op. Corresponding fwd op is specified in 'fwd_op' field of |
2537 | // workspace info. fwd_slot and bwd_slot in workspace info specify |
2538 | // an edge between which slots connect forward and backward op. |
2539 | // Once all these criteria match, we add a workspace edge between |
2540 | // ws_fwd_slot and ws_bwd_slot. Its corresponding Mkl tensor is |
2541 | // determined by interleaved/contiguous ordering. Function |
2542 | // DataIndexToMetaDataIndex tells us the location of Mkl tensor |
2543 | // from the location of the Tensorflow tensor. |
2544 | for (const Edge* e : orig_node->in_edges()) { |
2545 | if (e->src_output() == ws.fwd_slot && |
2546 | // We would have rewritten the forward op, so we need to use |
2547 | // GetMklOpName call to get its Mkl name. |
2548 | e->src()->type_string() == |
2549 | mkl_op_registry::GetMklOpName(ws.fwd_op) && |
2550 | e->dst_input() == ws.bwd_slot) { |
2551 | nb->Attr("workspace_enabled" , true); |
2552 | DCHECK(ws_tensors); |
2553 | // Add workspace edge between fwd op and bwd op. |
2554 | ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot)); |
2555 | // Check if we are running in native format mode. If so, |
2556 | // we don't need to have an Mkl metadata tensor for the workspace. |
2557 | if (!NativeFormatEnabled()) { |
2558 | // Add Mkl tensor edge for workspace edge between fwd op and bwd op. |
2559 | ws_tensors->push_back(NodeBuilder::NodeOut( |
2560 | e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot, |
2561 | e->src()->num_outputs()))); |
2562 | } |
2563 | *are_ws_tensors_added = true; |
2564 | // In terms of input ordering, we add these calls to add Input |
2565 | // here because workspace edge (and its Mkl tensor) is the last |
2566 | // edge in the fwdop and bwdop. So all inputs before workspace |
2567 | // tensor have been added by SetUpInputs function. |
2568 | VLOG(1) << "MklLayoutRewritePass: workspace_enabled for " |
2569 | << orig_node->type_string(); |
2570 | workspace_edge_added = true; |
2571 | // We found the edge that we were looking for, so break. |
2572 | break; |
2573 | } |
2574 | } |
2575 | |
2576 | // If we are here means we did not find fwd op that feeds to this |
2577 | // bwd op. So in this case, we need to generate dummy tensors for |
2578 | // workspace input and Mkl tensor for workspace, and set |
2579 | // workspace_enabled to false. |
2580 | if (!workspace_edge_added) { |
2581 | nb->Attr("workspace_enabled" , false); |
2582 | Node* dmt_ws = nullptr; // Dummy tensor for workspace |
2583 | Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace |
2584 | GetDummyWorkspaceTensorNode(g, &dmt_ws, orig_node); |
2585 | GetDummyMklTensorNode(g, &dmt_mkl_ws, orig_node); |
2586 | DCHECK(dmt_ws); |
2587 | DCHECK(dmt_mkl_ws); |
2588 | DCHECK(ws_tensors); |
2589 | // We add dummy tensor as workspace tensor. |
2590 | ws_tensors->push_back(NodeBuilder::NodeOut(dmt_ws, 0)); |
2591 | // We add dummy tensor as Mkl tensor for workspace tensor. |
2592 | ws_tensors->push_back(NodeBuilder::NodeOut(dmt_mkl_ws, 0)); |
2593 | *are_ws_tensors_added = true; |
2594 | VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for " |
2595 | << orig_node->type_string(); |
2596 | } |
2597 | } else { |
2598 | // If this node does not match any workspace info, then we do not |
2599 | // do anything special for workspace propagation for it. |
2600 | } |
2601 | } |
2602 | } |
2603 | |
2604 | ////////////////////////////////////////////////////////////////////////// |
2605 | // Op-specific functions to copy attributes from old node to new node |
2606 | ////////////////////////////////////////////////////////////////////////// |
2607 | |
2608 | // Generic function to copy all attributes from original node to target. |
2609 | void MklLayoutRewritePass::CopyAttrsAll(const Node* orig_node, NodeBuilder* nb, |
2610 | bool change_format) { |
2611 | string name; |
2612 | AttrSlice attr_list(orig_node->def()); |
2613 | |
2614 | auto iter = attr_list.begin(); |
2615 | while (iter != attr_list.end()) { |
2616 | name = iter->first; |
2617 | auto attr = iter->second; |
2618 | nb->Attr(name, attr); |
2619 | ++iter; |
2620 | } |
2621 | } |
2622 | |
2623 | // Generic function to copy all attributes and check if filter is const. |
2624 | void MklLayoutRewritePass::CopyAttrsAllCheckConstFilter(const Node* orig_node, |
2625 | NodeBuilder* nb, |
2626 | bool change_format) { |
2627 | CopyAttrsAll(orig_node, nb, change_format); |
2628 | |
2629 | // Check and set filter attribute. |
2630 | Node* filter_node = nullptr; |
2631 | TF_CHECK_OK(orig_node->input_node(1, &filter_node)); |
2632 | |
2633 | bool is_filter_const = false; |
2634 | if (HasNodeAttr(orig_node->def(), "is_filter_const" )) { |
2635 | (void)GetNodeAttr(orig_node->def(), "is_filter_const" , &is_filter_const); |
2636 | } |
2637 | |
2638 | // In case that (1) orig_node does not have attribute 'is_filter_const', |
2639 | // or (2) it has the attribute but with the false value, then we set the |
2640 | // attribute for 'nb' with a value based on filter_node being const or not. |
2641 | // If is_filter_const == true, then there is no need to call nb->Attr() as |
2642 | // CopyAttrsAll() has already copied the attribute from orig_node to nb. |
2643 | if (!is_filter_const) { |
2644 | nb->Attr("is_filter_const" , filter_node->IsConstant()); |
2645 | } |
2646 | } |
2647 | |
2648 | void MklLayoutRewritePass::CopyAttrsConvCheckConstFilter(const Node* orig_node, |
2649 | NodeBuilder* nb, |
2650 | bool change_format) { |
2651 | CopyAttrsConv(orig_node, nb, change_format); |
2652 | |
2653 | // Check and set filter attribute. |
2654 | Node* filter_node = nullptr; |
2655 | TF_CHECK_OK(orig_node->input_node(1, &filter_node)); |
2656 | nb->Attr("is_filter_const" , filter_node->IsConstant()); |
2657 | } |
2658 | |
2659 | void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node, NodeBuilder* nb, |
2660 | bool change_format) { |
2661 | DataType T; |
2662 | string padding; |
2663 | std::vector<int32> strides; |
2664 | std::vector<int32> dilations; |
2665 | std::vector<int32> explicit_paddings; |
2666 | |
2667 | // Get all attributes from old node. |
2668 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T" , &T)); |
2669 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides" , &strides)); |
2670 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations" , &dilations)); |
2671 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding" , &padding)); |
2672 | |
2673 | // Check `explicit_paddings` first because some Conv ops don't have |
2674 | // this attribute. |
2675 | if (TryGetNodeAttr(orig_node->def(), "explicit_paddings" , |
2676 | &explicit_paddings) && |
2677 | !explicit_paddings.empty()) { |
2678 | nb->Attr("explicit_paddings" , explicit_paddings); |
2679 | } |
2680 | |
2681 | // Add attributes to new node. |
2682 | nb->Attr("T" , T); |
2683 | nb->Attr("padding" , padding); |
2684 | |
2685 | // Add attributes related to `data_format`. |
2686 | CopyFormatAttrsConv(orig_node, nb, strides, dilations, change_format); |
2687 | } |
2688 | |
2689 | // Used with MergePadWithConv2D |
2690 | void MklLayoutRewritePass::CopyAttrsFromPadAndConv2D(const Node* orig_node1, |
2691 | const Node* orig_node2, |
2692 | NodeBuilder* nb, |
2693 | bool change_format) { |
2694 | DataType Tpaddings; |
2695 | DataType T; |
2696 | string data_format; |
2697 | string padding; |
2698 | std::vector<int32> strides; |
2699 | std::vector<int32> dilations; |
2700 | bool use_cudnn_on_gpu; |
2701 | |
2702 | // Get all attributes from old node 1. |
2703 | TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "T" , &T)); |
2704 | TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "strides" , &strides)); |
2705 | TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "dilations" , &dilations)); |
2706 | TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "padding" , &padding)); |
2707 | TF_CHECK_OK(GetNodeAttr(orig_node1->def(), "data_format" , &data_format)); |
2708 | TF_CHECK_OK( |
2709 | GetNodeAttr(orig_node1->def(), "use_cudnn_on_gpu" , &use_cudnn_on_gpu)); |
2710 | // Get all attributes from old node 2. |
2711 | TF_CHECK_OK(GetNodeAttr(orig_node2->def(), "Tpaddings" , &Tpaddings)); |
2712 | |
2713 | // Add attributes to new node. |
2714 | nb->Attr("T" , T); |
2715 | nb->Attr("strides" , strides); |
2716 | nb->Attr("dilations" , dilations); |
2717 | nb->Attr("padding" , padding); |
2718 | nb->Attr("data_format" , data_format); |
2719 | nb->Attr("use_cudnn_on_gpu" , use_cudnn_on_gpu); |
2720 | nb->Attr("Tpaddings" , Tpaddings); |
2721 | } |
2722 | |
2723 | void MklLayoutRewritePass::CopyAttrsFusedConv2DCheckConstFilter( |
2724 | const Node* orig_node, NodeBuilder* nb, bool change_format) { |
2725 | DataType T; |
2726 | int num_args; |
2727 | string data_format; |
2728 | string padding; |
2729 | std::vector<int32> strides; |
2730 | std::vector<int32> dilations; |
2731 | std::vector<int32> explicit_paddings; |
2732 | float epsilon; |
2733 | std::vector<string> fused_ops; |
2734 | float leakyrelu_alpha; |
2735 | bool use_cudnn_on_gpu; |
2736 | |
2737 | // Get all attributes from old node. |
2738 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T" , &T)); |
2739 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "num_args" , &num_args)); |
2740 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides" , &strides)); |
2741 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding" , &padding)); |
2742 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format" , &data_format)); |
2743 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations" , &dilations)); |
2744 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "fused_ops" , &fused_ops)); |
2745 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "epsilon" , &epsilon)); |
2746 | TF_CHECK_OK( |
2747 | GetNodeAttr(orig_node->def(), "leakyrelu_alpha" , &leakyrelu_alpha)); |
2748 | TF_CHECK_OK( |
2749 | GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu" , &use_cudnn_on_gpu)); |
2750 | |
2751 | // Add attributes to new node. |
2752 | nb->Attr("T" , T); |
2753 | nb->Attr("num_args" , num_args); |
2754 | nb->Attr("strides" , strides); |
2755 | nb->Attr("padding" , padding); |
2756 | nb->Attr("data_format" , data_format); |
2757 | nb->Attr("dilations" , dilations); |
2758 | nb->Attr("epsilon" , epsilon); |
2759 | nb->Attr("fused_ops" , fused_ops); |
2760 | nb->Attr("leakyrelu_alpha" , leakyrelu_alpha); |
2761 | nb->Attr("use_cudnn_on_gpu" , use_cudnn_on_gpu); |
2762 | |
2763 | // Check `explicit_paddings` first because some Conv ops don't have |
2764 | // this attribute. |
2765 | if (TryGetNodeAttr(orig_node->def(), "explicit_paddings" , |
2766 | &explicit_paddings) && |
2767 | !explicit_paddings.empty()) { |
2768 | nb->Attr("explicit_paddings" , explicit_paddings); |
2769 | } |
2770 | |
2771 | // Check and set filter attribute. |
2772 | Node* filter_node = nullptr; |
2773 | TF_CHECK_OK(orig_node->input_node(1, &filter_node)); |
2774 | nb->Attr("is_filter_const" , filter_node->IsConstant()); |
2775 | } |
2776 | |
2777 | void MklLayoutRewritePass::CopyAttrsFromPadAndFusedConv2D( |
2778 | const Node* fused_conv2d, const Node* pad, NodeBuilder* nb, |
2779 | bool change_format) { |
2780 | DataType T; |
2781 | int num_args; |
2782 | string data_format; |
2783 | string padding; |
2784 | std::vector<int32> strides; |
2785 | std::vector<int32> dilations; |
2786 | float epsilon; |
2787 | std::vector<string> fused_ops; |
2788 | DataType Tpaddings; |
2789 | float leakyrelu_alpha; |
2790 | |
2791 | // Get all attributes from old node. |
2792 | TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "T" , &T)); |
2793 | TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "num_args" , &num_args)); |
2794 | TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "strides" , &strides)); |
2795 | TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "padding" , &padding)); |
2796 | TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "data_format" , &data_format)); |
2797 | TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "dilations" , &dilations)); |
2798 | TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "fused_ops" , &fused_ops)); |
2799 | TF_CHECK_OK(GetNodeAttr(fused_conv2d->def(), "epsilon" , &epsilon)); |
2800 | TF_CHECK_OK( |
2801 | GetNodeAttr(fused_conv2d->def(), "leakyrelu_alpha" , &leakyrelu_alpha)); |
2802 | TF_CHECK_OK(GetNodeAttr(pad->def(), "Tpaddings" , &Tpaddings)); |
2803 | |
2804 | // Add attributes to new node. |
2805 | nb->Attr("T" , T); |
2806 | nb->Attr("num_args" , num_args); |
2807 | nb->Attr("strides" , strides); |
2808 | nb->Attr("padding" , padding); |
2809 | nb->Attr("data_format" , data_format); |
2810 | nb->Attr("dilations" , dilations); |
2811 | nb->Attr("epsilon" , epsilon); |
2812 | nb->Attr("Tpaddings" , Tpaddings); |
2813 | nb->Attr("fused_ops" , fused_ops); |
2814 | nb->Attr("leakyrelu_alpha" , leakyrelu_alpha); |
2815 | } |
2816 | |
2817 | void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node, |
2818 | NodeBuilder* nb, |
2819 | bool change_format) { |
2820 | DataType Tinput, Tfilter, out_type; |
2821 | string padding; |
2822 | string data_format("NHWC" ); |
2823 | std::vector<int32> strides, dilations, padding_list; |
2824 | bool has_padding_list = HasNodeAttr(orig_node->def(), "padding_list" ); |
2825 | |
2826 | // Get all attributes from old node. |
2827 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput" , &Tinput)); |
2828 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tfilter" , &Tfilter)); |
2829 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "out_type" , &out_type)); |
2830 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding" , &padding)); |
2831 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides" , &strides)); |
2832 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations" , &dilations)); |
2833 | if (has_padding_list) { |
2834 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding_list" , &padding_list)); |
2835 | } |
2836 | |
2837 | Node* filter_node = nullptr; |
2838 | TF_CHECK_OK(orig_node->input_node(1, &filter_node)); |
2839 | |
2840 | // Add attributes to new node. |
2841 | nb->Attr("Tinput" , Tinput); |
2842 | nb->Attr("Tfilter" , Tfilter); |
2843 | nb->Attr("out_type" , out_type); |
2844 | nb->Attr("padding" , padding); |
2845 | nb->Attr("is_filter_const" , filter_node->IsConstant()); |
2846 | nb->Attr("strides" , strides); |
2847 | nb->Attr("dilations" , dilations); |
2848 | nb->Attr("data_format" , data_format); |
2849 | if (has_padding_list) { |
2850 | nb->Attr("padding_list" , padding_list); |
2851 | } |
2852 | |
2853 | // Requantization attr Tbias. |
2854 | DataType Tbias; |
2855 | Status bias_status = GetNodeAttr(orig_node->def(), "Tbias" , &Tbias); |
2856 | if (bias_status.ToString() == "OK" ) nb->Attr("Tbias" , Tbias); |
2857 | } |
2858 | |
2859 | void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBiasAndDequantize( |
2860 | const Node* orig_node, NodeBuilder* nb, bool change_format) { |
2861 | CopyAttrsAll(orig_node, nb, change_format); |
2862 | |
2863 | // Check and set filter attribute. |
2864 | Node* filter_node = nullptr; |
2865 | TF_CHECK_OK(orig_node->input_node(1, &filter_node)); |
2866 | nb->Attr("is_weight_const" , filter_node->IsConstant()); |
2867 | } |
2868 | |
2869 | void MklLayoutRewritePass::CopyAttrsQuantizedMatMulWithBias( |
2870 | const Node* orig_node, NodeBuilder* nb, bool change_format) { |
2871 | DataType T1, T2, Toutput; |
2872 | |
2873 | // Get all attributes from old node. |
2874 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T1" , &T1)); |
2875 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T2" , &T2)); |
2876 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Toutput" , &Toutput)); |
2877 | |
2878 | Node* weight_node = nullptr; |
2879 | TF_CHECK_OK(orig_node->input_node(1, &weight_node)); |
2880 | |
2881 | // Add attributes to new node. |
2882 | nb->Attr("T1" , T1); |
2883 | nb->Attr("T2" , T2); |
2884 | nb->Attr("Toutput" , Toutput); |
2885 | nb->Attr("is_weight_const" , weight_node->IsConstant()); |
2886 | |
2887 | // Requantization attr Tbias |
2888 | DataType Tbias; |
2889 | Status bias_status = GetNodeAttr(orig_node->def(), "Tbias" , &Tbias); |
2890 | if (bias_status.ToString() == "OK" ) nb->Attr("Tbias" , Tbias); |
2891 | } |
2892 | |
2893 | void MklLayoutRewritePass::CopyFormatAttrsConv( |
2894 | const Node* orig_node, NodeBuilder* nb, const std::vector<int32>& strides, |
2895 | const std::vector<int32>& dilations, bool change_format) { |
2896 | string data_format; |
2897 | |
2898 | if (!change_format) { |
2899 | nb->Attr("strides" , strides); |
2900 | nb->Attr("dilations" , dilations); |
2901 | |
2902 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format" , &data_format)); |
2903 | nb->Attr("data_format" , data_format); |
2904 | } else { |
2905 | std::vector<int32> new_strides; |
2906 | std::vector<int32> new_dilations; |
2907 | if (strides.size() == 5) { |
2908 | // `strides` and `dilations` also need to be changed according to |
2909 | // `data_format`. In this case, from `NDHWC` to `NCDHW`. |
2910 | new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C], |
2911 | strides[NDHWC::dim::D], strides[NDHWC::dim::H], |
2912 | strides[NDHWC::dim::W]}; |
2913 | |
2914 | new_dilations = {dilations[NDHWC::dim::N], dilations[NDHWC::dim::C], |
2915 | dilations[NDHWC::dim::D], dilations[NDHWC::dim::H], |
2916 | dilations[NDHWC::dim::W]}; |
2917 | } else { |
2918 | // `strides` and `dilations` also need to be changed according to |
2919 | // `data_format`. In this case, from `NHWC` to `NCHW`. |
2920 | |
2921 | new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C], |
2922 | strides[NHWC::dim::H], strides[NHWC::dim::W]}; |
2923 | |
2924 | new_dilations = {dilations[NHWC::dim::N], dilations[NHWC::dim::C], |
2925 | dilations[NHWC::dim::H], dilations[NHWC::dim::W]}; |
2926 | } |
2927 | nb->Attr("strides" , new_strides); |
2928 | nb->Attr("dilations" , new_dilations); |
2929 | } |
2930 | } |
2931 | |
2932 | void MklLayoutRewritePass::CopyAttrsPooling(const Node* orig_node, |
2933 | NodeBuilder* nb, |
2934 | bool change_format) { |
2935 | DataType T; |
2936 | string data_format; |
2937 | string padding; |
2938 | std::vector<int32> ksize, strides; |
2939 | |
2940 | // Get all attributes from old node. |
2941 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T" , &T)); |
2942 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "ksize" , &ksize)); |
2943 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides" , &strides)); |
2944 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding" , &padding)); |
2945 | TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format" , &data_format)); |
2946 | |
2947 | // Add attributes to new node. |
2948 | nb->Attr("T" , T); |
2949 | nb->Attr("padding" , padding); |
2950 | |
2951 | if (!change_format) { |
2952 | nb->Attr("strides" , strides); |
2953 | nb->Attr("ksize" , ksize); |
2954 | |
2955 | nb->Attr("data_format" , data_format); |
2956 | } else { |
2957 | std::vector<int32> new_strides; |
2958 | std::vector<int32> new_ksize; |
2959 | if (strides.size() == 5) { |
2960 | DCHECK(data_format == "NCDHW" ); |
2961 | // `strides` and `ksize` also need to be changed according to |
2962 | // `data_format`. In this case, from `NDHWC` to `NCDHW`. |
2963 | new_strides = {strides[NDHWC::dim::N], strides[NDHWC::dim::C], |
2964 | strides[NDHWC::dim::D], strides[NDHWC::dim::H], |
2965 | strides[NDHWC::dim::W]}; |
2966 | |
2967 | new_ksize = {ksize[NDHWC::dim::N], ksize[NDHWC::dim::C], |
2968 | ksize[NDHWC::dim::D], ksize[NDHWC::dim::H], |
2969 | ksize[NDHWC::dim::W]}; |
2970 | |
2971 | } else { |
2972 | // `strides` and `ksize` also need to be changed according to |
2973 | // `data_format`. In this case, from `NHWC` to `NCHW`. |
2974 | DCHECK(data_format == "NCHW" ); |
2975 | new_strides = {strides[NHWC::dim::N], strides[NHWC::dim::C], |
2976 | strides[NHWC::dim::H], strides[NHWC::dim::W]}; |
2977 | |
2978 | new_ksize = {ksize[NHWC::dim::N], ksize[NHWC::dim::C], |
2979 | ksize[NHWC::dim::H], ksize[NHWC::dim::W]}; |
2980 | } |
2981 | nb->Attr("strides" , new_strides); |
2982 | nb->Attr("ksize" , new_ksize); |
2983 | } |
2984 | } |
2985 | |
2986 | ////////////////////////////////////////////////////////////////////////// |
2987 | // Helper functions related to node merge pass |
2988 | ////////////////////////////////////////////////////////////////////////// |
2989 | |
2990 | Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const { |
2991 | // TODO(nhasabni) Add check for type of node similar to CheckForNodeRewrite |
2992 | // once we support BiasAddGrad as Mkl layer. |
2993 | |
2994 | // Search for all matching mergeinfo. |
2995 | // We allow more than one match for extensibility. |
2996 | std::vector<const MergeInfo*> matching_mi; |
2997 | for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) { |
2998 | if (a->type_string() == mi->op1 || a->type_string() == mi->op2) { |
2999 | matching_mi.push_back(&*mi); |
3000 | } |
3001 | } |
3002 | |
3003 | for (const MergeInfo* mi : matching_mi) { |
3004 | // Get the operand with which 'a' can be merged. |
3005 | Node* b = nullptr; |
3006 | if ((b = mi->get_node_to_be_merged(a)) == nullptr) { |
3007 | continue; |
3008 | } |
3009 | |
3010 | // Get the control edges and input of node |
3011 | const int N_in = a->num_inputs(); |
3012 | gtl::InlinedVector<Node*, 4> a_control_edges; |
3013 | gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in); |
3014 | FillInputs(a, &a_control_edges, &a_in); |
3015 | |
3016 | const int B_in = b->num_inputs(); |
3017 | gtl::InlinedVector<Node*, 4> b_control_edges; |
3018 | gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(B_in); |
3019 | FillInputs(b, &b_control_edges, &b_in); |
3020 | |
3021 | // Shouldn't merge if a and b have different control edges. |
3022 | if (a_control_edges != b_control_edges) { |
3023 | continue; |
3024 | } else { |
3025 | // We found a match. |
3026 | return b; |
3027 | } |
3028 | } |
3029 | |
3030 | return nullptr; |
3031 | } |
3032 | |
3033 | Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g, |
3034 | Node* m, Node* n) { |
3035 | CHECK_EQ(((m->type_string() == csinfo_.bias_add && |
3036 | n->type_string() == csinfo_.conv2d)) || |
3037 | ((n->type_string() == csinfo_.bias_add && |
3038 | m->type_string() == csinfo_.conv2d)), |
3039 | true); |
3040 | |
3041 | // If 'm' is BiasAdd, then 'n' is Conv2D. Since Conv2D feeds BiasAdd, |
3042 | // BiasAdd is successor node, and Conv2D predecessor node. |
3043 | Node* pred = m->type_string() == csinfo_.bias_add ? n : m; |
3044 | Node* succ = m->type_string() == csinfo_.bias_add ? m : n; |
3045 | |
3046 | // 1. Get all attributes from input nodes. |
3047 | DataType T_pred, T_succ; |
3048 | string padding; |
3049 | std::vector<int32> strides; |
3050 | std::vector<int32> dilations; |
3051 | string data_format_pred, data_format_succ; |
3052 | bool use_cudnn_on_gpu; |
3053 | TF_CHECK_OK(GetNodeAttr(pred->def(), "T" , &T_pred)); |
3054 | TF_CHECK_OK(GetNodeAttr(succ->def(), "T" , &T_succ)); |
3055 | TF_CHECK_OK(GetNodeAttr(pred->def(), "padding" , &padding)); |
3056 | TF_CHECK_OK(GetNodeAttr(pred->def(), "strides" , &strides)); |
3057 | TF_CHECK_OK(GetNodeAttr(pred->def(), "dilations" , &dilations)); |
3058 | TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format" , &data_format_pred)); |
3059 | TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format" , &data_format_succ)); |
3060 | TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu" , &use_cudnn_on_gpu)); |
3061 | // We check to ensure that data formats of both succ and pred are same. |
3062 | // We expect them to be same, so we can enforce this as assert. |
3063 | // But assert can be too strict, so we enforce this as a check. |
3064 | // If the check fails, then we do not merge two nodes. |
3065 | // We also do same check for devices. |
3066 | if (data_format_pred != data_format_succ || T_pred != T_succ || |
3067 | pred->assigned_device_name() != succ->assigned_device_name() || |
3068 | pred->def().device() != succ->def().device()) { |
3069 | return Status(error::Code::INVALID_ARGUMENT, |
3070 | "data_format or T attribute or devices of Conv2D and " |
3071 | "BiasAdd do not match. Will skip node merge optimization" ); |
3072 | } |
3073 | |
3074 | const int succ_num = succ->num_inputs(); |
3075 | gtl::InlinedVector<Node*, 4> succ_control_edges; |
3076 | gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num); |
3077 | FillInputs(succ, &succ_control_edges, &succ_in); |
3078 | |
3079 | const int pred_num = pred->num_inputs(); |
3080 | gtl::InlinedVector<Node*, 4> pred_control_edges; |
3081 | gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num); |
3082 | FillInputs(pred, &pred_control_edges, &pred_in); |
3083 | |
3084 | // We need to ensure that Conv2D only feeds to BiasAdd (some other operator is |
3085 | // not expecting output of Conv2D). If this is not the case, then we cannot |
3086 | // merge Conv2D with BiasAdd. |
3087 | const int kFirstOutputSlot = 0; |
3088 | for (const Edge* e : pred->out_edges()) { |
3089 | if (e->src_output() == kFirstOutputSlot && e->dst() != succ) { |
3090 | return Status(error::Code::INVALID_ARGUMENT, |
3091 | "Conv2D does not feed to BiasAdd, or " |
3092 | "it feeds BiasAdd but has multiple outputs. " |
3093 | "Will skip node merge optimization" ); |
3094 | } |
3095 | } |
3096 | |
3097 | // 2. Get inputs from both the nodes. |
3098 | // Find the 2 inputs from the conv and the bias from the add Bias. |
3099 | // Get operand 0, 1 of conv2D. |
3100 | CHECK_EQ(pred->in_edges().size(), 2); // Conv2D must have 2 inputs. |
3101 | // Get operand 1 of add_bias |
3102 | // BiasAdd must have 2 inputs: Conv, bias |
3103 | CHECK_EQ(succ->in_edges().size(), 2); |
3104 | |
3105 | // We will use the node name of BiasAdd as the name of new node |
3106 | // Build new node. We use same name as original node, but change the op |
3107 | // name. |
3108 | NodeBuilder nb(succ->name(), csinfo_.conv2d_with_bias); |
3109 | nb.Input(pred_in[0].first, pred_in[0].second); // In1 of Conv2D |
3110 | // pred_in[1] will be 2nd Tensorflow tensor for Conv2D. |
3111 | nb.Input(pred_in[1].first, pred_in[1].second); // In2 of Conv2D |
3112 | // In1 of BiasAdd is same as output of Conv2D. |
3113 | nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd |
3114 | |
3115 | // Copy attributes from Conv2D to Conv2DWithBias. |
3116 | CopyAttrsConvCheckConstFilter(const_cast<const Node*>(pred), &nb); |
3117 | |
3118 | // Copy the device assigned to old node to new node. |
3119 | nb.Device(succ->def().device()); |
3120 | |
3121 | // Create node. |
3122 | Node* new_node; |
3123 | TF_CHECK_OK(nb.Finalize(&**g, &new_node)); |
3124 | |
3125 | // In the following code of this function, an unsorted set is used to make |
3126 | // sure no duplicated edges be added into the new node. Therefore, we can |
3127 | // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges) |
3128 | // check in the routine. |
3129 | |
3130 | // Incoming data edges from 'pred' node and 'succ' node to new 'new_node' |
3131 | // node are already copied in BuildNode. We handle control edges now. |
3132 | std::unordered_set<Node*> unique_node; |
3133 | for (const Edge* e : pred->in_edges()) { |
3134 | if (e->IsControlEdge()) { |
3135 | auto result = unique_node.insert(e->src()); |
3136 | if (result.second) { |
3137 | (*g)->AddControlEdge(e->src(), new_node, true); |
3138 | } |
3139 | } |
3140 | } |
3141 | unique_node.clear(); |
3142 | |
3143 | for (const Edge* e : succ->in_edges()) { |
3144 | if (e->IsControlEdge()) { |
3145 | auto result = unique_node.insert(e->src()); |
3146 | if (result.second) { |
3147 | (*g)->AddControlEdge(e->src(), new_node, true); |
3148 | } |
3149 | } |
3150 | } |
3151 | unique_node.clear(); |
3152 | |
3153 | // Incoming edges are fixed, we will fix the outgoing edges now. |
3154 | // First, we will fix outgoing control edges from 'pred' node. |
3155 | for (const Edge* e : pred->out_edges()) { |
3156 | if (e->IsControlEdge()) { |
3157 | auto result = unique_node.insert(e->dst()); |
3158 | if (result.second) { |
3159 | (*g)->AddControlEdge(new_node, e->dst(), true); |
3160 | } |
3161 | } |
3162 | } |
3163 | unique_node.clear(); |
3164 | |
3165 | // Second, we will fix outgoing control and data edges from 'succ' node. |
3166 | for (const Edge* e : succ->out_edges()) { |
3167 | if (e->IsControlEdge()) { |
3168 | auto result = unique_node.insert(e->dst()); |
3169 | if (result.second) { |
3170 | (*g)->AddControlEdge(new_node, e->dst(), true); |
3171 | } |
3172 | } else { |
3173 | // BiasAdd has only 1 output (at slot 0) and merged node also has only 1 |
3174 | // output (at slot 0). |
3175 | const int kConv2DWithBiasOutputSlot = 0; |
3176 | auto new_edge = (*g)->AddEdge(new_node, kConv2DWithBiasOutputSlot, |
3177 | e->dst(), e->dst_input()); |
3178 | DCHECK(new_edge); |
3179 | } |
3180 | } |
3181 | |
3182 | // Copy device assigned to old node to new node. |
3183 | // It's ok to use pred or succ as we have enforced a check that |
3184 | // both have same device assigned. |
3185 | new_node->set_assigned_device_name(pred->assigned_device_name()); |
3186 | |
3187 | VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString() |
3188 | << ", and node: " << succ->DebugString() |
3189 | << ", into node:" << new_node->DebugString(); |
3190 | |
3191 | (*g)->RemoveNode(succ); |
3192 | (*g)->RemoveNode(pred); |
3193 | |
3194 | return OkStatus(); |
3195 | } |
3196 | |
3197 | Status MklLayoutRewritePass::MergePadWithConv2D(std::unique_ptr<Graph>* g, |
3198 | Node* m, Node* n) { |
3199 | DCHECK((m->type_string() == csinfo_.pad && |
3200 | (n->type_string() == csinfo_.conv2d || |
3201 | n->type_string() == csinfo_.fused_conv2d)) || |
3202 | (n->type_string() == csinfo_.pad && |
3203 | (m->type_string() == csinfo_.conv2d || |
3204 | m->type_string() == csinfo_.fused_conv2d))); |
3205 | |
3206 | bool is_fused_conv2d = n->type_string() == csinfo_.fused_conv2d || |
3207 | m->type_string() == csinfo_.fused_conv2d; |
3208 | // Conv2D is successor node, and Pad predecessor node. |
3209 | Node* pred = m->type_string() == csinfo_.pad ? m : n; |
3210 | Node* succ = m->type_string() == csinfo_.pad ? n : m; |
3211 | |
3212 | // 1. Get all attributes from input nodes. |
3213 | DataType T_pred, T_succ; |
3214 | string padding; |
3215 | std::vector<int32> strides; |
3216 | std::vector<int32> dilations; |
3217 | string data_format_pred, data_format_succ; |
3218 | |
3219 | TF_CHECK_OK(GetNodeAttr(pred->def(), "T" , &T_pred)); |
3220 | TF_CHECK_OK(GetNodeAttr(succ->def(), "T" , &T_succ)); |
3221 | TF_CHECK_OK(GetNodeAttr(succ->def(), "padding" , &padding)); |
3222 | TF_CHECK_OK(GetNodeAttr(succ->def(), "strides" , &strides)); |
3223 | TF_CHECK_OK(GetNodeAttr(succ->def(), "dilations" , &dilations)); |
3224 | // Check if the devices of both succ and pred are the same. |
3225 | // Assert is not used because it can be too strict. |
3226 | // Don't need to check for data formats because it is not available in Pad. |
3227 | if (T_pred != T_succ || |
3228 | pred->assigned_device_name() != succ->assigned_device_name() || |
3229 | pred->def().device() != succ->def().device()) { |
3230 | return Status(error::Code::INVALID_ARGUMENT, |
3231 | "T attribute or devices of Conv2D and " |
3232 | "Pad do not match. Will skip node merge optimization" ); |
3233 | } |
3234 | |
3235 | const int succ_num = succ->num_inputs(); |
3236 | gtl::InlinedVector<Node*, 4> succ_control_edges; |
3237 | gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num); |
3238 | FillInputs(succ, &succ_control_edges, &succ_in); |
3239 | |
3240 | const int pred_num = pred->num_inputs(); |
3241 | gtl::InlinedVector<Node*, 4> pred_control_edges; |
3242 | gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num); |
3243 | FillInputs(pred, &pred_control_edges, &pred_in); |
3244 | |
3245 | // We need to ensure that Pad only feeds to Conv2D (some other operator is |
3246 | // not expecting output of Pad). If this is not the case, then we cannot |
3247 | // merge Conv2D with Pad. |
3248 | const int kFirstOutputSlot = 0; |
3249 | for (const Edge* e : pred->out_edges()) { |
3250 | if (e->src_output() == kFirstOutputSlot && e->dst() != succ) { |
3251 | return Status(error::Code::INVALID_ARGUMENT, |
3252 | "Pad does not feed to Conv2D, or " |
3253 | "it feeds Conv2D but has multiple outputs. " |
3254 | "Will skip node merge optimization" ); |
3255 | } |
3256 | } |
3257 | |
3258 | // 2. Get inputs from both the nodes. |
3259 | |
3260 | // Pad must have 2 data inputs: "input" and paddings. |
3261 | int PadDataInputEdges = 0; |
3262 | for (const Edge* e : pred->in_edges()) { |
3263 | if (!e->IsControlEdge()) { |
3264 | PadDataInputEdges++; |
3265 | } |
3266 | } |
3267 | DCHECK_EQ(PadDataInputEdges, 2); |
3268 | |
3269 | // Conv2D must have 2 data inputs: Pad output and Filter |
3270 | // FusedConv2D have 3 data inputs: Pad output, Filter and Args; |
3271 | int ConvDataInputEdges = 0; |
3272 | for (const Edge* e : succ->in_edges()) { |
3273 | if (!e->IsControlEdge()) { |
3274 | ConvDataInputEdges++; |
3275 | } |
3276 | } |
3277 | |
3278 | DCHECK_EQ(ConvDataInputEdges, is_fused_conv2d ? 3 : 2); |
3279 | |
3280 | // We will use the node name of Conv2D as the name of new node |
3281 | // Build new node. We use same name as original node, but change the op |
3282 | // name. |
3283 | |
3284 | NodeBuilder nb(succ->name(), is_fused_conv2d ? csinfo_.pad_with_fused_conv2d |
3285 | : csinfo_.pad_with_conv2d); |
3286 | nb.Input(pred_in[0].first, pred_in[0].second); // In1 (input data) of Pad |
3287 | // pred_in[1] will be 2nd Tensorflow tensor for Conv2D. |
3288 | nb.Input(succ_in[1].first, succ_in[1].second); // In2 (filter) of conv2d |
3289 | // In1 of Conv2D is same as output of Pad. |
3290 | // Thus, only need to add In2 of Conv2D |
3291 | |
3292 | if (is_fused_conv2d) { |
3293 | // FusedConv2D has one additional input, args |
3294 | std::vector<NodeBuilder::NodeOut> args; |
3295 | int num_args = 1; |
3296 | TF_CHECK_OK(GetNodeAttr(succ->def(), "num_args" , &num_args)); |
3297 | for (int i = 0; i < num_args; i++) { |
3298 | args.emplace_back(succ_in[2 + i].first, succ_in[2 + i].second); |
3299 | } |
3300 | nb.Input(gtl::ArraySlice<NodeBuilder::NodeOut>{ |
3301 | args}); // In3 (args) of FusedConv2D |
3302 | nb.Input(pred_in[1].first, pred_in[1].second); // In2 (paddings) of Pad |
3303 | // Copy attributes from Pad and FusedConv2D to PadWithFusedConv2D. |
3304 | CopyAttrsFromPadAndFusedConv2D(const_cast<const Node*>(succ), |
3305 | const_cast<const Node*>(pred), &nb); |
3306 | } else { |
3307 | nb.Input(pred_in[1].first, pred_in[1].second); // In2 (paddings) of Pad |
3308 | // Copy attributes from Pad and conv2D to PadWithConv2D. |
3309 | CopyAttrsFromPadAndConv2D(const_cast<const Node*>(succ), |
3310 | const_cast<const Node*>(pred), &nb); |
3311 | } |
3312 | |
3313 | // Copy the device assigned to old node to new node. |
3314 | nb.Device(succ->def().device()); |
3315 | |
3316 | // Create node. |
3317 | Node* new_node; |
3318 | TF_CHECK_OK(nb.Finalize(&**g, &new_node)); |
3319 | // No need to check if new_node is null because it will be null only when |
3320 | // Finalize fails. |
3321 | |
3322 | // Incoming data edges from 'pred' node and 'succ' node to new 'new_node' |
3323 | // node are already copied in BuildNode. |
3324 | // We handle control edges now. |
3325 | for (const Edge* e : pred->in_edges()) { |
3326 | if (e->IsControlEdge()) { |
3327 | // Don't allow duplicate edge |
3328 | (*g)->AddControlEdge(e->src(), new_node, false); |
3329 | } |
3330 | } |
3331 | for (const Edge* e : succ->in_edges()) { |
3332 | if (e->IsControlEdge()) { |
3333 | // Don't allow duplicate edge |
3334 | (*g)->AddControlEdge(e->src(), new_node, false); |
3335 | } |
3336 | } |
3337 | |
3338 | // Incoming edges are fixed, we will fix the outgoing edges now. |
3339 | // First, we will fix outgoing control edges from 'pred' node. |
3340 | for (const Edge* e : pred->out_edges()) { |
3341 | if (e->IsControlEdge()) { |
3342 | // Don't allow duplicate edge |
3343 | (*g)->AddControlEdge(new_node, e->dst(), false); |
3344 | } |
3345 | } |
3346 | |
3347 | // Second, we will fix outgoing control and data edges from 'succ' node. |
3348 | for (const Edge* e : succ->out_edges()) { |
3349 | if (e->IsControlEdge()) { |
3350 | // Allow duplicate while adding control edge as it would fail (return |
3351 | // NULL) if we try to add duplicate edge. |
3352 | (*g)->AddControlEdge(new_node, e->dst(), false); |
3353 | } else { |
3354 | // Conv2D has only 1 output (at slot 0) and merged node also has only 1 |
3355 | // output (at slot 0). |
3356 | const int kPadWithConv2DOutputSlot = 0; |
3357 | (*g)->AddEdge(new_node, kPadWithConv2DOutputSlot, e->dst(), |
3358 | e->dst_input()); |
3359 | } |
3360 | } |
3361 | |
3362 | // Copy device assigned to old node to new node. |
3363 | // It's ok to use pred or succ as we have enforced a check that |
3364 | // both have same device assigned. |
3365 | new_node->set_assigned_device_name(pred->assigned_device_name()); |
3366 | |
3367 | VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString() |
3368 | << ", and node: " << succ->DebugString() |
3369 | << ", into node:" << new_node->DebugString(); |
3370 | |
3371 | (*g)->RemoveNode(succ); |
3372 | (*g)->RemoveNode(pred); |
3373 | |
3374 | return OkStatus(); |
3375 | } |
3376 | |
3377 | Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad( |
3378 | std::unique_ptr<Graph>* g, Node* m, Node* n) { |
3379 | CHECK_EQ(((m->type_string() == csinfo_.bias_add_grad && |
3380 | n->type_string() == csinfo_.conv2d_grad_filter)) || |
3381 | ((n->type_string() == csinfo_.bias_add_grad && |
3382 | m->type_string() == csinfo_.conv2d_grad_filter)), |
3383 | true); |
3384 | |
3385 | // If 'm' is BiasAddGrad, then 'n' is BackpropFilter. |
3386 | Node* badd = m->type_string() == csinfo_.bias_add_grad ? m : n; |
3387 | Node* fltr = m->type_string() == csinfo_.bias_add_grad ? n : m; |
3388 | |
3389 | // Sanity check for attributes from input nodes. |
3390 | DataType T_b, T_f; |
3391 | string data_format_b, data_format_f; |
3392 | TF_CHECK_OK(GetNodeAttr(badd->def(), "T" , &T_b)); |
3393 | TF_CHECK_OK(GetNodeAttr(fltr->def(), "T" , &T_f)); |
3394 | TF_CHECK_OK(GetNodeAttr(badd->def(), "data_format" , &data_format_b)); |
3395 | TF_CHECK_OK(GetNodeAttr(fltr->def(), "data_format" , &data_format_f)); |
3396 | if (data_format_b != data_format_f || T_b != T_f || |
3397 | badd->assigned_device_name() != fltr->assigned_device_name() || |
3398 | badd->def().device() != fltr->def().device()) { |
3399 | return Status(error::Code::INVALID_ARGUMENT, |
3400 | "data_format or T attribute or devices of " |
3401 | "Conv2DBackpropFilter and BiasAddGrad do not match. " |
3402 | "Will skip node merge optimization" ); |
3403 | } |
3404 | |
3405 | // We will use the node name of Conv2DBackpropFilter as the name of new node. |
3406 | // This is because BackpropFilterWithBias is going to emit bias output also. |
3407 | NodeBuilder nb(fltr->name(), csinfo_.conv2d_grad_filter_with_bias); |
3408 | // Since Conv2DBackpropFilterWithBias has same number of inputs as |
3409 | // Conv2DBackpropFilter, we can just copy input edges directly. We don't need |
3410 | // to copy any data input of BiasAddGrad because that input also goes to |
3411 | // Conv2DBackpropFilter. |
3412 | const int fltr_ins = fltr->num_inputs(); |
3413 | gtl::InlinedVector<Node*, 4> fltr_control_edges; |
3414 | gtl::InlinedVector<std::pair<Node*, int>, 4> fltr_in_edges(fltr_ins); |
3415 | FillInputs(fltr, &fltr_control_edges, &fltr_in_edges); |
3416 | for (int idx = 0; idx < fltr_ins; idx++) { |
3417 | nb.Input(fltr_in_edges[idx].first, fltr_in_edges[idx].second); |
3418 | } |
3419 | |
3420 | // Copy attributes from Conv2DBackpropFilter. |
3421 | CopyAttrsConv(const_cast<const Node*>(fltr), &nb); |
3422 | |
3423 | // Copy the device assigned to old node to new node. |
3424 | nb.Device(fltr->def().device()); |
3425 | |
3426 | // Create node. |
3427 | Node* new_node; |
3428 | TF_CHECK_OK(nb.Finalize(&**g, &new_node)); |
3429 | |
3430 | // In the following code of this function, an unsorted set is used to make |
3431 | // sure no duplicated edges be added into the new node. Therefore, we can |
3432 | // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges) |
3433 | // check in the routine. |
3434 | |
3435 | // Incoming data edges from BiasAddGrad node and Conv2DBackpropFilter node to |
3436 | // new 'new_node' node are already copied in BuildNode. We handle control |
3437 | // edges now. |
3438 | std::unordered_set<Node*> unique_node; |
3439 | for (const Edge* e : badd->in_edges()) { |
3440 | if (e->IsControlEdge()) { |
3441 | auto result = unique_node.insert(e->src()); |
3442 | if (result.second) { |
3443 | (*g)->AddControlEdge(e->src(), new_node, true); |
3444 | } |
3445 | } |
3446 | } |
3447 | unique_node.clear(); |
3448 | for (const Edge* e : fltr->in_edges()) { |
3449 | if (e->IsControlEdge()) { |
3450 | auto result = unique_node.insert(e->src()); |
3451 | if (result.second) { |
3452 | (*g)->AddControlEdge(e->src(), new_node, true); |
3453 | } |
3454 | } |
3455 | } |
3456 | unique_node.clear(); |
3457 | |
3458 | // Incoming edges are fixed, we will fix the outgoing edges now. |
3459 | // First, we will fix outgoing control edges from 'badd' node. |
3460 | // Conv2DBackpropFilter has 1 output -- filter_grad. |
3461 | // Conv2DBackpropFilterWithBias has 2 outputs -- filter_grad and |
3462 | // bias_grad. But filter_grad is at same slot number (0) in both the |
3463 | // nodes. bias_grad is at slot number 1 in Conv2DBackpropFilterWithBias, while |
3464 | // it is at slot number 0 in BiasAddGrad. |
3465 | const int kMergedNodeFilterGradOutputIdx = 0; |
3466 | const int kMergedNodeBiasGradOutputIdx = 1; |
3467 | |
3468 | for (const Edge* e : badd->out_edges()) { |
3469 | if (e->IsControlEdge()) { |
3470 | auto result = unique_node.insert(e->dst()); |
3471 | if (result.second) { |
3472 | (*g)->AddControlEdge(new_node, e->dst(), true); |
3473 | } |
3474 | } else { |
3475 | auto new_edge = (*g)->AddEdge(new_node, kMergedNodeBiasGradOutputIdx, |
3476 | e->dst(), e->dst_input()); |
3477 | DCHECK(new_edge); |
3478 | } |
3479 | } |
3480 | unique_node.clear(); |
3481 | |
3482 | // Second, we will fix outgoing control and data edges from 'fltr' node. |
3483 | for (const Edge* e : fltr->out_edges()) { |
3484 | if (e->IsControlEdge()) { |
3485 | auto result = unique_node.insert(e->dst()); |
3486 | if (result.second) { |
3487 | (*g)->AddControlEdge(new_node, e->dst(), true); |
3488 | } |
3489 | } else { |
3490 | auto new_edge = (*g)->AddEdge(new_node, kMergedNodeFilterGradOutputIdx, |
3491 | e->dst(), e->dst_input()); |
3492 | DCHECK(new_edge); |
3493 | } |
3494 | } |
3495 | |
3496 | // Copy device assigned to old node to new node. |
3497 | // It's ok to use badd or fltr as we have enforced a check that |
3498 | // both have same device assigned. |
3499 | new_node->set_assigned_device_name(badd->assigned_device_name()); |
3500 | |
3501 | VLOG(1) << "MklLayoutRewritePass: Merged old node:" << badd->DebugString() |
3502 | << ", and node: " << fltr->DebugString() |
3503 | << ", into node:" << new_node->DebugString(); |
3504 | |
3505 | (*g)->RemoveNode(badd); |
3506 | (*g)->RemoveNode(fltr); |
3507 | |
3508 | return OkStatus(); |
3509 | } |
3510 | |
3511 | Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* m, |
3512 | Node* n) { |
3513 | DCHECK(m); |
3514 | DCHECK(n); |
3515 | |
3516 | if (((m->type_string() == csinfo_.bias_add && |
3517 | n->type_string() == csinfo_.conv2d)) || |
3518 | ((n->type_string() == csinfo_.bias_add && |
3519 | m->type_string() == csinfo_.conv2d))) { |
3520 | return this->MergeConv2DWithBiasAdd(g, m, n); |
3521 | } |
3522 | if ((m->type_string() == csinfo_.pad && |
3523 | (n->type_string() == csinfo_.conv2d || |
3524 | (n->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(n)))) || |
3525 | (n->type_string() == csinfo_.pad && |
3526 | (m->type_string() == csinfo_.conv2d || |
3527 | (m->type_string() == csinfo_.fused_conv2d && FusedConv2DRewrite(m))))) { |
3528 | return this->MergePadWithConv2D(g, m, n); |
3529 | } |
3530 | |
3531 | if (((m->type_string() == csinfo_.bias_add_grad && |
3532 | n->type_string() == csinfo_.conv2d_grad_filter)) || |
3533 | ((n->type_string() == csinfo_.bias_add_grad && |
3534 | m->type_string() == csinfo_.conv2d_grad_filter))) { |
3535 | return this->MergeConv2DBackpropFilterWithBiasAddGrad(g, m, n); |
3536 | } |
3537 | |
3538 | return Status(error::Code::UNIMPLEMENTED, |
3539 | "Unimplemented case for node merge optimization." ); |
3540 | } |
3541 | |
3542 | ////////////////////////////////////////////////////////////////////////// |
3543 | // Helper functions for node rewrite |
3544 | ////////////////////////////////////////////////////////////////////////// |
3545 | |
3546 | Status MklLayoutRewritePass::RewriteNodeForLayoutPropagation( |
3547 | std::unique_ptr<Graph>* g, const Node* orig_node, Node** new_node, |
3548 | const RewriteInfo* ri) { |
3549 | // Get all data inputs. |
3550 | int num_data_inputs = orig_node->in_edges().size(); |
3551 | // Drop count for control edges from inputs |
3552 | for (const Edge* e : orig_node->in_edges()) { |
3553 | if (e->IsControlEdge()) { |
3554 | num_data_inputs--; |
3555 | } |
3556 | } |
3557 | |
3558 | gtl::InlinedVector<Node*, 4> control_edges; |
3559 | gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_data_inputs); |
3560 | FillInputs(orig_node, &control_edges, &inputs); |
3561 | |
3562 | // Build new node. We use same name as original node, but change the op name. |
3563 | NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str()); |
3564 | // Copy user-specified device assigned to original node to new node. |
3565 | nb.Device(orig_node->def().device()); |
3566 | // Set up new inputs to the rewritten node. |
3567 | Status s = SetUpInputs(g, inputs, &nb, orig_node); |
3568 | if (s != OkStatus()) { |
3569 | return s; |
3570 | } |
3571 | |
3572 | const bool kPartialCopyAttrs = false; |
3573 | ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, kPartialCopyAttrs); |
3574 | |
3575 | // Set the Mkl layer label for this op. |
3576 | if (DataTypeIsQuantized(orig_node->input_type(0)) || |
3577 | DataTypeIsQuantized(orig_node->output_type(0))) { |
3578 | nb.Attr("_kernel" , mkl_op_registry::kMklQuantizedOpLabel); |
3579 | } else { |
3580 | nb.Attr("_kernel" , mkl_op_registry::kMklLayoutDependentOpLabel); |
3581 | } |
3582 | // Finalize graph and get new node. |
3583 | s = nb.Finalize(&**g, new_node); |
3584 | if (s != OkStatus()) { |
3585 | return s; |
3586 | } |
3587 | |
3588 | // In the following code of this function, an unsorted set is used to make |
3589 | // sure no duplicated edges be added into the new node. Therefore, we can |
3590 | // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges) |
3591 | // check in the routine. |
3592 | |
3593 | // Incoming data edges from 'orig_node' node to new 'new_node' node are |
3594 | // already copied in BuildNode. We need to handle control edges now. |
3595 | std::unordered_set<Node*> unique_node; |
3596 | for (const Edge* e : orig_node->in_edges()) { |
3597 | if (e->IsControlEdge()) { |
3598 | auto result = unique_node.insert(e->src()); |
3599 | if (result.second) { |
3600 | (*g)->AddControlEdge(e->src(), *new_node, true); |
3601 | } |
3602 | } |
3603 | } |
3604 | unique_node.clear(); |
3605 | |
3606 | // Copy outgoing edges from 'orig_node' node to new |
3607 | // 'new_node' node, since the output also follows same ordering among |
3608 | // Tensorflow tensors and Mkl tensors. We need to connect Tensorflow |
3609 | // tensors appropriately. Specifically, nth output of the original node |
3610 | // will become 2*nth output of the Mkl node for the interleaved ordering |
3611 | // of the tensors. For the contiguous ordering of the tensors, it will be n. |
3612 | // GetTensorDataIndex provides this mapping function. |
3613 | for (const Edge* e : orig_node->out_edges()) { |
3614 | if (e->IsControlEdge()) { |
3615 | auto result = unique_node.insert(e->dst()); |
3616 | if (result.second) { |
3617 | (*g)->AddControlEdge(*new_node, e->dst(), true); |
3618 | } |
3619 | } else { |
3620 | auto new_edge = (*g)->AddEdge( |
3621 | *new_node, |
3622 | GetTensorDataIndex(e->src_output(), e->src()->num_outputs()), |
3623 | e->dst(), e->dst_input()); |
3624 | DCHECK(new_edge); |
3625 | } |
3626 | } |
3627 | return OkStatus(); |
3628 | } |
3629 | |
3630 | Status MklLayoutRewritePass::RewriteNodeForJustOpNameChange( |
3631 | std::unique_ptr<Graph>* g, const Node* orig_node, Node** new_node, |
3632 | const RewriteInfo* ri) { |
3633 | // Get all data inputs. |
3634 | int num_data_inputs = orig_node->in_edges().size(); |
3635 | // Drop count for control edges from inputs |
3636 | for (const Edge* e : orig_node->in_edges()) { |
3637 | if (e->IsControlEdge()) { |
3638 | num_data_inputs--; |
3639 | } |
3640 | } |
3641 | gtl::InlinedVector<Node*, 4> control_edges; |
3642 | gtl::InlinedVector<std::pair<Node*, int>, 4> inputs(num_data_inputs); |
3643 | FillInputs(orig_node, &control_edges, &inputs); |
3644 | |
3645 | // Build new node. We use same name as original node, but change the op name. |
3646 | NodeBuilder nb(orig_node->name().c_str(), ri->new_name.c_str()); |
3647 | // Copy user-specified device assigned to original node to new node. |
3648 | nb.Device(orig_node->def().device()); |
3649 | |
3650 | Status s = CopyInputs(orig_node, inputs, &nb); |
3651 | if (s != OkStatus()) { |
3652 | return s; |
3653 | } |
3654 | |
3655 | std::vector<NodeBuilder::NodeOut> workspace_tensors; |
3656 | bool are_workspace_tensors_available = false; |
3657 | if (IsWorkspaceCheckNeeded(orig_node)) { |
3658 | AddWorkSpaceEdgeIfNeeded(g, orig_node, &nb, &workspace_tensors, |
3659 | &are_workspace_tensors_available); |
3660 | if (are_workspace_tensors_available) { |
3661 | DCHECK_EQ(workspace_tensors.size(), 1); |
3662 | nb.Input(workspace_tensors[0].node, workspace_tensors[0].index); |
3663 | } |
3664 | } |
3665 | |
3666 | if (!NativeFormatEnabled()) { |
3667 | ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, true); |
3668 | } else { |
3669 | ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, false); |
3670 | } |
3671 | |
3672 | if (DataTypeIsQuantized(orig_node->input_type(0)) || |
3673 | DataTypeIsQuantized(orig_node->output_type(0))) { |
3674 | nb.Attr("_kernel" , mkl_op_registry::kMklQuantizedOpLabel); |
3675 | } else { |
3676 | nb.Attr("_kernel" , mkl_op_registry::kMklNameChangeOpLabel); |
3677 | } |
3678 | |
3679 | // Finalize graph and get new node. |
3680 | s = nb.Finalize(&**g, new_node); |
3681 | if (s != OkStatus()) { |
3682 | return s; |
3683 | } |
3684 | |
3685 | // In the following code of this function, an unsorted set is used to make |
3686 | // sure no duplicated edges be added into the new node. Therefore, we can |
3687 | // pass allow_duplicates = true in AddControlEdge call to skip the O(#edges) |
3688 | // check in the routine. |
3689 | |
3690 | // Incoming data edges from 'orig_node' node to new 'new_node' node are |
3691 | // already copied in BuildNode. We need to handle control edges now. |
3692 | std::unordered_set<Node*> unique_node; |
3693 | for (const Edge* e : orig_node->in_edges()) { |
3694 | if (e->IsControlEdge()) { |
3695 | auto result = unique_node.insert(e->src()); |
3696 | if (result.second) { |
3697 | (*g)->AddControlEdge(e->src(), *new_node, true); |
3698 | } |
3699 | } |
3700 | } |
3701 | unique_node.clear(); |
3702 | |
3703 | // Transfer outgoing edges from 'orig_node' node to new 'new_node' node. |
3704 | for (const Edge* e : orig_node->out_edges()) { |
3705 | if (e->IsControlEdge()) { |
3706 | auto result = unique_node.insert(e->dst()); |
3707 | if (result.second) { |
3708 | (*g)->AddControlEdge(*new_node, e->dst(), true); |
3709 | } |
3710 | } else { |
3711 | auto result = |
3712 | (*g)->AddEdge(*new_node, e->src_output(), e->dst(), e->dst_input()); |
3713 | DCHECK(result != nullptr); |
3714 | } |
3715 | } |
3716 | |
3717 | return OkStatus(); |
3718 | } |
3719 | |
3720 | Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, |
3721 | Node* orig_node, |
3722 | const RewriteInfo* ri) { |
3723 | DCHECK(ri != nullptr); |
3724 | DCHECK(orig_node != nullptr); |
3725 | |
3726 | VLOG(1) << "MklLayoutRewritePass: Original node:" << orig_node->DebugString(); |
3727 | |
3728 | Status ret_status = OkStatus(); |
3729 | Node* new_node = nullptr; |
3730 | if (ri->rewrite_cause == kRewriteForLayoutPropagation) { |
3731 | ret_status = RewriteNodeForLayoutPropagation(g, orig_node, &new_node, ri); |
3732 | } else if (ri->rewrite_cause == kRewriteForOpNameChange) { |
3733 | ret_status = RewriteNodeForJustOpNameChange(g, orig_node, &new_node, ri); |
3734 | } else { |
3735 | ret_status = Status(error::Code::INVALID_ARGUMENT, |
3736 | "Unsupported rewrite cause found." |
3737 | "RewriteNode will fail." ); |
3738 | } |
3739 | TF_CHECK_OK(ret_status); |
3740 | |
3741 | // Copy the runtime device assigned from original code to new node. |
3742 | new_node->set_assigned_device_name(orig_node->assigned_device_name()); |
3743 | |
3744 | // Delete original node and mark new node as rewritten. |
3745 | (*g)->RemoveNode(orig_node); |
3746 | |
3747 | VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString(); |
3748 | return ret_status; |
3749 | } |
3750 | |
3751 | // TODO(mdfaijul): Is there any other elegant way to check for quantized ops |
3752 | // having attributes other than "T"? |
3753 | // Current implementation reflects only QuantizedConv2D and its fused Ops. |
3754 | const MklLayoutRewritePass::RewriteInfo* |
3755 | MklLayoutRewritePass::CheckForQuantizedNodeRewrite(const Node* n) const { |
3756 | DataType T1, T2; |
3757 | DataType Tinput, Tfilter; |
3758 | bool type_attrs_present = false; |
3759 | |
3760 | if (TryGetNodeAttr(n->def(), "Tinput" , &Tinput) && |
3761 | TryGetNodeAttr(n->def(), "Tfilter" , &Tfilter) && |
3762 | mkl_op_registry::IsMklQuantizedOp( |
3763 | mkl_op_registry::GetMklOpName(n->type_string()), Tinput, Tfilter)) { |
3764 | type_attrs_present = true; |
3765 | } else if (TryGetNodeAttr(n->def(), "T1" , &T1) && |
3766 | TryGetNodeAttr(n->def(), "T2" , &T2) && |
3767 | mkl_op_registry::IsMklQuantizedOp( |
3768 | mkl_op_registry::GetMklOpName(n->type_string()), T1, T2)) { |
3769 | type_attrs_present = true; |
3770 | } |
3771 | |
3772 | if (type_attrs_present) { |
3773 | for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { |
3774 | if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) { |
3775 | return &*ri; |
3776 | } |
3777 | } |
3778 | } |
3779 | |
3780 | return nullptr; |
3781 | } |
3782 | |
3783 | const MklLayoutRewritePass::RewriteInfo* |
3784 | MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { |
3785 | DCHECK(n); |
3786 | |
3787 | // QuantizedOps may have attributes other than "T", so decoupled the check |
3788 | // with a function, CheckForQuantizedNodeRewrite(const Node*). |
3789 | const RewriteInfo* ri = CheckForQuantizedNodeRewrite(n); |
3790 | if (ri != nullptr) return ri; |
3791 | |
3792 | // First check if node along with its type is supported by MKL layer. |
3793 | // We do not want to rewrite an op into Mkl op if types are not supported. |
3794 | // E.g., MklRelu does not support INT32. So we cannot rewrite Relu to |
3795 | // MklRelu if type is INT32. |
3796 | DataType T; |
3797 | if (!TryGetNodeAttr(n->def(), "T" , &T)) { |
3798 | return nullptr; |
3799 | } |
3800 | |
3801 | // We make an exception for Conv2DGrad and MaxPool related ops as |
3802 | // the corresponding MKL ops currently do not support the case |
3803 | // of padding == EXPLICIT yet. |
3804 | // TODO(intel): support `EXPLICIT` padding for ConvGrad |
3805 | if (n->type_string() == csinfo_.conv2d_grad_input || |
3806 | n->type_string() == csinfo_.conv2d_grad_filter || |
3807 | n->type_string() == csinfo_.depthwise_conv2d_grad_filter || |
3808 | n->type_string() == csinfo_.depthwise_conv2d_grad_input || |
3809 | n->type_string() == csinfo_.conv3d_grad_filter || |
3810 | n->type_string() == csinfo_.conv3d_grad_filter || |
3811 | n->type_string() == csinfo_.max_pool || |
3812 | n->type_string() == csinfo_.max_pool_grad || |
3813 | n->type_string() == csinfo_.max_pool3d || |
3814 | n->type_string() == csinfo_.max_pool3d_grad) { |
3815 | string padding; |
3816 | TF_CHECK_OK(GetNodeAttr(n->def(), "padding" , &padding)); |
3817 | if (padding == "EXPLICIT" ) return nullptr; |
3818 | } |
3819 | |
3820 | // We make an exception for __MklDummyConv2DWithBias, |
3821 | // __MklConv2DBackpropFilterWithBias, and __MklDummyPadWithConv2D since their |
3822 | // names do not match Mkl node names. |
3823 | if (n->type_string() != csinfo_.conv2d_with_bias && |
3824 | n->type_string() != csinfo_.pad_with_conv2d && |
3825 | n->type_string() != csinfo_.pad_with_fused_conv2d && |
3826 | n->type_string() != csinfo_.conv2d_grad_filter_with_bias && |
3827 | n->type_string() != csinfo_.fused_batch_norm_ex && |
3828 | n->type_string() != csinfo_.fused_conv2d && |
3829 | n->type_string() != csinfo_.fused_depthwise_conv2d && |
3830 | n->type_string() != csinfo_.fused_matmul && |
3831 | n->type_string() != csinfo_.fused_conv3d && |
3832 | !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), |
3833 | T)) { |
3834 | return nullptr; |
3835 | } |
3836 | |
3837 | // We now check if rewrite rule applies for this op. If rewrite rule passes |
3838 | // for this op, then we rewrite it to Mkl op. |
3839 | // Find matching RewriteInfo and then check that rewrite rule applies. |
3840 | for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) { |
3841 | if (n->type_string().compare(ri->name) == 0 && ri->rewrite_rule(n)) { |
3842 | return &*ri; |
3843 | } |
3844 | } |
3845 | |
3846 | // Else return not found. |
3847 | return nullptr; |
3848 | } |
3849 | |
3850 | ////////////////////////////////////////////////////////////////////////// |
3851 | // Helper functions for node fusion |
3852 | ////////////////////////////////////////////////////////////////////////// |
3853 | Status MklLayoutRewritePass::FuseTransposeMklOpTranspose( |
3854 | std::unique_ptr<Graph>* g, std::vector<Node*>& nodes, |
3855 | std::function<void(const Node*, NodeBuilder* nb, bool)> copy_attrs, |
3856 | string data_format) { |
3857 | Node* transpose_to_nhwc = nodes[0]; |
3858 | Node* mklop = nodes[1]; |
3859 | Node* transpose_to_nchw = nodes[2]; |
3860 | |
3861 | const int transpose_nhwc_num_inputs = transpose_to_nhwc->num_inputs(); |
3862 | gtl::InlinedVector<Node*, 4> transpose_nhwc_control_edges; |
3863 | gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nhwc_in( |
3864 | transpose_nhwc_num_inputs); |
3865 | FillInputs(transpose_to_nhwc, &transpose_nhwc_control_edges, |
3866 | &transpose_nhwc_in); |
3867 | |
3868 | const int mklop_num_inputs = mklop->num_inputs(); |
3869 | gtl::InlinedVector<Node*, 4> mklop_control_edges; |
3870 | gtl::InlinedVector<std::pair<Node*, int>, 4> mklop_in(mklop_num_inputs); |
3871 | FillInputs(mklop, &mklop_control_edges, &mklop_in); |
3872 | |
3873 | const int transpose_nchw_num_inputs = transpose_to_nchw->num_inputs(); |
3874 | gtl::InlinedVector<Node*, 4> transpose_nchw_control_edges; |
3875 | gtl::InlinedVector<std::pair<Node*, int>, 4> transpose_nchw_in( |
3876 | transpose_nchw_num_inputs); |
3877 | FillInputs(transpose_to_nchw, &transpose_nchw_control_edges, |
3878 | &transpose_nchw_in); |
3879 | |
3880 | // We use same name as original node, but change the op |
3881 | // type. |
3882 | NodeBuilder nb(mklop->name(), mklop->type_string()); |
3883 | |
3884 | // Storing the output slots of the input nodes. |
3885 | for (int i = 0; i < mklop_num_inputs; i++) { |
3886 | if (mklop_in[i].first == transpose_to_nhwc) { |
3887 | // Fill "x": |
3888 | nb.Input(transpose_nhwc_in[0].first, transpose_nhwc_in[0].second); |
3889 | } else { |
3890 | // Fill inputs other than "x": |
3891 | nb.Input(mklop_in[i].first, mklop_in[i].second); |
3892 | } |
3893 | } |
3894 | |
3895 | copy_attrs(const_cast<const Node*>(mklop), &nb, true); |
3896 | nb.Attr("data_format" , data_format); |
3897 | |
3898 | // Copy the device assigned to old node to new node. |
3899 | nb.Device(mklop->def().device()); |
3900 | |
3901 | // Create node. |
3902 | Node* new_node; |
3903 | TF_CHECK_OK(nb.Finalize(&**g, &new_node)); |
3904 | // No need to check if new_node is null because it will be null only when |
3905 | // Finalize fails. |
3906 | |
3907 | // Fill outputs. |
3908 | for (const Edge* e : transpose_to_nchw->out_edges()) { |
3909 | if (!e->IsControlEdge()) { |
3910 | const int kTransposeWithMklOpOutputSlot = 0; |
3911 | auto new_edge = (*g)->AddEdge(new_node, kTransposeWithMklOpOutputSlot, |
3912 | e->dst(), e->dst_input()); |
3913 | DCHECK(new_edge); |
3914 | } |
3915 | } |
3916 | |
3917 | // Copy device assigned to old node to new node. |
3918 | new_node->set_assigned_device_name(mklop->assigned_device_name()); |
3919 | |
3920 | // Copy requested_device and assigned_device_name_index |
3921 | new_node->set_requested_device(mklop->requested_device()); |
3922 | new_node->set_assigned_device_name_index(mklop->assigned_device_name_index()); |
3923 | |
3924 | (*g)->RemoveNode(transpose_to_nhwc); |
3925 | (*g)->RemoveNode(mklop); |
3926 | (*g)->RemoveNode(transpose_to_nchw); |
3927 | |
3928 | return OkStatus(); |
3929 | } |
3930 | |
3931 | Status MklLayoutRewritePass::FuseNode( |
3932 | std::unique_ptr<Graph>* g, std::vector<Node*>& nodes, |
3933 | const MklLayoutRewritePass::FusionInfo fi) { |
3934 | return fi.fuse_func(g, nodes, fi.copy_attrs); |
3935 | } |
3936 | |
3937 | std::tuple<bool, std::vector<Node*>, const MklLayoutRewritePass::FusionInfo> |
3938 | MklLayoutRewritePass::CheckForNodeFusion(Node* a) const { |
3939 | // Stores matched nodes, in the same order as node_checkers. |
3940 | std::vector<Node*> nodes; |
3941 | |
3942 | for (auto fi = finfo_.begin(); fi != finfo_.end(); ++fi) { |
3943 | // |
3944 | // Make sure node "a" and its succeeding nodes (b, c ...), match the pattern |
3945 | // defined in fusion info (ops[0], ops[1], ...), |
3946 | // a.k.a. "a->b->c" matches "op1->op2->op3" |
3947 | // |
3948 | |
3949 | // Stores the first unvisited outgoing edge of each matched node in "nodes". |
3950 | std::stack<EdgeSet::const_iterator> current_neighbor_stack; |
3951 | nodes.clear(); |
3952 | |
3953 | auto node_checker = fi->node_checkers.begin(); |
3954 | if (a != nullptr && (*node_checker)(a)) { |
3955 | nodes.push_back(a); |
3956 | current_neighbor_stack.push(a->out_edges().begin()); |
3957 | ++node_checker; |
3958 | } |
3959 | |
3960 | while (!nodes.empty()) { |
3961 | auto& current_neighbor_iter = current_neighbor_stack.top(); |
3962 | |
3963 | if (current_neighbor_iter != nodes.back()->out_edges().end()) { |
3964 | // Found an unvisited edge. Goes through the edge to get the neighbor. |
3965 | Node* neighbor_node = (*current_neighbor_iter)->dst(); |
3966 | ++current_neighbor_stack.top(); // Retrieves the next unvisited edge. |
3967 | |
3968 | if ((*node_checker)(neighbor_node)) { |
3969 | // Found a match. Stores the node and moves to the next checker. |
3970 | nodes.push_back(neighbor_node); |
3971 | current_neighbor_stack.push(neighbor_node->out_edges().begin()); |
3972 | if (++node_checker == fi->node_checkers.end()) { |
3973 | return make_tuple(true, nodes, *fi); |
3974 | } |
3975 | } |
3976 | } else { |
3977 | // Removes the current node since none of its neighbor leads to a |
3978 | // further match. |
3979 | nodes.pop_back(); |
3980 | current_neighbor_stack.pop(); |
3981 | --node_checker; |
3982 | } |
3983 | } |
3984 | } |
3985 | |
3986 | return make_tuple(false, std::vector<Node*>(), FusionInfo()); |
3987 | } |
3988 | |
3989 | /////////////////////////////////////////////////////////////////////////////// |
3990 | // Post-rewrite Mkl metadata fixup pass |
3991 | /////////////////////////////////////////////////////////////////////////////// |
3992 | bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g, |
3993 | const Edge* e_data, |
3994 | const Edge* e_metadata) { |
3995 | if (g == nullptr || e_data == nullptr || e_metadata == nullptr) { |
3996 | return false; |
3997 | } |
3998 | |
3999 | Node* n_data = e_data->src(); |
4000 | int n_data_op_slot = e_data->src_output(); |
4001 | int n_metadata_op_slot = |
4002 | GetTensorMetaDataIndex(n_data_op_slot, n_data->num_outputs()); |
4003 | |
4004 | // If the source of meta edge is a constant node (producing dummy Mkl metadata |
4005 | // tensor), then we will need to fix. |
4006 | if (IsConstant(e_metadata->src())) { |
4007 | Node* e_metadata_dst = e_metadata->dst(); |
4008 | int e_metadata_in_slot = e_metadata->dst_input(); |
4009 | auto new_edge = (*g)->AddEdge(n_data, n_metadata_op_slot, e_metadata_dst, |
4010 | e_metadata_in_slot); |
4011 | DCHECK(new_edge); |
4012 | |
4013 | (*g)->RemoveEdge(e_metadata); |
4014 | return true; |
4015 | } |
4016 | |
4017 | return false; |
4018 | } |
4019 | |
4020 | bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g, |
4021 | Node* n) { |
4022 | bool result = false; |
4023 | |
4024 | // If graph node is not Mkl node, then return. |
4025 | DataType T = DT_INVALID; |
4026 | if (!TryGetNodeAttr(n->def(), "T" , &T) || |
4027 | !mkl_op_registry::IsMklOp(n->type_string(), T, false)) { |
4028 | return result; |
4029 | } |
4030 | |
4031 | // If it is Mkl node, then check if the input edges to this node that carry |
4032 | // Mkl metadata are linked up correctly with the source node. |
4033 | |
4034 | // For Mkl nodes, we generate twice the number of input tensors (n for Mkl |
4035 | // data tensors + n for Mkl metadata tensors). We need to check for correct |
4036 | // connection of n metadata tensors only. |
4037 | int num_data_inputs = n->num_inputs() / 2; |
4038 | for (int idx = 0; idx < num_data_inputs; idx++) { |
4039 | // Get the edge connecting input slot with index (idx). |
4040 | const Edge* e = nullptr; |
4041 | TF_CHECK_OK(n->input_edge(idx, &e)); |
4042 | |
4043 | // If e is control edge, then skip. |
4044 | if (e->IsControlEdge()) { |
4045 | continue; |
4046 | } |
4047 | |
4048 | // Check that the source node for edge 'e' is Mkl node. If it is not an Mkl |
4049 | // node, then we don't need to do anything. |
4050 | Node* e_src = e->src(); |
4051 | if (TryGetNodeAttr(e_src->def(), "T" , &T) && |
4052 | mkl_op_registry::IsMklOp(e_src->type_string(), T, false)) { |
4053 | // Source node for edge 'e' is Mkl node. |
4054 | // Destination node and destination input slot of e is node 'n' and 'idx' |
4055 | // resp. |
4056 | CHECK_EQ(e->dst(), n); |
4057 | CHECK_EQ(e->dst_input(), idx); |
4058 | |
4059 | // Let's get edge that carries Mkl metadata corresponding to Mkl data edge |
4060 | // 'e'. For that, let's first get the input slot of 'n' where the meta |
4061 | // edge will feed the value. |
4062 | int e_meta_in_slot = |
4063 | GetTensorMetaDataIndex(e->dst_input(), n->num_inputs()); |
4064 | const Edge* e_meta = nullptr; |
4065 | TF_CHECK_OK(n->input_edge(e_meta_in_slot, &e_meta)); |
4066 | |
4067 | // Let's check if we need to fix this meta edge. |
4068 | if (FixMklMetaDataEdgeIfNeeded(g, e, e_meta)) { |
4069 | result = true; |
4070 | } |
4071 | } |
4072 | } |
4073 | |
4074 | return result; |
4075 | } |
4076 | |
4077 | /////////////////////////////////////////////////////////////////////////////// |
4078 | // Run function for the pass |
4079 | /////////////////////////////////////////////////////////////////////////////// |
4080 | |
4081 | bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) { |
4082 | bool result = false; |
4083 | DCHECK(g); |
4084 | |
4085 | DumpGraph("Before running MklLayoutRewritePass" , &**g); |
4086 | |
4087 | std::vector<Node*> order; |
4088 | GetReversePostOrder(**g, &order); // This will give us topological sort. |
4089 | for (Node* n : order) { |
4090 | // If node is not an op or it cannot run on CPU device, then skip. |
4091 | if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { |
4092 | continue; |
4093 | } |
4094 | |
4095 | Node* m = nullptr; |
4096 | if ((m = CheckForNodeMerge(n)) != nullptr && CanOpRunOnCPUDevice(m)) { |
4097 | // Check if the node 'n' can be merged with any other node. If it can |
4098 | // be 'm' contains the node with which it can be merged. |
4099 | string n1_name = n->name(); |
4100 | string n2_name = m->name(); |
4101 | |
4102 | VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and " |
4103 | << n2_name << " for merging" ; |
4104 | |
4105 | if (MergeNode(g, n, m) == OkStatus()) { |
4106 | VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and " |
4107 | << n2_name; |
4108 | result = true; |
4109 | } |
4110 | } |
4111 | } |
4112 | |
4113 | DumpGraph("After running MklLayoutRewritePass(NodeMerge)" , &**g); |
4114 | |
4115 | order.clear(); |
4116 | GetReversePostOrder(**g, &order); // This will give us topological sort. |
4117 | for (Node* n : order) { |
4118 | // If node is not an op or it cannot run on CPU device, then skip. |
4119 | if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { |
4120 | continue; |
4121 | } |
4122 | |
4123 | auto check_result = CheckForNodeFusion(n); |
4124 | bool found_pattern = std::get<0>(check_result); |
4125 | std::vector<Node*> nodes = std::get<1>(check_result); |
4126 | const FusionInfo fi = std::get<2>(check_result); |
4127 | |
4128 | // if "found_pattern" is true, we can do the fusion. |
4129 | if (found_pattern) { |
4130 | if (FuseNode(g, nodes, fi) == OkStatus()) { |
4131 | result = true; |
4132 | } |
4133 | } |
4134 | } |
4135 | DumpGraph("After running MklLayoutRewritePass(NodeFusion)" , &**g); |
4136 | |
4137 | order.clear(); |
4138 | GetReversePostOrder(**g, &order); // This will give us topological sort. |
4139 | for (Node* n : order) { |
4140 | // If node is not an op or it cannot run on CPU device, then skip. |
4141 | if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { |
4142 | continue; |
4143 | } |
4144 | |
4145 | const RewriteInfo* ri = nullptr; |
4146 | // We will first search if node is to be rewritten. |
4147 | if ((ri = CheckForNodeRewrite(n)) != nullptr) { |
4148 | string node_name = n->name(); |
4149 | string op_name = n->type_string(); |
4150 | |
4151 | VLOG(1) << "MklLayoutRewritePass: Scheduled node " << node_name |
4152 | << " with op " << op_name << " for rewrite using" |
4153 | << " layout optimization." ; |
4154 | |
4155 | if (RewriteNode(g, n, ri) == OkStatus()) { |
4156 | VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name |
4157 | << " with op " << op_name << " for Mkl layout optimization." ; |
4158 | result = true; |
4159 | } |
4160 | } |
4161 | } |
4162 | |
4163 | DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)" , &**g); |
4164 | |
4165 | if (!NativeFormatEnabled()) { |
4166 | order.clear(); |
4167 | GetReversePostOrder(**g, &order); // This will give us topological sort. |
4168 | for (Node* n : order) { |
4169 | // If node is not an op or it cannot run on CPU device, then skip. |
4170 | if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) { |
4171 | continue; |
4172 | } |
4173 | if (FixMklMetaDataEdges(g, n)) { |
4174 | string node_name = n->name(); |
4175 | string op_name = n->type_string(); |
4176 | |
4177 | VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node " |
4178 | << node_name << " with op " << op_name; |
4179 | result = true; |
4180 | } |
4181 | } |
4182 | DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)" , |
4183 | &**g); |
4184 | } |
4185 | |
4186 | return result; |
4187 | } |
4188 | |
4189 | bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g) { |
4190 | return MklLayoutRewritePass().RunPass(g); |
4191 | } |
4192 | |
4193 | Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { |
4194 | if (options.graph == nullptr && options.partition_graphs == nullptr) { |
4195 | return OkStatus(); |
4196 | } |
4197 | if (!IsMKLEnabled()) { |
4198 | VLOG(2) << "TF-MKL: MKL is not enabled" ; |
4199 | return OkStatus(); |
4200 | } |
4201 | |
4202 | auto process_graph = [&](std::unique_ptr<Graph>* g) { |
4203 | // Get the ownership of a graph |
4204 | std::unique_ptr<Graph>* ng = std::move(g); |
4205 | RunPass(ng); |
4206 | // Return the ownership of a graph back |
4207 | g->reset(ng->release()); |
4208 | }; |
4209 | |
4210 | if (kMklLayoutRewritePassGroup != |
4211 | OptimizationPassRegistry::POST_PARTITIONING) { |
4212 | // For any pre-partitioning phase, a graph is stored in options.graph. |
4213 | process_graph(options.graph); |
4214 | } else { |
4215 | // For post partitioning phase, graphs are stored in |
4216 | // options.partition_graphs. |
4217 | for (auto& pg : *options.partition_graphs) { |
4218 | process_graph(&pg.second); |
4219 | } |
4220 | } |
4221 | |
4222 | return OkStatus(); |
4223 | } |
4224 | |
4225 | } // namespace tensorflow |
4226 | |
4227 | #endif // INTEL_MKL |
4228 | |