1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/numeric_op.h" |
18 | #include "tensorflow/core/framework/op.h" |
19 | #include "tensorflow/core/framework/shape_inference.h" |
20 | #include "tensorflow/core/util/mirror_pad_mode.h" |
21 | #include "tensorflow/core/util/padding.h" |
22 | #include "tensorflow/core/util/tensor_format.h" |
23 | |
24 | // For now, this file only includes MKL quantized ops. In the |
25 | // future, we will move all other MKL ops from nn_ops.cc to this file. |
26 | |
27 | #ifdef INTEL_MKL |
28 | |
29 | namespace tensorflow { |
30 | |
31 | using shape_inference::DimensionHandle; |
32 | using shape_inference::InferenceContext; |
33 | using shape_inference::ShapeHandle; |
34 | |
35 | REGISTER_OP("_MklNativeConv3D" ) |
36 | .Input("input: T" ) |
37 | .Input("filter: T" ) |
38 | .Output("output: T" ) |
39 | .Attr("T: {bfloat16, float}" ) |
40 | .Attr("strides: list(int) >= 5" ) |
41 | .Attr("is_filter_const: bool = false" ) |
42 | .Attr(GetPaddingAttrString()) |
43 | .Attr(GetConvnet3dDataFormatAttrString()) |
44 | .Attr("dilations: list(int) = [1, 1, 1, 1, 1]" ) |
45 | .SetShapeFn(shape_inference::Conv3DShape) |
46 | .Doc(R"doc( |
47 | MKL version of Conv3D operator that does not depend on layout propagation. |
48 | Uses oneDNN APIs to perform 3D convolution. |
49 | |
50 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
51 | expected to invoke these operators. |
52 | )doc" ); |
53 | |
54 | REGISTER_OP("_MklNativeConv3DBackpropInputV2" ) |
55 | .Input("input_sizes: Tshape" ) |
56 | .Input("filter: T" ) |
57 | .Input("out_backprop: T" ) |
58 | .Output("output: T" ) |
59 | .Attr("T: {bfloat16, float}" ) |
60 | .Attr("strides: list(int) >= 5" ) |
61 | .Attr("dilations: list(int) = [1, 1, 1, 1, 1]" ) |
62 | .Attr("Tshape: {int32, int64} = DT_INT32" ) |
63 | .Attr(GetPaddingAttrString()) |
64 | .Attr(GetConvnet3dDataFormatAttrString()) |
65 | .SetShapeFn([](InferenceContext* c) { |
66 | ShapeHandle s; |
67 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); |
68 | TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); |
69 | c->set_output(0, s); |
70 | return Status::OK(); |
71 | }) |
72 | .Doc(R"doc( |
73 | MKL version of Convolution3D backward input op that does not depend on layout |
74 | propagation. Uses oneDNN APIs to compute the gradients of convolution with |
75 | respect to the input. |
76 | |
77 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
78 | expected to invoke these operators. |
79 | )doc" ); |
80 | |
81 | REGISTER_OP("_MklNativeConv3DBackpropFilterV2" ) |
82 | .Input("input: T" ) |
83 | .Input("filter_sizes: int32" ) |
84 | .Input("out_backprop: T" ) |
85 | .Output("output: T" ) |
86 | .Attr("T: {bfloat16, float}" ) |
87 | .Attr("strides: list(int)" ) |
88 | .Attr(GetPaddingAttrString()) |
89 | .Attr(GetConvnet3dDataFormatAttrString()) |
90 | .Attr("dilations: list(int) = [1, 1, 1, 1, 1]" ) |
91 | .SetShapeFn([](InferenceContext* c) { |
92 | ShapeHandle s; |
93 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); |
94 | TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); |
95 | c->set_output(0, s); |
96 | return Status::OK(); |
97 | }) |
98 | .Doc(R"doc( |
99 | MKL version of Conv3DBackpropFilter op that does not depend on layout |
100 | propagation. Uses oneDNN APIs to compute the gradients of convolution |
101 | with respect to the filter. |
102 | |
103 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
104 | expected to invoke these operators. |
105 | )doc" ); |
106 | |
107 | REGISTER_OP("_MklNativeFusedConv3D" ) |
108 | .Input("input: T" ) |
109 | .Input("filter: T" ) |
110 | .Input("args: num_args * T" ) |
111 | .Output("output: T" ) |
112 | .Attr("T: {bfloat16, float}" ) |
113 | .Attr("num_args: int >= 0" ) |
114 | .Attr("strides: list(int) >= 5" ) |
115 | .Attr("is_filter_const: bool = false" ) |
116 | .Attr(GetPaddingAttrString()) |
117 | .Attr(GetConvnet3dDataFormatAttrString()) |
118 | .Attr("dilations: list(int) = [1, 1, 1, 1, 1]" ) |
119 | .Attr("padding_list: list(int) = []" ) |
120 | .Attr("fused_ops: list(string) = []" ) |
121 | .Attr("epsilon: float = 0.0001" ) |
122 | .Attr("leakyrelu_alpha: float = 0.2" ) |
123 | .SetShapeFn(shape_inference::Conv3DShape) |
124 | .Doc(R"doc( |
125 | MKL version of Conv3D operator that does not depend on layout propagation. |
126 | Uses oneDNN APIs to perform 3D convolution. |
127 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
128 | expected to invoke these operators. |
129 | )doc" ); |
130 | |
131 | REGISTER_OP("_FusedConv3D" ) |
132 | .Input("input: T" ) |
133 | .Input("filter: T" ) |
134 | .Input("args: num_args * T" ) |
135 | .Output("output: T" ) |
136 | .Attr("T: {bfloat16, float}" ) |
137 | .Attr("num_args: int >= 0" ) |
138 | .Attr("strides: list(int) >= 5" ) |
139 | .Attr(GetPaddingAttrString()) |
140 | .Attr(GetConvnet3dDataFormatAttrString()) |
141 | .Attr("dilations: list(int) = [1, 1, 1, 1, 1]" ) |
142 | .Attr("padding_list: list(int) = []" ) |
143 | .Attr("fused_ops: list(string) = []" ) |
144 | .Attr("epsilon: float = 0.0001" ) |
145 | .Attr("leakyrelu_alpha: float = 0.2" ) |
146 | .SetShapeFn(shape_inference::Conv3DShape) |
147 | .Doc(R"doc( |
148 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
149 | expected to invoke these operators. |
150 | )doc" ); |
151 | |
152 | REGISTER_OP("_MklNativeDepthwiseConv2dNative" ) |
153 | .Input("input: T" ) |
154 | .Input("filter: T" ) |
155 | .Output("output: T" ) |
156 | .Attr("T: {half, bfloat16, float, double}" ) |
157 | .Attr("strides: list(int)" ) |
158 | .Attr("is_filter_const: bool = false" ) |
159 | .Attr(GetPaddingAttrStringWithExplicit()) |
160 | .Attr(GetConvnetDataFormatAttrString()) |
161 | .Attr(GetExplicitPaddingsAttrString()) |
162 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
163 | .SetShapeFn(shape_inference::DepthwiseConv2DNativeShapeWithExplicitPadding); |
164 | |
165 | REGISTER_OP("_MklNativeDepthwiseConv2dNativeBackpropInput" ) |
166 | .Input("input_sizes: int32" ) |
167 | .Input("filter: T" ) |
168 | .Input("out_backprop: T" ) |
169 | .Output("output: T" ) |
170 | .Attr("T: {half, bfloat16, float, double}" ) |
171 | .Attr("strides: list(int)" ) |
172 | .Attr(GetPaddingAttrString()) |
173 | .Attr(GetConvnetDataFormatAttrString()) |
174 | .Attr(GetExplicitPaddingsAttrString()) |
175 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
176 | .SetShapeFn([](InferenceContext* c) { |
177 | ShapeHandle s; |
178 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); |
179 | TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); |
180 | c->set_output(0, s); |
181 | return Status::OK(); |
182 | }); |
183 | |
184 | REGISTER_OP("_MklNativeDepthwiseConv2dNativeBackpropFilter" ) |
185 | .Input("input: T" ) |
186 | .Input("filter_sizes: int32" ) |
187 | .Input("out_backprop: T" ) |
188 | .Output("output: T" ) |
189 | .Attr("T: {half, bfloat16, float, double}" ) |
190 | .Attr("strides: list(int)" ) |
191 | .Attr(GetPaddingAttrString()) |
192 | .Attr(GetConvnetDataFormatAttrString()) |
193 | .Attr(GetExplicitPaddingsAttrString()) |
194 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
195 | .SetShapeFn([](InferenceContext* c) { |
196 | ShapeHandle s; |
197 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); |
198 | TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); |
199 | c->set_output(0, s); |
200 | return Status::OK(); |
201 | }); |
202 | |
203 | REGISTER_OP("_MklFusedConv2D" ) |
204 | .Input("input: T" ) |
205 | .Input("filter: T" ) |
206 | .Input("args: num_args * T" ) |
207 | .Input("mkl_input: uint8" ) |
208 | .Input("mkl_filter: uint8" ) |
209 | .Input("mkl_args: num_args * uint8" ) |
210 | .Output("output: T" ) |
211 | .Output("filter_output: T" ) |
212 | .Output("mkl_output: uint8" ) |
213 | .Output("mkl_filter_output: uint8" ) |
214 | .Attr("T: {bfloat16, float}" ) |
215 | .Attr("num_args: int >= 0" ) |
216 | .Attr("strides: list(int)" ) |
217 | .Attr("is_filter_const: bool = false" ) |
218 | .Attr(GetPaddingAttrStringWithExplicit()) |
219 | .Attr(GetConvnetDataFormatAttrString()) |
220 | .Attr(GetExplicitPaddingsAttrString()) |
221 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
222 | .Attr("use_cudnn_on_gpu: bool = true" ) |
223 | .Attr("fused_ops: list(string) = []" ) |
224 | // Attributes for the FusedBatchNorm ------------------------------------ // |
225 | .Attr("epsilon: float = 0.0001" ) |
226 | // Attributes for the LeakyRelu ----------------------------------------- // |
227 | .Attr("leakyrelu_alpha: float = 0.2" ) |
228 | // ---------------------------------------------------------------------- // |
229 | .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding) |
230 | .Doc(R"doc( |
231 | *NOTE*: Do not invoke this operator directly in Python. MKL DNN graph transformer |
232 | is expected to create these operators. |
233 | )doc" ); |
234 | |
235 | REGISTER_OP("_MklNativeFusedConv2D" ) |
236 | .Input("input: T" ) |
237 | .Input("filter: T" ) |
238 | .Input("args: num_args * T" ) |
239 | .Output("output: T" ) |
240 | .Attr("T: {bfloat16, float}" ) |
241 | .Attr("num_args: int >= 0" ) |
242 | .Attr("strides: list(int)" ) |
243 | .Attr("is_filter_const: bool = false" ) |
244 | .Attr(GetPaddingAttrStringWithExplicit()) |
245 | .Attr(GetConvnetDataFormatAttrString()) |
246 | .Attr(GetExplicitPaddingsAttrString()) |
247 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
248 | .Attr("use_cudnn_on_gpu: bool = true" ) |
249 | .Attr("fused_ops: list(string) = []" ) |
250 | // Attributes for the FusedBatchNorm ------------------------------------ // |
251 | .Attr("epsilon: float = 0.0001" ) |
252 | // Attributes for the LeakyRelu ----------------------------------------- // |
253 | .Attr("leakyrelu_alpha: float = 0.2" ) |
254 | // ---------------------------------------------------------------------- // |
255 | .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding) |
256 | .Doc(R"doc( |
257 | *NOTE*: Do not invoke this operator directly in Python. oneDNN graph transformer |
258 | is expected to create these operators. |
259 | )doc" ); |
260 | |
261 | REGISTER_OP("_MklNativeConv2DWithBias" ) |
262 | .Input("input: T" ) |
263 | .Input("filter: T" ) |
264 | .Input("bias: T" ) |
265 | .Output("output: T" ) |
266 | .Attr("T: {bfloat16, float}" ) |
267 | .Attr("strides: list(int)" ) |
268 | .Attr("use_cudnn_on_gpu: bool = true" ) |
269 | .Attr("is_filter_const: bool = false" ) |
270 | .Attr(GetPaddingAttrStringWithExplicit()) |
271 | .Attr(GetConvnetDataFormatAttrString()) |
272 | .Attr(GetExplicitPaddingsAttrString()) |
273 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
274 | .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding) |
275 | .Doc(R"doc( |
276 | MKL version of Conv2D and BiasAdd operator. Uses oneDNN APIs to perform |
277 | 2D convolution and add Bias to the output of convolution. |
278 | |
279 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
280 | expected to invoke this operator. |
281 | )doc" ); |
282 | |
283 | REGISTER_OP("_MklNativeConv2DBackpropFilterWithBias" ) |
284 | .Input("input: T" ) |
285 | .Input("filter_sizes: int32" ) |
286 | .Input("out_backprop: T" ) |
287 | .Output("output: T" ) |
288 | .Output("bias_grad: T" ) |
289 | .Attr("T: {bfloat16, float}" ) |
290 | .Attr("strides: list(int)" ) |
291 | .Attr("use_cudnn_on_gpu: bool = true" ) |
292 | .Attr(GetPaddingAttrString()) |
293 | .Attr(GetConvnetDataFormatAttrString()) |
294 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
295 | .SetShapeFn(shape_inference::Conv2DBackpropFilterWithBiasShape) |
296 | .Doc(R"doc( |
297 | oneDNN version of Conv2DBackpropFilterWithBias. Uses oneDNN APIs to compute the |
298 | fusion of Conv2DBackpropFilter and BiasAddGrad. |
299 | |
300 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
301 | expected to invoke this one. |
302 | )doc" ); |
303 | |
304 | REGISTER_OP("_MklFusedDepthwiseConv2dNative" ) |
305 | .Input("input: T" ) |
306 | .Input("filter: T" ) |
307 | .Input("args: num_args * T" ) |
308 | .Input("mkl_input: uint8" ) |
309 | .Input("mkl_filter: uint8" ) |
310 | .Input("mkl_args: num_args * uint8" ) |
311 | .Output("output: T" ) |
312 | .Output("filter_output: T" ) |
313 | .Output("mkl_output: uint8" ) |
314 | .Output("mkl_filter_output: uint8" ) |
315 | .Attr("T: {bfloat16, float}" ) |
316 | .Attr("num_args: int >= 0" ) |
317 | .Attr("strides: list(int)" ) |
318 | .Attr("is_filter_const: bool = false" ) |
319 | .Attr(GetPaddingAttrString()) |
320 | .Attr(GetConvnetDataFormatAttrString()) |
321 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
322 | .Attr("fused_ops: list(string) = []" ) |
323 | // Attributes for the FusedBatchNorm ------------------------------------ // |
324 | .Attr("epsilon: float = 0.0001" ) |
325 | // Attributes for the LeakyRelu ----------------------------------------- // |
326 | .Attr("leakyrelu_alpha: float = 0.2" ) |
327 | // ---------------------------------------------------------------------- // |
328 | .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape); |
329 | |
330 | REGISTER_OP("_MklNativeFusedDepthwiseConv2dNative" ) |
331 | .Input("input: T" ) |
332 | .Input("filter: T" ) |
333 | .Input("args: num_args * T" ) |
334 | .Output("output: T" ) |
335 | .Attr("T: {bfloat16, float}" ) |
336 | .Attr("num_args: int >= 0" ) |
337 | .Attr("strides: list(int)" ) |
338 | .Attr("is_filter_const: bool = false" ) |
339 | .Attr(GetPaddingAttrString()) |
340 | .Attr(GetConvnetDataFormatAttrString()) |
341 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
342 | .Attr("fused_ops: list(string) = []" ) |
343 | // Attributes for the FusedBatchNorm ------------------------------------ // |
344 | .Attr("epsilon: float = 0.0001" ) |
345 | // Attributes for the LeakyRelu ----------------------------------------- // |
346 | .Attr("leakyrelu_alpha: float = 0.2" ) |
347 | // ---------------------------------------------------------------------- // |
348 | .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape); |
349 | |
350 | REGISTER_OP("_MklFusedMatMul" ) |
351 | .Input("a: T" ) |
352 | .Input("b: T" ) |
353 | .Input("args: num_args * T" ) |
354 | .Input("mkl_a: uint8" ) |
355 | .Input("mkl_b: uint8" ) |
356 | .Input("mkl_args: num_args * uint8" ) |
357 | .Output("product: T" ) |
358 | .Output("mkl_product: uint8" ) |
359 | .Attr("is_filter_const: bool = false" ) |
360 | .Attr("transpose_a: bool = false" ) |
361 | .Attr("transpose_b: bool = false" ) |
362 | .Attr("T: {bfloat16, float}" ) |
363 | .Attr("num_args: int >= 0" ) |
364 | .Attr("fused_ops: list(string) = []" ) |
365 | // Attributes for the FusedBatchNorm ------------------------------------ // |
366 | .Attr("epsilon: float = 0.0001" ) |
367 | // Attributes for the LeakyRelu ----------------------------------------- // |
368 | .Attr("leakyrelu_alpha: float = 0.2" ) |
369 | // ---------------------------------------------------------------------- // |
370 | .SetShapeFn(shape_inference::MatMulShape) |
371 | .Doc(R"doc( |
372 | MKL version of FusedMatMul operator. Uses MKL-DNN APIs to implement MatMul |
373 | operator. |
374 | |
375 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
376 | expected to invoke these operators. |
377 | )doc" ); |
378 | |
379 | REGISTER_OP("_MklNativeFusedMatMul" ) |
380 | .Input("a: T" ) |
381 | .Input("b: T" ) |
382 | .Input("args: num_args * T" ) |
383 | .Output("product: T" ) |
384 | .Attr("is_filter_const: bool = false" ) |
385 | .Attr("transpose_a: bool = false" ) |
386 | .Attr("transpose_b: bool = false" ) |
387 | .Attr("T: {bfloat16, float}" ) |
388 | .Attr("num_args: int >= 0" ) |
389 | .Attr("fused_ops: list(string) = []" ) |
390 | // Attributes for the FusedBatchNorm ------------------------------------ // |
391 | .Attr("epsilon: float = 0.0001" ) |
392 | // Attributes for the LeakyRelu ----------------------------------------- // |
393 | .Attr("leakyrelu_alpha: float = 0.2" ) |
394 | // ---------------------------------------------------------------------- // |
395 | .SetShapeFn(shape_inference::MatMulShape) |
396 | .Doc(R"doc( |
397 | oneDNN version of FusedMatMul operator that does not depend |
398 | on layout propagation. Uses oneDNN APIs to implement MatMul fusion. |
399 | |
400 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
401 | expected to invoke this one. |
402 | )doc" ); |
403 | |
404 | REGISTER_OP("__MklDummyPadWithFusedConv2D" ) |
405 | .Input("input: T" ) |
406 | .Input("filter: T" ) |
407 | .Input("args: num_args * T" ) |
408 | .Input("paddings: Tpaddings" ) |
409 | .Output("output: T" ) |
410 | .Output("filter_output: T" ) |
411 | .Output("mkl_output: uint8" ) |
412 | .Output("mkl_filter_output: uint8" ) |
413 | .Attr("T: {bfloat16, float}" ) |
414 | .Attr("num_args: int >= 0" ) |
415 | .Attr("strides: list(int)" ) |
416 | .Attr(GetPaddingAttrString()) |
417 | .Attr(GetConvnetDataFormatAttrString()) |
418 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
419 | .Attr("fused_ops: list(string) = []" ) |
420 | .Attr("Tpaddings: {int32, int64} = DT_INT32" ) |
421 | // Attributes for the FusedBatchNorm ------------------------------------ // |
422 | .Attr("epsilon: float = 0.0001" ) |
423 | // Attributes for the LeakyRelu ----------------------------------------- // |
424 | .Attr("leakyrelu_alpha: float = 0.2" ) |
425 | // ---------------------------------------------------------------------- // |
426 | .SetShapeFn(shape_inference::Conv2DShape) |
427 | .Doc(R"doc( |
428 | *NOTE*: Do not invoke this operator directly in Python. MKL DNN graph transformer |
429 | is expected to create these operators. |
430 | )doc" ); |
431 | |
432 | REGISTER_OP("_MklPadWithFusedConv2D" ) |
433 | .Input("input: T" ) |
434 | .Input("filter: T" ) |
435 | .Input("args: num_args * T" ) |
436 | .Input("paddings: Tpaddings" ) |
437 | .Input("mkl_input: uint8" ) |
438 | .Input("mkl_filter: uint8" ) |
439 | .Input("mkl_args: num_args * uint8" ) |
440 | .Input("mkl_paddings: uint8" ) |
441 | .Output("output: T" ) |
442 | .Output("filter_output: T" ) |
443 | .Output("mkl_output: uint8" ) |
444 | .Output("mkl_filter_output: uint8" ) |
445 | .Attr("T: {bfloat16, float}" ) |
446 | .Attr("num_args: int >= 0" ) |
447 | .Attr("strides: list(int)" ) |
448 | .Attr("is_filter_const: bool = false" ) |
449 | .Attr(GetPaddingAttrString()) |
450 | .Attr(GetConvnetDataFormatAttrString()) |
451 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
452 | .Attr("fused_ops: list(string) = []" ) |
453 | .Attr("Tpaddings: {int32, int64} = DT_INT32" ) |
454 | // Attributes for the FusedBatchNorm ------------------------------------ // |
455 | .Attr("epsilon: float = 0.0001" ) |
456 | // Attributes for the LeakyRelu ----------------------------------------- // |
457 | .Attr("leakyrelu_alpha: float = 0.2" ) |
458 | // ---------------------------------------------------------------------- // |
459 | .SetShapeFn(shape_inference::Conv2DShape) |
460 | .Doc(R"doc( |
461 | *NOTE*: Do not invoke this operator directly in Python. MKL DNN graph transformer |
462 | is expected to create these operators. |
463 | )doc" ); |
464 | |
465 | REGISTER_OP("_MklNativePadWithFusedConv2D" ) |
466 | .Input("input: T" ) |
467 | .Input("filter: T" ) |
468 | .Input("args: num_args * T" ) |
469 | .Input("paddings: Tpaddings" ) |
470 | .Output("output: T" ) |
471 | .Attr("T: {bfloat16, float}" ) |
472 | .Attr("num_args: int >= 0" ) |
473 | .Attr("strides: list(int)" ) |
474 | .Attr("is_filter_const: bool = false" ) |
475 | .Attr(GetPaddingAttrString()) |
476 | .Attr(GetConvnetDataFormatAttrString()) |
477 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
478 | .Attr("fused_ops: list(string) = []" ) |
479 | .Attr("Tpaddings: {int32, int64} = DT_INT32" ) |
480 | // Attributes for the FusedBatchNorm ------------------------------------ // |
481 | .Attr("epsilon: float = 0.0001" ) |
482 | // Attributes for the LeakyRelu ----------------------------------------- // |
483 | .Attr("leakyrelu_alpha: float = 0.2" ) |
484 | // ---------------------------------------------------------------------- // |
485 | .SetShapeFn(shape_inference::Conv2DShape) |
486 | .Doc(R"doc( |
487 | *NOTE*: Do not invoke this operator directly in Python. oneDNN graph transformer |
488 | is expected to create these operators. |
489 | )doc" ); |
490 | |
491 | REGISTER_OP("_MklNativePadWithConv2D" ) |
492 | .Input("input: T" ) |
493 | .Input("filter: T" ) |
494 | .Input("paddings: Tpaddings" ) |
495 | .Output("output: T" ) |
496 | .Attr("T: {bfloat16, float}" ) |
497 | .Attr("strides: list(int)" ) |
498 | .Attr("use_cudnn_on_gpu: bool = true" ) |
499 | .Attr(GetPaddingAttrString()) |
500 | .Attr(GetConvnetDataFormatAttrString()) |
501 | .Attr("is_filter_const: bool = false" ) |
502 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
503 | .Attr("Tpaddings: {int32, int64} = DT_INT32" ) |
504 | .SetShapeFn(shape_inference::Conv2DShape) |
505 | .Doc(R"doc( |
506 | MKL version of Pad and Conv2D fusion that does not depend |
507 | on layout propagation. Uses oneDNN APIs to perform |
508 | the fusion. |
509 | |
510 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
511 | expected to invoke these operators. |
512 | )doc" ); |
513 | |
514 | REGISTER_OP("_MklNativeAvgPool" ) |
515 | .Input("value: T" ) |
516 | .Output("output: T" ) |
517 | .Attr("ksize: list(int) >= 4" ) |
518 | .Attr("strides: list(int) >= 4" ) |
519 | .Attr(GetPaddingAttrString()) |
520 | .Attr(GetConvnetDataFormatAttrString()) |
521 | .Attr("T: {float, half, double, bfloat16}" ) |
522 | .SetShapeFn(shape_inference::AvgPoolShape) |
523 | .Doc(R"doc( |
524 | oneDNN version of AvgPool operator that does not depend on layout |
525 | propagation. Uses oneDNN APIs to perform average pooling on the input. |
526 | |
527 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
528 | expected to invoke these operators. |
529 | )doc" ); |
530 | |
531 | REGISTER_OP("_MklNativeAvgPoolGrad" ) |
532 | .Input("orig_input_shape: int32" ) |
533 | .Input("grad: T" ) |
534 | .Output("output: T" ) |
535 | .Attr("ksize: list(int) >= 4" ) |
536 | .Attr("strides: list(int) >= 4" ) |
537 | .Attr(GetPaddingAttrString()) |
538 | .Attr(GetConvnetDataFormatAttrString()) |
539 | .Attr("T: {float, half, double, bfloat16}" ) |
540 | .SetShapeFn(shape_inference::AvgPoolGradShape) |
541 | .Doc(R"doc( |
542 | oneDNN version of AvgPoolGrad operator that does not depend on layout |
543 | propagation. Uses oneDNN APIs to compute gradients of AvgPool operator. |
544 | |
545 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
546 | expected to invoke these operators. |
547 | )doc" ); |
548 | |
549 | REGISTER_OP("_MklNativeAvgPool3D" ) |
550 | .Input("value: T" ) |
551 | .Output("output: T" ) |
552 | .Attr("ksize: list(int) >= 5" ) |
553 | .Attr("strides: list(int) >= 5" ) |
554 | .Attr(GetPaddingAttrString()) |
555 | .Attr(GetConvnet3dDataFormatAttrString()) |
556 | .Attr("T: {float, half, double, bfloat16}" ) |
557 | .SetShapeFn(shape_inference::Pool3DShape) |
558 | .Doc(R"doc( |
559 | oneDNN version of AvgPool3D operator that does not depend on layout |
560 | propagation. Uses oneDNN APIs to perform 3D average pooling on the input. |
561 | |
562 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
563 | expected to invoke these operators. |
564 | )doc" ); |
565 | |
566 | REGISTER_OP("_MklNativeAvgPool3DGrad" ) |
567 | .Input("orig_input_shape: int32" ) |
568 | .Input("grad: T" ) |
569 | .Output("output: T" ) |
570 | .Attr("ksize: list(int) >= 5" ) |
571 | .Attr("strides: list(int) >= 5" ) |
572 | .Attr(GetPaddingAttrString()) |
573 | .Attr(GetConvnet3dDataFormatAttrString()) |
574 | .Attr("T: {float, half, double, bfloat16}" ) |
575 | .SetShapeFn(shape_inference::AvgPool3DGradShape) |
576 | .Doc(R"doc( |
577 | oneDNN version of AvgPool3DGrad operator that does not depend on layout |
578 | propagation. Uses oneDNN APIs to compute gradients of AvgPool3D function. |
579 | |
580 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
581 | expected to invoke these operators. |
582 | )doc" ); |
583 | |
584 | REGISTER_OP("_MklNativeMaxPool" ) |
585 | .Attr("T: {float, half, bfloat16} = DT_FLOAT" ) |
586 | .Attr("ksize: list(int) >= 4" ) |
587 | .Attr("strides: list(int) >= 4" ) |
588 | .Attr(GetPaddingAttrString()) |
589 | .Attr(GetConvnetDataFormatAttrString()) |
590 | .Attr(GetExplicitPaddingsAttrString()) |
591 | .Attr("workspace_enabled: bool = false" ) |
592 | .Input("input: T" ) |
593 | .Output("output: T" ) |
594 | .Output("workspace: uint8" ) |
595 | .SetShapeFn(shape_inference::MaxPoolShape) |
596 | .Doc(R"doc( |
597 | oneDNN version of MaxPool operator that does not depend |
598 | on layout propagation. Uses oneDNN APIs to perform max pooling |
599 | on the input. |
600 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
601 | expected to invoke these operators. |
602 | )doc" ); |
603 | |
604 | REGISTER_OP("_MklNativeMaxPoolGrad" ) |
605 | .Attr("T: {float, half, bfloat16} = DT_FLOAT" ) |
606 | .Attr("ksize: list(int) >= 4" ) |
607 | .Attr("strides: list(int) >= 4" ) |
608 | .Attr("workspace_enabled: bool = false" ) |
609 | .Attr(GetPaddingAttrString()) |
610 | .Attr(GetConvnetDataFormatAttrString()) |
611 | .Attr(GetExplicitPaddingsAttrString()) |
612 | .Input("orig_input: T" ) |
613 | .Input("orig_output: T" ) |
614 | .Input("grad: T" ) |
615 | .Input("workspace: uint8" ) |
616 | .Output("output: T" ) |
617 | .SetShapeFn(shape_inference::MaxPoolGradShape) |
618 | .Doc(R"doc( |
619 | oneDNN version of MaxPoolGrad that does not depend on layout propagation. |
620 | Uses oneDNN APIs to compute gradients of MaxPool operator. |
621 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
622 | expected to invoke these operators. |
623 | )doc" ); |
624 | |
625 | REGISTER_OP("_MklNativeMaxPool3D" ) |
626 | .Input("input: T" ) |
627 | .Output("output: T" ) |
628 | .Output("workspace: uint8" ) |
629 | .Attr("ksize: list(int) >= 5" ) |
630 | .Attr("strides: list(int) >= 5" ) |
631 | .Attr(GetPaddingAttrString()) |
632 | .Attr(GetConvnet3dDataFormatAttrString()) |
633 | .Attr("T: {half, bfloat16, float}" ) |
634 | .Attr("workspace_enabled: bool = false" ) |
635 | .SetShapeFn(shape_inference::Pool3DShape) |
636 | .Doc(R"doc( |
637 | oneDNN version of MaxPool3D operator that does not depend on layout propagation. |
638 | Uses oneDNN APIs to perform 3D max pooling on the input. |
639 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
640 | expected to invoke these operators. |
641 | )doc" ); |
642 | |
643 | REGISTER_OP("_MklNativeMaxPool3DGrad" ) |
644 | .Input("orig_input: TInput" ) |
645 | .Input("orig_output: TInput" ) |
646 | .Input("grad: T" ) |
647 | .Input("workspace: uint8" ) |
648 | .Output("output: T" ) |
649 | .Attr("ksize: list(int) >= 5" ) |
650 | .Attr("strides: list(int) >= 5" ) |
651 | .Attr(GetPaddingAttrString()) |
652 | .Attr(GetConvnet3dDataFormatAttrString()) |
653 | .Attr("T: {half, bfloat16, float} = DT_FLOAT" ) |
654 | .Attr("TInput: {half, bfloat16, float} = DT_FLOAT" ) |
655 | .Attr("workspace_enabled: bool = false" ) |
656 | .SetShapeFn(shape_inference::MaxPool3DGradShape) |
657 | .Doc(R"doc( |
658 | oneDNN version of MaxPool3DGrad operator that does not depend on layout |
659 | propagation. Uses oneDNN APIs to compute gradients of MaxPool3D function. |
660 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
661 | expected to invoke these operators. |
662 | )doc" ); |
663 | |
664 | REGISTER_OP("_MklQuantizedMaxPool" ) |
665 | .Input("input: T" ) |
666 | .Input("min_input: float" ) |
667 | .Input("max_input: float" ) |
668 | .Output("output: T" ) |
669 | .Output("min_output: float" ) |
670 | .Output("max_output: float" ) |
671 | .Attr("T: quantizedtype" ) |
672 | .Attr("ksize: list(int) >= 4" ) |
673 | .Attr("strides: list(int) >= 4" ) |
674 | .Attr(GetPaddingAttrString()) |
675 | .SetShapeFn(shape_inference::MaxPoolShape) |
676 | .Doc(R"doc( |
677 | MKL version of QuantizedMaxPool operator. Uses MKL DNN APIs to perform max pooling |
678 | on the quantized input. |
679 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
680 | expected to invoke these operators. |
681 | )doc" ); |
682 | |
683 | REGISTER_OP("_MklQuantizedAvgPool" ) |
684 | .Input("input: T" ) |
685 | .Input("min_input: float" ) |
686 | .Input("max_input: float" ) |
687 | .Output("output: T" ) |
688 | .Output("min_output: float" ) |
689 | .Output("max_output: float" ) |
690 | .Attr("T: quantizedtype" ) |
691 | .Attr("ksize: list(int) >= 4" ) |
692 | .Attr("strides: list(int) >= 4" ) |
693 | .Attr(GetPaddingAttrString()) |
694 | .SetShapeFn(shape_inference::QuantizedAvgPoolShape) |
695 | .Doc(R"doc( |
696 | MKL version of QuantizedAvgPool operator. Uses MKL DNN APIs to perform average pooling |
697 | on the quantized input. |
698 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
699 | expected to invoke these operators. |
700 | )doc" ); |
701 | |
702 | REGISTER_OP("_FusedQuantizedConv2D" ) |
703 | .Input("device_inputs: Tdevice_inputs" ) |
704 | .Input("host_inputs: Thost_inputs" ) |
705 | .Output("device_outputs: Tdevice_outputs" ) |
706 | .Output("host_outputs: Thost_outputs" ) |
707 | .Attr("Tinput: quantizedtype = DT_QUINT8" ) |
708 | .Attr("Tfilter: quantizedtype = DT_QINT8" ) |
709 | .Attr("Tbias: {float, qint32} = DT_QINT32" ) |
710 | .Attr("Tsummand: {float, quint8, qint8, qint32}" ) |
711 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
712 | .Attr("Tdevice_inputs: list(type) >= 0 = []" ) |
713 | .Attr("Thost_inputs: list(type) >= 0" ) |
714 | .Attr("Tdevice_outputs: list(type) >= 0 = []" ) |
715 | .Attr("Thost_outputs: list(type) >= 0" ) |
716 | .Attr("data_format: string = 'NHWC'" ) |
717 | .Attr("strides: list(int)" ) |
718 | .Attr("is_filter_const: bool = true" ) |
719 | .Attr("is_bias_const: bool = true" ) |
720 | .Attr(GetPaddingAttrStringWithExplicit()) |
721 | .Attr(GetExplicitPaddingsAttrString()) |
722 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
723 | .Attr("fused_ops: list(string) = []" ) |
724 | .Attr("alpha: float = 0.0" ) |
725 | .SetShapeFn(shape_inference::FusedQuantizedConv2DShape); |
726 | |
727 | REGISTER_OP("_FusedQuantizedDepthwiseConv2D" ) |
728 | .Input("device_inputs: Tdevice_inputs" ) |
729 | .Input("host_inputs: Thost_inputs" ) |
730 | .Output("device_outputs: Tdevice_outputs" ) |
731 | .Output("host_outputs: Thost_outputs" ) |
732 | .Attr("Tinput: quantizedtype = DT_QUINT8" ) |
733 | .Attr("Tfilter: quantizedtype = DT_QINT8" ) |
734 | .Attr("Tbias: {float, qint32} = DT_QINT32" ) |
735 | .Attr("Tsummand: {float, quint8, qint8, qint32}" ) |
736 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
737 | .Attr("Tdevice_inputs: list(type) >= 0 = []" ) |
738 | .Attr("Thost_inputs: list(type) >= 0" ) |
739 | .Attr("Tdevice_outputs: list(type) >= 0 = []" ) |
740 | .Attr("Thost_outputs: list(type) >= 0" ) |
741 | .Attr("data_format: string = 'NHWC'" ) |
742 | .Attr("strides: list(int)" ) |
743 | .Attr("is_filter_const: bool = true" ) |
744 | .Attr("is_bias_const: bool = true" ) |
745 | .Attr(GetPaddingAttrStringWithExplicit()) |
746 | .Attr(GetExplicitPaddingsAttrString()) |
747 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
748 | .Attr("fused_ops: list(string) = []" ) |
749 | .Attr("alpha: float = 0.0" ) |
750 | .SetShapeFn(shape_inference::FusedQuantizedDepthwiseConv2D); |
751 | |
752 | REGISTER_OP("_MklQuantizedConv2D" ) |
753 | .Input("input: Tinput" ) |
754 | .Input("filter: Tfilter" ) |
755 | .Input("min_input: float" ) |
756 | .Input("max_input: float" ) |
757 | .Input("min_filter: float" ) |
758 | .Input("max_filter: float" ) |
759 | .Output("output: out_type" ) |
760 | .Output("min_output: float" ) |
761 | .Output("max_output: float" ) |
762 | .Attr("Tinput: quantizedtype" ) |
763 | .Attr("Tfilter: quantizedtype" ) |
764 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
765 | .Attr("data_format: string = 'NHWC'" ) |
766 | .Attr("strides: list(int)" ) |
767 | .Attr("is_filter_const: bool = true" ) |
768 | .Attr(GetPaddingAttrString()) |
769 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
770 | .Attr("padding_list: list(int) = []" ) |
771 | .SetShapeFn(shape_inference::QuantizedConv2DShape); |
772 | |
773 | // TODO(nammbash): Most of the TF_RETURN_IF_ERROR(c->WithRank) checks |
774 | // seems to be similar and hence can be moved into a single function |
775 | // with appropriate arguments for a cleaner design. |
776 | REGISTER_OP("_MklQuantizedConv2DAndRequantize" ) |
777 | .Input("input: Tinput" ) |
778 | .Input("filter: Tfilter" ) |
779 | .Input("min_input: float" ) |
780 | .Input("max_input: float" ) |
781 | .Input("min_filter: float" ) |
782 | .Input("max_filter: float" ) |
783 | .Input("min_freezed_output: float" ) |
784 | .Input("max_freezed_output: float" ) |
785 | .Output("output: out_type" ) |
786 | .Output("min_output: float" ) |
787 | .Output("max_output: float" ) |
788 | .Attr("Tinput: quantizedtype" ) |
789 | .Attr("Tfilter: quantizedtype" ) |
790 | .Attr("out_type: quantizedtype = DT_QINT8" ) |
791 | .Attr("data_format: string = 'NHWC'" ) |
792 | .Attr("strides: list(int)" ) |
793 | .Attr("is_filter_const: bool = true" ) |
794 | .Attr(GetPaddingAttrString()) |
795 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
796 | .Attr("padding_list: list(int) = []" ) |
797 | .SetShapeFn([](InferenceContext* c) { |
798 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
799 | ShapeHandle unused; |
800 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
801 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
802 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused)); |
803 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused)); |
804 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); |
805 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
806 | c->set_output(1, c->Scalar()); |
807 | c->set_output(2, c->Scalar()); |
808 | return Status::OK(); |
809 | }); |
810 | |
811 | REGISTER_OP("_MklQuantizedConv2DWithBias" ) |
812 | .Input("input: Tinput" ) |
813 | .Input("filter: Tfilter" ) |
814 | .Input("bias: float" ) |
815 | .Input("min_input: float" ) |
816 | .Input("max_input: float" ) |
817 | .Input("min_filter: float" ) |
818 | .Input("max_filter: float" ) |
819 | .Output("output: out_type" ) |
820 | .Output("min_output: float" ) |
821 | .Output("max_output: float" ) |
822 | .Attr("Tinput: quantizedtype" ) |
823 | .Attr("Tfilter: quantizedtype" ) |
824 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
825 | .Attr("data_format: string = 'NHWC'" ) |
826 | .Attr("strides: list(int)" ) |
827 | .Attr("is_filter_const: bool = true" ) |
828 | .Attr("is_bias_const: bool = true" ) |
829 | .Attr(GetPaddingAttrString()) |
830 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
831 | .Attr("padding_list: list(int) = []" ) |
832 | .SetShapeFn([](InferenceContext* c) { |
833 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
834 | ShapeHandle unused, channel; |
835 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
836 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
837 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
838 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
839 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); |
840 | c->set_output(1, channel); |
841 | c->set_output(2, channel); |
842 | return Status::OK(); |
843 | }); |
844 | |
845 | REGISTER_OP("_MklQuantizedConv2DWithBiasAndRequantize" ) |
846 | .Input("input: Tinput" ) |
847 | .Input("filter: Tfilter" ) |
848 | .Input("bias: Tbias" ) |
849 | .Input("min_input: float" ) |
850 | .Input("max_input: float" ) |
851 | .Input("min_filter: float" ) |
852 | .Input("max_filter: float" ) |
853 | .Input("min_freezed_output: float" ) |
854 | .Input("max_freezed_output: float" ) |
855 | .Output("output: out_type" ) |
856 | .Output("min_output: float" ) |
857 | .Output("max_output: float" ) |
858 | .Attr("Tinput: quantizedtype" ) |
859 | .Attr("Tfilter: quantizedtype" ) |
860 | .Attr("Tbias: {float, qint32}" ) |
861 | .Attr("out_type: quantizedtype = DT_QINT8" ) |
862 | .Attr("data_format: string = 'NHWC'" ) |
863 | .Attr("strides: list(int)" ) |
864 | .Attr("is_filter_const: bool = true" ) |
865 | .Attr("is_bias_const: bool = true" ) |
866 | .Attr(GetPaddingAttrString()) |
867 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
868 | .Attr("padding_list: list(int) = []" ) |
869 | .SetShapeFn([](InferenceContext* c) { |
870 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
871 | ShapeHandle unused; |
872 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
873 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
874 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
875 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused)); |
876 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &unused)); |
877 | c->set_output(1, c->Scalar()); |
878 | c->set_output(2, c->Scalar()); |
879 | return Status::OK(); |
880 | }); |
881 | |
882 | REGISTER_OP("_MklQuantizedConv2DAndRelu" ) |
883 | .Input("input: Tinput" ) |
884 | .Input("filter: Tfilter" ) |
885 | .Input("min_input: float" ) |
886 | .Input("max_input: float" ) |
887 | .Input("min_filter: float" ) |
888 | .Input("max_filter: float" ) |
889 | .Output("output: out_type" ) |
890 | .Output("min_output: float" ) |
891 | .Output("max_output: float" ) |
892 | .Attr("Tinput: quantizedtype" ) |
893 | .Attr("Tfilter: quantizedtype" ) |
894 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
895 | .Attr("data_format: string = 'NHWC'" ) |
896 | .Attr("strides: list(int)" ) |
897 | .Attr("is_filter_const: bool = true" ) |
898 | .Attr(GetPaddingAttrString()) |
899 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
900 | .Attr("padding_list: list(int) = []" ) |
901 | .SetShapeFn([](InferenceContext* c) { |
902 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
903 | ShapeHandle unused, channel; |
904 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
905 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
906 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel)); |
907 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
908 | c->set_output(1, channel); |
909 | c->set_output(2, channel); |
910 | return Status::OK(); |
911 | }); |
912 | |
913 | REGISTER_OP("_MklQuantizedConv2DAndReluAndRequantize" ) |
914 | .Input("input: Tinput" ) |
915 | .Input("filter: Tfilter" ) |
916 | .Input("min_input: float" ) |
917 | .Input("max_input: float" ) |
918 | .Input("min_filter: float" ) |
919 | .Input("max_filter: float" ) |
920 | .Input("min_freezed_output: float" ) |
921 | .Input("max_freezed_output: float" ) |
922 | .Output("output: out_type" ) |
923 | .Output("min_output: float" ) |
924 | .Output("max_output: float" ) |
925 | .Attr("Tinput: quantizedtype" ) |
926 | .Attr("Tfilter: quantizedtype" ) |
927 | .Attr("out_type: quantizedtype = DT_QUINT8" ) |
928 | .Attr("data_format: string = 'NHWC'" ) |
929 | .Attr("strides: list(int)" ) |
930 | .Attr("is_filter_const: bool = true" ) |
931 | .Attr(GetPaddingAttrString()) |
932 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
933 | .Attr("padding_list: list(int) = []" ) |
934 | .SetShapeFn([](InferenceContext* c) { |
935 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
936 | ShapeHandle unused; |
937 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
938 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
939 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused)); |
940 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused)); |
941 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); |
942 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
943 | c->set_output(1, c->Scalar()); |
944 | c->set_output(2, c->Scalar()); |
945 | return Status::OK(); |
946 | }); |
947 | |
948 | REGISTER_OP("_MklQuantizedConv2DWithBiasAndRelu" ) |
949 | .Input("input: Tinput" ) |
950 | .Input("filter: Tfilter" ) |
951 | .Input("bias: float" ) |
952 | .Input("min_input: float" ) |
953 | .Input("max_input: float" ) |
954 | .Input("min_filter: float" ) |
955 | .Input("max_filter: float" ) |
956 | .Output("output: out_type" ) |
957 | .Output("min_output: float" ) |
958 | .Output("max_output: float" ) |
959 | .Attr("Tinput: quantizedtype" ) |
960 | .Attr("Tfilter: quantizedtype" ) |
961 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
962 | .Attr("data_format: string = 'NHWC'" ) |
963 | .Attr("strides: list(int)" ) |
964 | .Attr("is_filter_const: bool = true" ) |
965 | .Attr("is_bias_const: bool = true" ) |
966 | .Attr(GetPaddingAttrString()) |
967 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
968 | .Attr("padding_list: list(int) = []" ) |
969 | .SetShapeFn([](InferenceContext* c) { |
970 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
971 | ShapeHandle unused, channel; |
972 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
973 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
974 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
975 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
976 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); |
977 | c->set_output(1, channel); |
978 | c->set_output(2, channel); |
979 | return Status::OK(); |
980 | }); |
981 | |
982 | REGISTER_OP("_MklQuantizedConv2DWithBiasAndReluAndRequantize" ) |
983 | .Input("input: Tinput" ) |
984 | .Input("filter: Tfilter" ) |
985 | .Input("bias: Tbias" ) |
986 | .Input("min_input: float" ) |
987 | .Input("max_input: float" ) |
988 | .Input("min_filter: float" ) |
989 | .Input("max_filter: float" ) |
990 | .Input("min_freezed_output: float" ) |
991 | .Input("max_freezed_output: float" ) |
992 | .Output("output: out_type" ) |
993 | .Output("min_output: float" ) |
994 | .Output("max_output: float" ) |
995 | .Attr("Tinput: quantizedtype" ) |
996 | .Attr("Tfilter: quantizedtype" ) |
997 | .Attr("Tbias: {float, qint32}" ) |
998 | .Attr("out_type: quantizedtype = DT_QUINT8" ) |
999 | .Attr("data_format: string = 'NHWC'" ) |
1000 | .Attr("strides: list(int)" ) |
1001 | .Attr("is_filter_const: bool = true" ) |
1002 | .Attr("is_bias_const: bool = true" ) |
1003 | .Attr(GetPaddingAttrString()) |
1004 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1005 | .Attr("padding_list: list(int) = []" ) |
1006 | .SetShapeFn([](InferenceContext* c) { |
1007 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
1008 | ShapeHandle unused; |
1009 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
1010 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1011 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
1012 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused)); |
1013 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &unused)); |
1014 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
1015 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); |
1016 | c->set_output(1, c->Scalar()); |
1017 | c->set_output(2, c->Scalar()); |
1018 | return Status::OK(); |
1019 | }); |
1020 | |
1021 | REGISTER_OP("_MklQuantizedConv2DWithBiasSumAndRelu" ) |
1022 | .Input("input: Tinput" ) |
1023 | .Input("filter: Tfilter" ) |
1024 | .Input("bias: float" ) |
1025 | .Input("min_input: float" ) |
1026 | .Input("max_input: float" ) |
1027 | .Input("min_filter: float" ) |
1028 | .Input("max_filter: float" ) |
1029 | .Input("summand: float" ) |
1030 | .Output("output: out_type" ) |
1031 | .Output("min_output: float" ) |
1032 | .Output("max_output: float" ) |
1033 | .Attr("Tinput: quantizedtype" ) |
1034 | .Attr("Tfilter: quantizedtype" ) |
1035 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
1036 | .Attr("data_format: string = 'NHWC'" ) |
1037 | .Attr("strides: list(int)" ) |
1038 | .Attr("is_filter_const: bool = true" ) |
1039 | .Attr("is_bias_const: bool = true" ) |
1040 | .Attr(GetPaddingAttrString()) |
1041 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1042 | .Attr("padding_list: list(int) = []" ) |
1043 | .SetShapeFn([](InferenceContext* c) { |
1044 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
1045 | ShapeHandle unused, channel; |
1046 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
1047 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1048 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
1049 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
1050 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); |
1051 | c->set_output(1, channel); |
1052 | c->set_output(2, channel); |
1053 | return Status::OK(); |
1054 | }); |
1055 | |
1056 | REGISTER_OP("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize" ) |
1057 | .Input("input: Tinput" ) |
1058 | .Input("filter: Tfilter" ) |
1059 | .Input("bias: Tbias" ) |
1060 | .Input("min_input: float" ) |
1061 | .Input("max_input: float" ) |
1062 | .Input("min_filter: float" ) |
1063 | .Input("max_filter: float" ) |
1064 | .Input("min_freezed_output: float" ) |
1065 | .Input("max_freezed_output: float" ) |
1066 | .Input("summand: Tsummand" ) |
1067 | .Input("min_summand: float" ) |
1068 | .Input("max_summand: float" ) |
1069 | .Output("output: out_type" ) |
1070 | .Output("min_output: float" ) |
1071 | .Output("max_output: float" ) |
1072 | .Attr("Tinput: quantizedtype" ) |
1073 | .Attr("Tfilter: quantizedtype" ) |
1074 | .Attr("Tbias: {float, qint32}" ) |
1075 | .Attr("Tsummand: quantizedtype" ) |
1076 | .Attr("out_type: quantizedtype = DT_QUINT8" ) |
1077 | .Attr("data_format: string = 'NHWC'" ) |
1078 | .Attr("strides: list(int)" ) |
1079 | .Attr("is_filter_const: bool = true" ) |
1080 | .Attr("is_bias_const: bool = true" ) |
1081 | .Attr(GetPaddingAttrString()) |
1082 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1083 | .Attr("padding_list: list(int) = []" ) |
1084 | .SetShapeFn([](InferenceContext* c) { |
1085 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
1086 | ShapeHandle unused; |
1087 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
1088 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1089 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
1090 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused)); |
1091 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &unused)); |
1092 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
1093 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); |
1094 | c->set_output(1, c->Scalar()); |
1095 | c->set_output(2, c->Scalar()); |
1096 | return Status::OK(); |
1097 | }); |
1098 | |
1099 | REGISTER_OP("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize" ) |
1100 | .Input("input: Tinput" ) |
1101 | .Input("filter: Tfilter" ) |
1102 | .Input("bias: Tbias" ) |
1103 | .Input("min_input: float" ) |
1104 | .Input("max_input: float" ) |
1105 | .Input("min_filter: float" ) |
1106 | .Input("max_filter: float" ) |
1107 | .Input("min_freezed_output: float" ) |
1108 | .Input("max_freezed_output: float" ) |
1109 | .Input("summand: Tsummand" ) |
1110 | .Input("min_summand: float" ) |
1111 | .Input("max_summand: float" ) |
1112 | .Output("output: out_type" ) |
1113 | .Output("min_output: float" ) |
1114 | .Output("max_output: float" ) |
1115 | .Attr("Tinput: quantizedtype" ) |
1116 | .Attr("Tfilter: quantizedtype" ) |
1117 | .Attr("Tbias: {float, qint32}" ) |
1118 | .Attr("Tsummand: quantizedtype" ) |
1119 | .Attr("out_type: quantizedtype = DT_QUINT8" ) |
1120 | .Attr("data_format: string = 'NHWC'" ) |
1121 | .Attr("strides: list(int)" ) |
1122 | .Attr("is_filter_const: bool = true" ) |
1123 | .Attr("is_bias_const: bool = true" ) |
1124 | .Attr(GetPaddingAttrString()) |
1125 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1126 | .Attr("padding_list: list(int) = []" ) |
1127 | .SetShapeFn([](InferenceContext* c) { |
1128 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
1129 | ShapeHandle unused; |
1130 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
1131 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1132 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
1133 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused)); |
1134 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &unused)); |
1135 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
1136 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); |
1137 | c->set_output(1, c->Scalar()); |
1138 | c->set_output(2, c->Scalar()); |
1139 | return Status::OK(); |
1140 | }); |
1141 | |
1142 | REGISTER_OP("_MklQuantizedConv2DPerChannel" ) |
1143 | .Input("input: Tinput" ) |
1144 | .Input("filter: Tfilter" ) |
1145 | .Input("min_input: float" ) |
1146 | .Input("max_input: float" ) |
1147 | .Input("min_filter: float" ) |
1148 | .Input("max_filter: float" ) |
1149 | .Output("output: out_type" ) |
1150 | .Output("min_output: float" ) |
1151 | .Output("max_output: float" ) |
1152 | .Attr("Tinput: quantizedtype" ) |
1153 | .Attr("Tfilter: quantizedtype" ) |
1154 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
1155 | .Attr("data_format: string = 'NHWC'" ) |
1156 | .Attr("strides: list(int)" ) |
1157 | .Attr("is_filter_const: bool = false" ) |
1158 | .Attr(GetPaddingAttrString()) |
1159 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1160 | .Attr("padding_list: list(int) = []" ) |
1161 | .SetShapeFn([](InferenceContext* c) { |
1162 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
1163 | ShapeHandle unused, channel; |
1164 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
1165 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1166 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel)); |
1167 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
1168 | c->set_output(1, channel); |
1169 | c->set_output(2, channel); |
1170 | return Status::OK(); |
1171 | }) |
1172 | .Doc(R"doc( |
1173 | MKL-DNN implementation of QuantizedConv2D op. |
1174 | )doc" ); |
1175 | |
1176 | REGISTER_OP("_MklDepthwiseConv2dNativeBackpropInput" ) |
1177 | .Input("input_sizes: int32" ) |
1178 | .Input("filter: T" ) |
1179 | .Input("out_backprop: T" ) |
1180 | .Input("mkl_input: uint8" ) |
1181 | .Input("mkl_filter: uint8" ) |
1182 | .Input("mkl_out_backprop: uint8" ) |
1183 | .Output("output: T" ) |
1184 | .Output("mkl_output: uint8" ) |
1185 | .Attr("T: {half, bfloat16, float, double}" ) |
1186 | .Attr("strides: list(int)" ) |
1187 | .Attr(GetPaddingAttrString()) |
1188 | .Attr(GetConvnetDataFormatAttrString()) |
1189 | .Attr(GetExplicitPaddingsAttrString()) |
1190 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1191 | .SetShapeFn([](InferenceContext* c) { |
1192 | ShapeHandle s; |
1193 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); |
1194 | TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); |
1195 | c->set_output(0, s); |
1196 | return Status::OK(); |
1197 | }); |
1198 | |
1199 | REGISTER_OP("_MklEinsum" ) |
1200 | .Input("inputs: N * T" ) |
1201 | .Output("output: T" ) |
1202 | .Attr("equation: string" ) |
1203 | .Attr("N: int >= 1" ) |
1204 | .Attr("T: {bfloat16, float}" ) |
1205 | .SetShapeFn(shape_inference::EinsumShape); |
1206 | |
1207 | REGISTER_OP("_MklDepthwiseConv2dNativeBackpropFilter" ) |
1208 | .Input("input: T" ) |
1209 | .Input("filter_sizes: int32" ) |
1210 | .Input("out_backprop: T" ) |
1211 | .Input("mkl_input: uint8" ) |
1212 | .Input("mkl_filter: uint8" ) |
1213 | .Input("mkl_out_backprop: uint8" ) |
1214 | .Output("output: T" ) |
1215 | .Output("mkl_output: uint8" ) |
1216 | .Attr("T: {half, bfloat16, float, double}" ) |
1217 | .Attr("strides: list(int)" ) |
1218 | .Attr(GetPaddingAttrString()) |
1219 | .Attr(GetConvnetDataFormatAttrString()) |
1220 | .Attr(GetExplicitPaddingsAttrString()) |
1221 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1222 | .SetShapeFn([](InferenceContext* c) { |
1223 | ShapeHandle s; |
1224 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); |
1225 | TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); |
1226 | c->set_output(0, s); |
1227 | return Status::OK(); |
1228 | }); |
1229 | |
1230 | REGISTER_OP("_MklQuantizedMatMulWithBias" ) |
1231 | .Input("a: T1" ) |
1232 | .Input("b: T2" ) |
1233 | .Input("bias: Tbias" ) |
1234 | .Input("min_a: float" ) |
1235 | .Input("max_a: float" ) |
1236 | .Input("min_b: float" ) |
1237 | .Input("max_b: float" ) |
1238 | .Output("out: Toutput" ) |
1239 | .Output("min_out: float" ) |
1240 | .Output("max_out: float" ) |
1241 | .Attr("T1: quantizedtype" ) |
1242 | .Attr("T2: quantizedtype" ) |
1243 | .Attr("Tbias: {float, qint32}" ) |
1244 | .Attr("Toutput: quantizedtype = DT_QINT32" ) |
1245 | .Attr("transpose_a: bool = false" ) |
1246 | .Attr("transpose_b: bool = false" ) |
1247 | .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'" ) |
1248 | .Attr("is_weight_const: bool = true" ) |
1249 | .SetShapeFn([](InferenceContext* c) { |
1250 | TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); |
1251 | ShapeHandle unused; |
1252 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
1253 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1254 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
1255 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
1256 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); |
1257 | |
1258 | c->set_output(1, c->Scalar()); |
1259 | c->set_output(2, c->Scalar()); |
1260 | return Status::OK(); |
1261 | }); |
1262 | |
1263 | REGISTER_OP("_MklQuantizedMatMulWithBiasAndRelu" ) |
1264 | .Input("a: T1" ) |
1265 | .Input("b: T2" ) |
1266 | // TODO(intel-tf): Modify bias type as Tbias and add relevant attribute. |
1267 | .Input("bias: float" ) |
1268 | .Input("min_a: float" ) |
1269 | .Input("max_a: float" ) |
1270 | .Input("min_b: float" ) |
1271 | .Input("max_b: float" ) |
1272 | .Output("out: Toutput" ) |
1273 | .Output("min_out: float" ) |
1274 | .Output("max_out: float" ) |
1275 | .Attr("T1: quantizedtype" ) |
1276 | .Attr("T2: quantizedtype" ) |
1277 | .Attr("Toutput: quantizedtype = DT_QINT32" ) |
1278 | .Attr("transpose_a: bool = false" ) |
1279 | .Attr("transpose_b: bool = false" ) |
1280 | .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'" ) |
1281 | .Attr("is_weight_const: bool = true" ) |
1282 | .SetShapeFn([](InferenceContext* c) { |
1283 | TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); |
1284 | ShapeHandle unused; |
1285 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
1286 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1287 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
1288 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
1289 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); |
1290 | |
1291 | c->set_output(1, c->Scalar()); |
1292 | c->set_output(2, c->Scalar()); |
1293 | return Status::OK(); |
1294 | }); |
1295 | |
1296 | REGISTER_OP("_MklQuantizedMatMulWithBiasAndReluAndRequantize" ) |
1297 | .Input("a: T1" ) |
1298 | .Input("b: T2" ) |
1299 | .Input("bias: Tbias" ) |
1300 | .Input("min_a: float" ) |
1301 | .Input("max_a: float" ) |
1302 | .Input("min_b: float" ) |
1303 | .Input("max_b: float" ) |
1304 | .Input("min_freezed_output: float" ) |
1305 | .Input("max_freezed_output: float" ) |
1306 | .Output("out: Toutput" ) |
1307 | .Output("min_out: float" ) |
1308 | .Output("max_out: float" ) |
1309 | .Attr("T1: quantizedtype" ) |
1310 | .Attr("T2: quantizedtype" ) |
1311 | .Attr("Tbias: {float, qint32}" ) |
1312 | .Attr("Toutput: quantizedtype = DT_QUINT8" ) |
1313 | .Attr("transpose_a: bool = false" ) |
1314 | .Attr("transpose_b: bool = false" ) |
1315 | .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'" ) |
1316 | .Attr("is_weight_const: bool = true" ) |
1317 | .SetShapeFn([](InferenceContext* c) { |
1318 | TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); |
1319 | ShapeHandle unused; |
1320 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
1321 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1322 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
1323 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
1324 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); |
1325 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
1326 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); |
1327 | |
1328 | c->set_output(1, c->Scalar()); |
1329 | c->set_output(2, c->Scalar()); |
1330 | return Status::OK(); |
1331 | }); |
1332 | |
1333 | REGISTER_OP("_MklQuantizedMatMulWithBiasAndDequantize" ) |
1334 | .Input("a: T1" ) |
1335 | .Input("b: T2" ) |
1336 | .Input("bias: Tbias" ) |
1337 | .Input("min_a: float" ) |
1338 | .Input("max_a: float" ) |
1339 | .Input("min_b: float" ) |
1340 | .Input("max_b: float" ) |
1341 | .Input("min_freezed_output: float" ) |
1342 | .Input("max_freezed_output: float" ) |
1343 | .Output("out: Toutput" ) |
1344 | .Attr("T1: quantizedtype" ) |
1345 | .Attr("T2: quantizedtype" ) |
1346 | .Attr("Tbias: {float, qint32}" ) |
1347 | .Attr("Toutput: {float}" ) |
1348 | .Attr("transpose_a: bool = false" ) |
1349 | .Attr("transpose_b: bool = false" ) |
1350 | .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'" ) |
1351 | .Attr("is_weight_const: bool = true" ) |
1352 | .SetShapeFn([](InferenceContext* c) { |
1353 | TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); |
1354 | ShapeHandle unused; |
1355 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
1356 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1357 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
1358 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
1359 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); |
1360 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
1361 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); |
1362 | |
1363 | return Status::OK(); |
1364 | }); |
1365 | |
1366 | REGISTER_OP("_MklQuantizedMatMulWithBiasAndRequantize" ) |
1367 | .Input("a: T1" ) |
1368 | .Input("b: T2" ) |
1369 | .Input("bias: Tbias" ) |
1370 | .Input("min_a: float" ) |
1371 | .Input("max_a: float" ) |
1372 | .Input("min_b: float" ) |
1373 | .Input("max_b: float" ) |
1374 | .Input("min_freezed_output: float" ) |
1375 | .Input("max_freezed_output: float" ) |
1376 | .Output("out: Toutput" ) |
1377 | .Output("min_out: float" ) |
1378 | .Output("max_out: float" ) |
1379 | .Attr("T1: quantizedtype" ) |
1380 | .Attr("T2: quantizedtype" ) |
1381 | .Attr("Tbias: {float, qint32}" ) |
1382 | .Attr("Toutput: quantizedtype = DT_QUINT8" ) |
1383 | .Attr("transpose_a: bool = false" ) |
1384 | .Attr("transpose_b: bool = false" ) |
1385 | .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'" ) |
1386 | .Attr("is_weight_const: bool = true" ) |
1387 | .SetShapeFn([](InferenceContext* c) { |
1388 | TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); |
1389 | ShapeHandle unused; |
1390 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
1391 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1392 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
1393 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
1394 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); |
1395 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
1396 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); |
1397 | |
1398 | c->set_output(1, c->Scalar()); |
1399 | c->set_output(2, c->Scalar()); |
1400 | return Status::OK(); |
1401 | }); |
1402 | |
1403 | REGISTER_OP("_MklQuantizedDepthwiseConv2D" ) |
1404 | .Input("input: Tinput" ) |
1405 | .Input("filter: Tfilter" ) |
1406 | .Input("min_input: float" ) |
1407 | .Input("max_input: float" ) |
1408 | .Input("min_filter: float" ) |
1409 | .Input("max_filter: float" ) |
1410 | .Output("output: out_type" ) |
1411 | .Output("min_output: float" ) |
1412 | .Output("max_output: float" ) |
1413 | .Attr("Tinput: quantizedtype" ) |
1414 | .Attr("Tfilter: quantizedtype" ) |
1415 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
1416 | .Attr("data_format: string = 'NHWC'" ) |
1417 | .Attr("strides: list(int)" ) |
1418 | .Attr("is_filter_const: bool = true" ) |
1419 | .Attr(GetPaddingAttrString()) |
1420 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1421 | .SetShapeFn([](InferenceContext* c) { |
1422 | // TODO(bhavanis): Print an error message during the return. |
1423 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
1424 | ShapeHandle unused, channel; |
1425 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
1426 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1427 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused)); |
1428 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
1429 | c->set_output(1, channel); |
1430 | c->set_output(2, channel); |
1431 | return Status::OK(); |
1432 | }) |
1433 | .Doc(R"doc( |
1434 | MKL-DNN implementation of quantized depthwise Conv2D. |
1435 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
1436 | expected to invoke this operator. |
1437 | )doc" ); |
1438 | |
1439 | REGISTER_OP("_MklQuantizedDepthwiseConv2DWithBias" ) |
1440 | .Input("input: Tinput" ) |
1441 | .Input("filter: Tfilter" ) |
1442 | .Input("bias: float" ) |
1443 | .Input("min_input: float" ) |
1444 | .Input("max_input: float" ) |
1445 | .Input("min_filter: float" ) |
1446 | .Input("max_filter: float" ) |
1447 | .Output("output: out_type" ) |
1448 | .Output("min_output: float" ) |
1449 | .Output("max_output: float" ) |
1450 | .Attr("Tinput: quantizedtype" ) |
1451 | .Attr("Tfilter: quantizedtype" ) |
1452 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
1453 | .Attr("data_format: string = 'NHWC'" ) |
1454 | .Attr("strides: list(int)" ) |
1455 | .Attr("is_filter_const: bool = true" ) |
1456 | .Attr("is_bias_const: bool = true" ) |
1457 | .Attr(GetPaddingAttrString()) |
1458 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1459 | .SetShapeFn([](InferenceContext* c) { |
1460 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
1461 | ShapeHandle unused, channel; |
1462 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
1463 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1464 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
1465 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
1466 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); |
1467 | c->set_output(1, channel); |
1468 | c->set_output(2, channel); |
1469 | return Status::OK(); |
1470 | }) |
1471 | .Doc(R"doc( |
1472 | MKL-DNN implementation of quantized depthwise Conv2D with Bias. |
1473 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
1474 | expected to invoke this operator. |
1475 | )doc" ); |
1476 | |
1477 | REGISTER_OP("_MklQuantizedDepthwiseConv2DWithBiasAndRelu" ) |
1478 | .Input("input: Tinput" ) |
1479 | .Input("filter: Tfilter" ) |
1480 | .Input("bias: float" ) |
1481 | .Input("min_input: float" ) |
1482 | .Input("max_input: float" ) |
1483 | .Input("min_filter: float" ) |
1484 | .Input("max_filter: float" ) |
1485 | .Output("output: out_type" ) |
1486 | .Output("min_output: float" ) |
1487 | .Output("max_output: float" ) |
1488 | .Attr("Tinput: quantizedtype" ) |
1489 | .Attr("Tfilter: quantizedtype" ) |
1490 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
1491 | .Attr("data_format: string = 'NHWC'" ) |
1492 | .Attr("strides: list(int)" ) |
1493 | .Attr("is_filter_const: bool = true" ) |
1494 | .Attr("is_bias_const: bool = true" ) |
1495 | .Attr(GetPaddingAttrString()) |
1496 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1497 | .Attr("padding_list: list(int) = []" ) |
1498 | .SetShapeFn([](InferenceContext* c) { |
1499 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
1500 | ShapeHandle unused, channel; |
1501 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
1502 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1503 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
1504 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
1505 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); |
1506 | c->set_output(1, channel); |
1507 | c->set_output(2, channel); |
1508 | return Status::OK(); |
1509 | }) |
1510 | .Doc(R"doc( |
1511 | MKL-DNN implementation of quantized depthwise Conv2D with Bias and Relu. |
1512 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
1513 | expected to invoke this operator. |
1514 | )doc" ); |
1515 | |
1516 | REGISTER_OP("_MklQuantizedDepthwiseConv2DWithBiasAndReluAndRequantize" ) |
1517 | .Input("input: Tinput" ) |
1518 | .Input("filter: Tfilter" ) |
1519 | .Input("bias: Tbias" ) |
1520 | .Input("min_input: float" ) |
1521 | .Input("max_input: float" ) |
1522 | .Input("min_filter: float" ) |
1523 | .Input("max_filter: float" ) |
1524 | .Input("min_freezed_output: float" ) |
1525 | .Input("max_freezed_output: float" ) |
1526 | .Output("output: out_type" ) |
1527 | .Output("min_output: float" ) |
1528 | .Output("max_output: float" ) |
1529 | .Attr("Tinput: quantizedtype" ) |
1530 | .Attr("Tfilter: quantizedtype" ) |
1531 | .Attr("Tbias: {float, qint32}" ) |
1532 | .Attr("out_type: quantizedtype = DT_QUINT8" ) |
1533 | .Attr("data_format: string = 'NHWC'" ) |
1534 | .Attr("strides: list(int)" ) |
1535 | .Attr("is_filter_const: bool = true" ) |
1536 | .Attr("is_bias_const: bool = true" ) |
1537 | .Attr(GetPaddingAttrString()) |
1538 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1539 | .Attr("padding_list: list(int) = []" ) |
1540 | .SetShapeFn([](InferenceContext* c) { |
1541 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
1542 | ShapeHandle unused; |
1543 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
1544 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1545 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
1546 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused)); |
1547 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &unused)); |
1548 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
1549 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); |
1550 | c->set_output(1, c->Scalar()); |
1551 | c->set_output(2, c->Scalar()); |
1552 | return Status::OK(); |
1553 | }) |
1554 | .Doc(R"doc( |
1555 | MKL-DNN implementation of quantized depthwise Conv2D with Bias, Relu and Requantize. |
1556 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
1557 | expected to invoke this operator. |
1558 | )doc" ); |
1559 | |
1560 | REGISTER_OP("_MklFusedBatchNormV3" ) |
1561 | .Input("x: T" ) |
1562 | .Input("scale: U" ) |
1563 | .Input("offset: U" ) |
1564 | .Input("mean: U" ) |
1565 | .Input("variance: U" ) |
1566 | .Input("mkl_x: uint8" ) |
1567 | .Input("mkl_scale: uint8" ) |
1568 | .Input("mkl_offset: uint8" ) |
1569 | .Input("mkl_mean: uint8" ) |
1570 | .Input("mkl_variance: uint8" ) |
1571 | .Output("y: T" ) |
1572 | .Output("batch_mean: U" ) |
1573 | .Output("batch_variance: U" ) |
1574 | .Output("reserve_space_1: U" ) |
1575 | .Output("reserve_space_2: U" ) |
1576 | .Output("reserve_space_3: U" ) |
1577 | .Output("mkl_y: uint8" ) |
1578 | .Output("mkl_batch_mean: uint8" ) |
1579 | .Output("mkl_batch_variance: uint8" ) |
1580 | .Output("mkl_reserve_space_1: uint8" ) |
1581 | .Output("mkl_reserve_space_2: uint8" ) |
1582 | .Output("mkl_reserve_space_3: uint8" ) |
1583 | .Attr("T: {half, bfloat16, float}" ) |
1584 | .Attr("U: {float}" ) |
1585 | .Attr("epsilon: float = 0.0001" ) |
1586 | .Attr(GetConvnetDataFormatAttrString()) |
1587 | .Attr("exponential_avg_factor: float = 1.0" ) |
1588 | .Attr("is_training: bool = true" ) |
1589 | .SetShapeFn(shape_inference::FusedBatchNormShape) |
1590 | .Doc( |
1591 | R"doc(MKL-DNN implementation of FusedBatchNormV3: Do not invoke this operator directly in Python. |
1592 | Graph rewrite pass is expected to invoke this operator.)doc" ); |
1593 | |
1594 | REGISTER_OP("_MklFusedBatchNormGradV3" ) |
1595 | .Input("y_backprop: T" ) |
1596 | .Input("x: T" ) |
1597 | .Input("scale: float" ) |
1598 | .Input("reserve_space_1: U" ) |
1599 | .Input("reserve_space_2: U" ) |
1600 | .Input("reserve_space_3: U" ) |
1601 | .Input("mkl_y_backprop: uint8" ) |
1602 | .Input("mkl_x: uint8" ) |
1603 | .Input("mkl_scale: uint8" ) |
1604 | .Input("mkl_reserve_space_1: uint8" ) |
1605 | .Input("mkl_reserve_space_2: uint8" ) |
1606 | .Input("mkl_reserve_space_3: uint8" ) |
1607 | .Output("x_backprop: T" ) |
1608 | .Output("scale_backprop: U" ) |
1609 | .Output("offset_backprop: U" ) |
1610 | .Output("reserve_space_4: U" ) |
1611 | .Output("reserve_space_5: U" ) |
1612 | .Output("mkl_x_backprop: uint8" ) |
1613 | .Output("mkl_scale_backprop: uint8" ) |
1614 | .Output("mkl_offset_backprop: uint8" ) |
1615 | .Output("mkl_reserve_space_4: uint8" ) |
1616 | .Output("mkl_reserve_space_5: uint8" ) |
1617 | .Attr("T: {half, bfloat16, float}" ) |
1618 | .Attr("U: {float}" ) |
1619 | .Attr("epsilon: float = 0.0001" ) |
1620 | .Attr(GetConvnetDataFormatAttrString()) |
1621 | .Attr("is_training: bool = true" ) |
1622 | .SetShapeFn(shape_inference::FusedBatchNormGradShape) |
1623 | .Doc( |
1624 | R"doc(MKL-DNN implementation of FusedBatchNormGradV3: Do not invoke this operator directly in Python. |
1625 | Graph rewrite pass is expected to invoke this operator.)doc" ); |
1626 | |
1627 | REGISTER_OP("_MklFusedBatchNormEx" ) |
1628 | .Input("x: T" ) |
1629 | .Input("scale: U" ) |
1630 | .Input("offset: U" ) |
1631 | .Input("mean: U" ) |
1632 | .Input("variance: U" ) |
1633 | .Input("side_input: num_side_inputs * T" ) |
1634 | .Input("mkl_x: uint8" ) |
1635 | .Input("mkl_scale: uint8" ) |
1636 | .Input("mkl_offset: uint8" ) |
1637 | .Input("mkl_mean: uint8" ) |
1638 | .Input("mkl_variance: uint8" ) |
1639 | .Input("mkl_side_input: num_side_inputs * uint8" ) |
1640 | .Output("y: T" ) |
1641 | .Output("batch_mean: U" ) |
1642 | .Output("batch_variance: U" ) |
1643 | .Output("reserve_space_1: U" ) |
1644 | .Output("reserve_space_2: U" ) |
1645 | .Output("reserve_space_3: U" ) |
1646 | .Output("mkl_y: uint8" ) |
1647 | .Output("mkl_batch_mean: uint8" ) |
1648 | .Output("mkl_batch_variance: uint8" ) |
1649 | .Output("mkl_reserve_space_1: uint8" ) |
1650 | .Output("mkl_reserve_space_2: uint8" ) |
1651 | .Output("mkl_reserve_space_3: uint8" ) |
1652 | .Attr("T: {bfloat16, float}" ) |
1653 | .Attr("U: {float}" ) |
1654 | .Attr("epsilon: float = 0.0001" ) |
1655 | .Attr("exponential_avg_factor: float = 1.0" ) |
1656 | .Attr(GetConvnetDataFormatAttrString()) |
1657 | .Attr("num_side_inputs: int >= 0 = 0" ) |
1658 | .Attr("activation_mode: string = \"Identity\"" ) |
1659 | .Attr("is_training: bool = true" ) |
1660 | .SetShapeFn(shape_inference::FusedBatchNormShape) |
1661 | .Doc(R"doc( |
1662 | MKL version of FusedBatchNormEx operator. Uses MKL DNN APIs to perform fused |
1663 | batch normalization and relu. |
1664 | |
1665 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
1666 | expected to invoke these operators. |
1667 | )doc" ); |
1668 | |
1669 | REGISTER_OP("_MklNativeFusedBatchNorm" ) |
1670 | .Input("x: T" ) |
1671 | .Input("scale: T" ) |
1672 | .Input("offset: T" ) |
1673 | .Input("mean: T" ) |
1674 | .Input("variance: T" ) |
1675 | .Output("y: T" ) |
1676 | .Output("batch_mean: T" ) |
1677 | .Output("batch_variance: T" ) |
1678 | .Output("reserve_space_1: T" ) |
1679 | .Output("reserve_space_2: T" ) |
1680 | .Attr("T: numbertype" ) |
1681 | .Attr("epsilon: float = 0.0001" ) |
1682 | .Attr("data_format: string = 'NHWC'" ) |
1683 | .Attr("exponential_avg_factor: float = 1.0" ) |
1684 | .Attr("is_training: bool = true" ) |
1685 | .SetShapeFn(shape_inference::FusedBatchNormShape) |
1686 | .Doc(R"doc( |
1687 | oneDNN version of FusedBatchNorm operator that does not depend on layout |
1688 | propagation. Uses oneDNN APIs to perform fused batch normalization. |
1689 | |
1690 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
1691 | expected to invoke these operators. |
1692 | )doc" ); |
1693 | |
1694 | REGISTER_OP("_MklNativeFusedBatchNormGrad" ) |
1695 | .Input("y_backprop: T" ) |
1696 | .Input("x: T" ) |
1697 | .Input("scale: T" ) |
1698 | .Input("reserve_space_1: T" ) |
1699 | .Input("reserve_space_2: T" ) |
1700 | .Output("x_backprop: T" ) |
1701 | .Output("scale_backprop: T" ) |
1702 | .Output("offset_backprop: T" ) |
1703 | .Output("reserve_space_3: T" ) |
1704 | .Output("reserve_space_4: T" ) |
1705 | .Attr("T: numbertype" ) |
1706 | .Attr("epsilon: float = 0.0001" ) |
1707 | .Attr("data_format: string = 'NHWC'" ) |
1708 | .Attr("is_training: bool = true" ) |
1709 | .SetShapeFn(shape_inference::FusedBatchNormGradShape) |
1710 | .Doc(R"doc( |
1711 | oneDNN version of FusedBatchNormGrad operator that does not depend |
1712 | on layout propagation. Uses oneDNN APIs to compute gradients for fused |
1713 | batch normalization. |
1714 | |
1715 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
1716 | expected to invoke these operators. |
1717 | )doc" ); |
1718 | |
1719 | REGISTER_OP("_MklNativeFusedBatchNormV2" ) |
1720 | .Input("x: T" ) |
1721 | .Input("scale: U" ) |
1722 | .Input("offset: U" ) |
1723 | .Input("mean: U" ) |
1724 | .Input("variance: U" ) |
1725 | .Output("y: T" ) |
1726 | .Output("batch_mean: U" ) |
1727 | .Output("batch_variance: U" ) |
1728 | .Output("reserve_space_1: U" ) |
1729 | .Output("reserve_space_2: U" ) |
1730 | .Attr("T: {bfloat16, float}" ) |
1731 | .Attr("U: {float}" ) |
1732 | .Attr("epsilon: float = 0.0001" ) |
1733 | .Attr(GetConvnetDataFormatAttrString()) |
1734 | .Attr("exponential_avg_factor: float = 1.0" ) |
1735 | .Attr("is_training: bool = true" ) |
1736 | .SetShapeFn(shape_inference::FusedBatchNormShape); |
1737 | |
1738 | REGISTER_OP("_MklNativeFusedBatchNormGradV2" ) |
1739 | .Input("y_backprop: T" ) |
1740 | .Input("x: T" ) |
1741 | .Input("scale: float" ) |
1742 | .Input("reserve_space_1: U" ) |
1743 | .Input("reserve_space_2: U" ) |
1744 | .Output("x_backprop: T" ) |
1745 | .Output("scale_backprop: U" ) |
1746 | .Output("offset_backprop: U" ) |
1747 | .Output("reserve_space_3: U" ) |
1748 | .Output("reserve_space_4: U" ) |
1749 | .Attr("T: {bfloat16, float}" ) |
1750 | .Attr("U: {float}" ) |
1751 | .Attr("epsilon: float = 0.0001" ) |
1752 | .Attr(GetConvnetDataFormatAttrString()) |
1753 | .Attr("is_training: bool = true" ) |
1754 | .SetShapeFn(shape_inference::FusedBatchNormGradShape); |
1755 | |
1756 | REGISTER_OP("_MklNativeFusedBatchNormV3" ) |
1757 | .Input("x: T" ) |
1758 | .Input("scale: U" ) |
1759 | .Input("offset: U" ) |
1760 | .Input("mean: U" ) |
1761 | .Input("variance: U" ) |
1762 | .Output("y: T" ) |
1763 | .Output("batch_mean: U" ) |
1764 | .Output("batch_variance: U" ) |
1765 | .Output("reserve_space_1: U" ) |
1766 | .Output("reserve_space_2: U" ) |
1767 | .Output("reserve_space_3: U" ) |
1768 | .Attr("T: {half, bfloat16, float}" ) |
1769 | .Attr("U: {float}" ) |
1770 | .Attr("epsilon: float = 0.0001" ) |
1771 | .Attr(GetConvnetDataFormatAttrString()) |
1772 | .Attr("exponential_avg_factor: float = 1.0" ) |
1773 | .Attr("is_training: bool = true" ) |
1774 | .SetShapeFn(shape_inference::FusedBatchNormShape) |
1775 | .Doc( |
1776 | R"doc(oneDNN version of FusedBatchNormV3 operator that does not depend |
1777 | on layout propagation. Do not invoke this operator directly in Python. |
1778 | Graph rewrite pass is expected to invoke this operator.)doc" ); |
1779 | |
1780 | REGISTER_OP("_MklNativeFusedBatchNormGradV3" ) |
1781 | .Input("y_backprop: T" ) |
1782 | .Input("x: T" ) |
1783 | .Input("scale: float" ) |
1784 | .Input("reserve_space_1: U" ) |
1785 | .Input("reserve_space_2: U" ) |
1786 | .Input("reserve_space_3: U" ) |
1787 | .Output("x_backprop: T" ) |
1788 | .Output("scale_backprop: U" ) |
1789 | .Output("offset_backprop: U" ) |
1790 | .Output("reserve_space_4: U" ) |
1791 | .Output("reserve_space_5: U" ) |
1792 | .Attr("T: {half, bfloat16, float}" ) |
1793 | .Attr("U: {float}" ) |
1794 | .Attr("epsilon: float = 0.0001" ) |
1795 | .Attr(GetConvnetDataFormatAttrString()) |
1796 | .Attr("is_training: bool = true" ) |
1797 | .SetShapeFn(shape_inference::FusedBatchNormGradShape) |
1798 | .Doc( |
1799 | R"doc(oneDNN version of FusedBatchNormGradV3 that does not depend |
1800 | on layout propagation. Do not invoke this operator directly in Python. |
1801 | Graph rewrite pass is expected to invoke this operator.)doc" ); |
1802 | |
1803 | REGISTER_OP("_MklNativeFusedBatchNormEx" ) |
1804 | .Input("x: T" ) |
1805 | .Input("scale: U" ) |
1806 | .Input("offset: U" ) |
1807 | .Input("mean: U" ) |
1808 | .Input("variance: U" ) |
1809 | .Input("side_input: num_side_inputs * T" ) |
1810 | .Output("y: T" ) |
1811 | .Output("batch_mean: U" ) |
1812 | .Output("batch_variance: U" ) |
1813 | .Output("reserve_space_1: U" ) |
1814 | .Output("reserve_space_2: U" ) |
1815 | .Output("reserve_space_3: U" ) |
1816 | .Attr("T: {bfloat16, float}" ) |
1817 | .Attr("U: {float}" ) |
1818 | .Attr("epsilon: float = 0.0001" ) |
1819 | .Attr("exponential_avg_factor: float = 1.0" ) |
1820 | .Attr(GetConvnetDataFormatAttrString()) |
1821 | .Attr("num_side_inputs: int >= 0 = 0" ) |
1822 | .Attr("activation_mode: string = \"Identity\"" ) |
1823 | .Attr("is_training: bool = true" ) |
1824 | .SetShapeFn(shape_inference::FusedBatchNormShape) |
1825 | .Doc(R"doc( |
1826 | oneDNN version of FusedBatchNormEx operator that does not depend on layout propagation. |
1827 | Uses oneDNN APIs to perform fused batch normalization and relu. |
1828 | |
1829 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
1830 | expected to invoke these operators. |
1831 | )doc" ); |
1832 | |
1833 | REGISTER_OP("_MklFusedMish" ) |
1834 | .Input("features: T" ) |
1835 | .Output("activations: T" ) |
1836 | .Attr("T: {bfloat16, float}" ) |
1837 | .SetShapeFn(shape_inference::UnchangedShape) |
1838 | .Doc(R"doc( |
1839 | oneDNN version of the Mish operator. Uses oneDNN APIs to implement Mish operator. |
1840 | |
1841 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is expected |
1842 | to invoke these operators. |
1843 | )doc" ); |
1844 | |
1845 | REGISTER_OP("_MklFusedBatchMatMulV2" ) |
1846 | .Input("x: T" ) |
1847 | .Input("y: T" ) |
1848 | .Input("args: num_args * T" ) |
1849 | .Output("output: T" ) |
1850 | .Attr("T: {bfloat16, float}" ) |
1851 | .Attr("adj_x: bool = false" ) |
1852 | .Attr("adj_y: bool = false" ) |
1853 | .Attr("num_args: int >= 0" ) |
1854 | .Attr("fused_ops: list(string) = []" ) |
1855 | .SetShapeFn(shape_inference::BatchMatMulV2Shape) |
1856 | .Doc(R"doc( |
1857 | *NOTE*: Do not invoke this operator directly in Python. Grappler is |
1858 | expected to create these operators. |
1859 | )doc" ); |
1860 | |
1861 | REGISTER_OP("_MklSwish" ) |
1862 | .Input("features: T" ) |
1863 | .Output("activations: T" ) |
1864 | .Attr("T: {float, bfloat16} = DT_FLOAT" ) |
1865 | .SetShapeFn(shape_inference::UnchangedShape) |
1866 | .Doc(R"doc( |
1867 | MKL version of Swish operator. Uses MKL DNN APIs to implement Swish operator. |
1868 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
1869 | expected to invoke these operators. |
1870 | )doc" ); |
1871 | |
1872 | REGISTER_OP("_MklLayerNorm" ) |
1873 | .Input("x: T" ) |
1874 | .Input("scale: T" ) |
1875 | .Input("offset: T" ) |
1876 | .Output("y: T" ) |
1877 | .Attr("T: {float, bfloat16}" ) |
1878 | .Attr("epsilon: float = 0.001" ) |
1879 | .SetShapeFn(shape_inference::UnchangedShape); |
1880 | |
1881 | } // namespace tensorflow |
1882 | |
1883 | #endif // INTEL_MKL |
1884 | |