1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/numeric_op.h"
18#include "tensorflow/core/framework/op.h"
19#include "tensorflow/core/framework/shape_inference.h"
20#include "tensorflow/core/util/mirror_pad_mode.h"
21#include "tensorflow/core/util/padding.h"
22#include "tensorflow/core/util/tensor_format.h"
23
24// For now, this file only includes MKL quantized ops. In the
25// future, we will move all other MKL ops from nn_ops.cc to this file.
26
27#ifdef INTEL_MKL
28
29namespace tensorflow {
30
31using shape_inference::DimensionHandle;
32using shape_inference::InferenceContext;
33using shape_inference::ShapeHandle;
34
35REGISTER_OP("_MklNativeConv3D")
36 .Input("input: T")
37 .Input("filter: T")
38 .Output("output: T")
39 .Attr("T: {bfloat16, float}")
40 .Attr("strides: list(int) >= 5")
41 .Attr("is_filter_const: bool = false")
42 .Attr(GetPaddingAttrString())
43 .Attr(GetConvnet3dDataFormatAttrString())
44 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
45 .SetShapeFn(shape_inference::Conv3DShape)
46 .Doc(R"doc(
47MKL version of Conv3D operator that does not depend on layout propagation.
48Uses oneDNN APIs to perform 3D convolution.
49
50*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
51expected to invoke these operators.
52)doc");
53
54REGISTER_OP("_MklNativeConv3DBackpropInputV2")
55 .Input("input_sizes: Tshape")
56 .Input("filter: T")
57 .Input("out_backprop: T")
58 .Output("output: T")
59 .Attr("T: {bfloat16, float}")
60 .Attr("strides: list(int) >= 5")
61 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
62 .Attr("Tshape: {int32, int64} = DT_INT32")
63 .Attr(GetPaddingAttrString())
64 .Attr(GetConvnet3dDataFormatAttrString())
65 .SetShapeFn([](InferenceContext* c) {
66 ShapeHandle s;
67 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
68 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
69 c->set_output(0, s);
70 return Status::OK();
71 })
72 .Doc(R"doc(
73MKL version of Convolution3D backward input op that does not depend on layout
74propagation. Uses oneDNN APIs to compute the gradients of convolution with
75respect to the input.
76
77*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
78expected to invoke these operators.
79)doc");
80
81REGISTER_OP("_MklNativeConv3DBackpropFilterV2")
82 .Input("input: T")
83 .Input("filter_sizes: int32")
84 .Input("out_backprop: T")
85 .Output("output: T")
86 .Attr("T: {bfloat16, float}")
87 .Attr("strides: list(int)")
88 .Attr(GetPaddingAttrString())
89 .Attr(GetConvnet3dDataFormatAttrString())
90 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
91 .SetShapeFn([](InferenceContext* c) {
92 ShapeHandle s;
93 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
94 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
95 c->set_output(0, s);
96 return Status::OK();
97 })
98 .Doc(R"doc(
99MKL version of Conv3DBackpropFilter op that does not depend on layout
100propagation. Uses oneDNN APIs to compute the gradients of convolution
101with respect to the filter.
102
103*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
104expected to invoke these operators.
105)doc");
106
107REGISTER_OP("_MklNativeFusedConv3D")
108 .Input("input: T")
109 .Input("filter: T")
110 .Input("args: num_args * T")
111 .Output("output: T")
112 .Attr("T: {bfloat16, float}")
113 .Attr("num_args: int >= 0")
114 .Attr("strides: list(int) >= 5")
115 .Attr("is_filter_const: bool = false")
116 .Attr(GetPaddingAttrString())
117 .Attr(GetConvnet3dDataFormatAttrString())
118 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
119 .Attr("padding_list: list(int) = []")
120 .Attr("fused_ops: list(string) = []")
121 .Attr("epsilon: float = 0.0001")
122 .Attr("leakyrelu_alpha: float = 0.2")
123 .SetShapeFn(shape_inference::Conv3DShape)
124 .Doc(R"doc(
125MKL version of Conv3D operator that does not depend on layout propagation.
126Uses oneDNN APIs to perform 3D convolution.
127*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
128expected to invoke these operators.
129)doc");
130
131REGISTER_OP("_FusedConv3D")
132 .Input("input: T")
133 .Input("filter: T")
134 .Input("args: num_args * T")
135 .Output("output: T")
136 .Attr("T: {bfloat16, float}")
137 .Attr("num_args: int >= 0")
138 .Attr("strides: list(int) >= 5")
139 .Attr(GetPaddingAttrString())
140 .Attr(GetConvnet3dDataFormatAttrString())
141 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
142 .Attr("padding_list: list(int) = []")
143 .Attr("fused_ops: list(string) = []")
144 .Attr("epsilon: float = 0.0001")
145 .Attr("leakyrelu_alpha: float = 0.2")
146 .SetShapeFn(shape_inference::Conv3DShape)
147 .Doc(R"doc(
148*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
149expected to invoke these operators.
150)doc");
151
152REGISTER_OP("_MklNativeDepthwiseConv2dNative")
153 .Input("input: T")
154 .Input("filter: T")
155 .Output("output: T")
156 .Attr("T: {half, bfloat16, float, double}")
157 .Attr("strides: list(int)")
158 .Attr("is_filter_const: bool = false")
159 .Attr(GetPaddingAttrStringWithExplicit())
160 .Attr(GetConvnetDataFormatAttrString())
161 .Attr(GetExplicitPaddingsAttrString())
162 .Attr("dilations: list(int) = [1, 1, 1, 1]")
163 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShapeWithExplicitPadding);
164
165REGISTER_OP("_MklNativeDepthwiseConv2dNativeBackpropInput")
166 .Input("input_sizes: int32")
167 .Input("filter: T")
168 .Input("out_backprop: T")
169 .Output("output: T")
170 .Attr("T: {half, bfloat16, float, double}")
171 .Attr("strides: list(int)")
172 .Attr(GetPaddingAttrString())
173 .Attr(GetConvnetDataFormatAttrString())
174 .Attr(GetExplicitPaddingsAttrString())
175 .Attr("dilations: list(int) = [1, 1, 1, 1]")
176 .SetShapeFn([](InferenceContext* c) {
177 ShapeHandle s;
178 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
179 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
180 c->set_output(0, s);
181 return Status::OK();
182 });
183
184REGISTER_OP("_MklNativeDepthwiseConv2dNativeBackpropFilter")
185 .Input("input: T")
186 .Input("filter_sizes: int32")
187 .Input("out_backprop: T")
188 .Output("output: T")
189 .Attr("T: {half, bfloat16, float, double}")
190 .Attr("strides: list(int)")
191 .Attr(GetPaddingAttrString())
192 .Attr(GetConvnetDataFormatAttrString())
193 .Attr(GetExplicitPaddingsAttrString())
194 .Attr("dilations: list(int) = [1, 1, 1, 1]")
195 .SetShapeFn([](InferenceContext* c) {
196 ShapeHandle s;
197 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
198 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
199 c->set_output(0, s);
200 return Status::OK();
201 });
202
203REGISTER_OP("_MklFusedConv2D")
204 .Input("input: T")
205 .Input("filter: T")
206 .Input("args: num_args * T")
207 .Input("mkl_input: uint8")
208 .Input("mkl_filter: uint8")
209 .Input("mkl_args: num_args * uint8")
210 .Output("output: T")
211 .Output("filter_output: T")
212 .Output("mkl_output: uint8")
213 .Output("mkl_filter_output: uint8")
214 .Attr("T: {bfloat16, float}")
215 .Attr("num_args: int >= 0")
216 .Attr("strides: list(int)")
217 .Attr("is_filter_const: bool = false")
218 .Attr(GetPaddingAttrStringWithExplicit())
219 .Attr(GetConvnetDataFormatAttrString())
220 .Attr(GetExplicitPaddingsAttrString())
221 .Attr("dilations: list(int) = [1, 1, 1, 1]")
222 .Attr("use_cudnn_on_gpu: bool = true")
223 .Attr("fused_ops: list(string) = []")
224 // Attributes for the FusedBatchNorm ------------------------------------ //
225 .Attr("epsilon: float = 0.0001")
226 // Attributes for the LeakyRelu ----------------------------------------- //
227 .Attr("leakyrelu_alpha: float = 0.2")
228 // ---------------------------------------------------------------------- //
229 .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
230 .Doc(R"doc(
231*NOTE*: Do not invoke this operator directly in Python. MKL DNN graph transformer
232 is expected to create these operators.
233)doc");
234
235REGISTER_OP("_MklNativeFusedConv2D")
236 .Input("input: T")
237 .Input("filter: T")
238 .Input("args: num_args * T")
239 .Output("output: T")
240 .Attr("T: {bfloat16, float}")
241 .Attr("num_args: int >= 0")
242 .Attr("strides: list(int)")
243 .Attr("is_filter_const: bool = false")
244 .Attr(GetPaddingAttrStringWithExplicit())
245 .Attr(GetConvnetDataFormatAttrString())
246 .Attr(GetExplicitPaddingsAttrString())
247 .Attr("dilations: list(int) = [1, 1, 1, 1]")
248 .Attr("use_cudnn_on_gpu: bool = true")
249 .Attr("fused_ops: list(string) = []")
250 // Attributes for the FusedBatchNorm ------------------------------------ //
251 .Attr("epsilon: float = 0.0001")
252 // Attributes for the LeakyRelu ----------------------------------------- //
253 .Attr("leakyrelu_alpha: float = 0.2")
254 // ---------------------------------------------------------------------- //
255 .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
256 .Doc(R"doc(
257*NOTE*: Do not invoke this operator directly in Python. oneDNN graph transformer
258 is expected to create these operators.
259)doc");
260
261REGISTER_OP("_MklNativeConv2DWithBias")
262 .Input("input: T")
263 .Input("filter: T")
264 .Input("bias: T")
265 .Output("output: T")
266 .Attr("T: {bfloat16, float}")
267 .Attr("strides: list(int)")
268 .Attr("use_cudnn_on_gpu: bool = true")
269 .Attr("is_filter_const: bool = false")
270 .Attr(GetPaddingAttrStringWithExplicit())
271 .Attr(GetConvnetDataFormatAttrString())
272 .Attr(GetExplicitPaddingsAttrString())
273 .Attr("dilations: list(int) = [1, 1, 1, 1]")
274 .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
275 .Doc(R"doc(
276MKL version of Conv2D and BiasAdd operator. Uses oneDNN APIs to perform
2772D convolution and add Bias to the output of convolution.
278
279*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
280expected to invoke this operator.
281)doc");
282
283REGISTER_OP("_MklNativeConv2DBackpropFilterWithBias")
284 .Input("input: T")
285 .Input("filter_sizes: int32")
286 .Input("out_backprop: T")
287 .Output("output: T")
288 .Output("bias_grad: T")
289 .Attr("T: {bfloat16, float}")
290 .Attr("strides: list(int)")
291 .Attr("use_cudnn_on_gpu: bool = true")
292 .Attr(GetPaddingAttrString())
293 .Attr(GetConvnetDataFormatAttrString())
294 .Attr("dilations: list(int) = [1, 1, 1, 1]")
295 .SetShapeFn(shape_inference::Conv2DBackpropFilterWithBiasShape)
296 .Doc(R"doc(
297oneDNN version of Conv2DBackpropFilterWithBias. Uses oneDNN APIs to compute the
298fusion of Conv2DBackpropFilter and BiasAddGrad.
299
300*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
301expected to invoke this one.
302)doc");
303
304REGISTER_OP("_MklFusedDepthwiseConv2dNative")
305 .Input("input: T")
306 .Input("filter: T")
307 .Input("args: num_args * T")
308 .Input("mkl_input: uint8")
309 .Input("mkl_filter: uint8")
310 .Input("mkl_args: num_args * uint8")
311 .Output("output: T")
312 .Output("filter_output: T")
313 .Output("mkl_output: uint8")
314 .Output("mkl_filter_output: uint8")
315 .Attr("T: {bfloat16, float}")
316 .Attr("num_args: int >= 0")
317 .Attr("strides: list(int)")
318 .Attr("is_filter_const: bool = false")
319 .Attr(GetPaddingAttrString())
320 .Attr(GetConvnetDataFormatAttrString())
321 .Attr("dilations: list(int) = [1, 1, 1, 1]")
322 .Attr("fused_ops: list(string) = []")
323 // Attributes for the FusedBatchNorm ------------------------------------ //
324 .Attr("epsilon: float = 0.0001")
325 // Attributes for the LeakyRelu ----------------------------------------- //
326 .Attr("leakyrelu_alpha: float = 0.2")
327 // ---------------------------------------------------------------------- //
328 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
329
330REGISTER_OP("_MklNativeFusedDepthwiseConv2dNative")
331 .Input("input: T")
332 .Input("filter: T")
333 .Input("args: num_args * T")
334 .Output("output: T")
335 .Attr("T: {bfloat16, float}")
336 .Attr("num_args: int >= 0")
337 .Attr("strides: list(int)")
338 .Attr("is_filter_const: bool = false")
339 .Attr(GetPaddingAttrString())
340 .Attr(GetConvnetDataFormatAttrString())
341 .Attr("dilations: list(int) = [1, 1, 1, 1]")
342 .Attr("fused_ops: list(string) = []")
343 // Attributes for the FusedBatchNorm ------------------------------------ //
344 .Attr("epsilon: float = 0.0001")
345 // Attributes for the LeakyRelu ----------------------------------------- //
346 .Attr("leakyrelu_alpha: float = 0.2")
347 // ---------------------------------------------------------------------- //
348 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
349
350REGISTER_OP("_MklFusedMatMul")
351 .Input("a: T")
352 .Input("b: T")
353 .Input("args: num_args * T")
354 .Input("mkl_a: uint8")
355 .Input("mkl_b: uint8")
356 .Input("mkl_args: num_args * uint8")
357 .Output("product: T")
358 .Output("mkl_product: uint8")
359 .Attr("is_filter_const: bool = false")
360 .Attr("transpose_a: bool = false")
361 .Attr("transpose_b: bool = false")
362 .Attr("T: {bfloat16, float}")
363 .Attr("num_args: int >= 0")
364 .Attr("fused_ops: list(string) = []")
365 // Attributes for the FusedBatchNorm ------------------------------------ //
366 .Attr("epsilon: float = 0.0001")
367 // Attributes for the LeakyRelu ----------------------------------------- //
368 .Attr("leakyrelu_alpha: float = 0.2")
369 // ---------------------------------------------------------------------- //
370 .SetShapeFn(shape_inference::MatMulShape)
371 .Doc(R"doc(
372MKL version of FusedMatMul operator. Uses MKL-DNN APIs to implement MatMul
373operator.
374
375NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
376expected to invoke these operators.
377)doc");
378
379REGISTER_OP("_MklNativeFusedMatMul")
380 .Input("a: T")
381 .Input("b: T")
382 .Input("args: num_args * T")
383 .Output("product: T")
384 .Attr("is_filter_const: bool = false")
385 .Attr("transpose_a: bool = false")
386 .Attr("transpose_b: bool = false")
387 .Attr("T: {bfloat16, float}")
388 .Attr("num_args: int >= 0")
389 .Attr("fused_ops: list(string) = []")
390 // Attributes for the FusedBatchNorm ------------------------------------ //
391 .Attr("epsilon: float = 0.0001")
392 // Attributes for the LeakyRelu ----------------------------------------- //
393 .Attr("leakyrelu_alpha: float = 0.2")
394 // ---------------------------------------------------------------------- //
395 .SetShapeFn(shape_inference::MatMulShape)
396 .Doc(R"doc(
397oneDNN version of FusedMatMul operator that does not depend
398on layout propagation. Uses oneDNN APIs to implement MatMul fusion.
399
400*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
401expected to invoke this one.
402)doc");
403
404REGISTER_OP("__MklDummyPadWithFusedConv2D")
405 .Input("input: T")
406 .Input("filter: T")
407 .Input("args: num_args * T")
408 .Input("paddings: Tpaddings")
409 .Output("output: T")
410 .Output("filter_output: T")
411 .Output("mkl_output: uint8")
412 .Output("mkl_filter_output: uint8")
413 .Attr("T: {bfloat16, float}")
414 .Attr("num_args: int >= 0")
415 .Attr("strides: list(int)")
416 .Attr(GetPaddingAttrString())
417 .Attr(GetConvnetDataFormatAttrString())
418 .Attr("dilations: list(int) = [1, 1, 1, 1]")
419 .Attr("fused_ops: list(string) = []")
420 .Attr("Tpaddings: {int32, int64} = DT_INT32")
421 // Attributes for the FusedBatchNorm ------------------------------------ //
422 .Attr("epsilon: float = 0.0001")
423 // Attributes for the LeakyRelu ----------------------------------------- //
424 .Attr("leakyrelu_alpha: float = 0.2")
425 // ---------------------------------------------------------------------- //
426 .SetShapeFn(shape_inference::Conv2DShape)
427 .Doc(R"doc(
428*NOTE*: Do not invoke this operator directly in Python. MKL DNN graph transformer
429 is expected to create these operators.
430)doc");
431
432REGISTER_OP("_MklPadWithFusedConv2D")
433 .Input("input: T")
434 .Input("filter: T")
435 .Input("args: num_args * T")
436 .Input("paddings: Tpaddings")
437 .Input("mkl_input: uint8")
438 .Input("mkl_filter: uint8")
439 .Input("mkl_args: num_args * uint8")
440 .Input("mkl_paddings: uint8")
441 .Output("output: T")
442 .Output("filter_output: T")
443 .Output("mkl_output: uint8")
444 .Output("mkl_filter_output: uint8")
445 .Attr("T: {bfloat16, float}")
446 .Attr("num_args: int >= 0")
447 .Attr("strides: list(int)")
448 .Attr("is_filter_const: bool = false")
449 .Attr(GetPaddingAttrString())
450 .Attr(GetConvnetDataFormatAttrString())
451 .Attr("dilations: list(int) = [1, 1, 1, 1]")
452 .Attr("fused_ops: list(string) = []")
453 .Attr("Tpaddings: {int32, int64} = DT_INT32")
454 // Attributes for the FusedBatchNorm ------------------------------------ //
455 .Attr("epsilon: float = 0.0001")
456 // Attributes for the LeakyRelu ----------------------------------------- //
457 .Attr("leakyrelu_alpha: float = 0.2")
458 // ---------------------------------------------------------------------- //
459 .SetShapeFn(shape_inference::Conv2DShape)
460 .Doc(R"doc(
461*NOTE*: Do not invoke this operator directly in Python. MKL DNN graph transformer
462 is expected to create these operators.
463)doc");
464
465REGISTER_OP("_MklNativePadWithFusedConv2D")
466 .Input("input: T")
467 .Input("filter: T")
468 .Input("args: num_args * T")
469 .Input("paddings: Tpaddings")
470 .Output("output: T")
471 .Attr("T: {bfloat16, float}")
472 .Attr("num_args: int >= 0")
473 .Attr("strides: list(int)")
474 .Attr("is_filter_const: bool = false")
475 .Attr(GetPaddingAttrString())
476 .Attr(GetConvnetDataFormatAttrString())
477 .Attr("dilations: list(int) = [1, 1, 1, 1]")
478 .Attr("fused_ops: list(string) = []")
479 .Attr("Tpaddings: {int32, int64} = DT_INT32")
480 // Attributes for the FusedBatchNorm ------------------------------------ //
481 .Attr("epsilon: float = 0.0001")
482 // Attributes for the LeakyRelu ----------------------------------------- //
483 .Attr("leakyrelu_alpha: float = 0.2")
484 // ---------------------------------------------------------------------- //
485 .SetShapeFn(shape_inference::Conv2DShape)
486 .Doc(R"doc(
487*NOTE*: Do not invoke this operator directly in Python. oneDNN graph transformer
488 is expected to create these operators.
489)doc");
490
491REGISTER_OP("_MklNativePadWithConv2D")
492 .Input("input: T")
493 .Input("filter: T")
494 .Input("paddings: Tpaddings")
495 .Output("output: T")
496 .Attr("T: {bfloat16, float}")
497 .Attr("strides: list(int)")
498 .Attr("use_cudnn_on_gpu: bool = true")
499 .Attr(GetPaddingAttrString())
500 .Attr(GetConvnetDataFormatAttrString())
501 .Attr("is_filter_const: bool = false")
502 .Attr("dilations: list(int) = [1, 1, 1, 1]")
503 .Attr("Tpaddings: {int32, int64} = DT_INT32")
504 .SetShapeFn(shape_inference::Conv2DShape)
505 .Doc(R"doc(
506MKL version of Pad and Conv2D fusion that does not depend
507on layout propagation. Uses oneDNN APIs to perform
508the fusion.
509
510*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
511expected to invoke these operators.
512)doc");
513
514REGISTER_OP("_MklNativeAvgPool")
515 .Input("value: T")
516 .Output("output: T")
517 .Attr("ksize: list(int) >= 4")
518 .Attr("strides: list(int) >= 4")
519 .Attr(GetPaddingAttrString())
520 .Attr(GetConvnetDataFormatAttrString())
521 .Attr("T: {float, half, double, bfloat16}")
522 .SetShapeFn(shape_inference::AvgPoolShape)
523 .Doc(R"doc(
524oneDNN version of AvgPool operator that does not depend on layout
525propagation. Uses oneDNN APIs to perform average pooling on the input.
526
527*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
528expected to invoke these operators.
529)doc");
530
531REGISTER_OP("_MklNativeAvgPoolGrad")
532 .Input("orig_input_shape: int32")
533 .Input("grad: T")
534 .Output("output: T")
535 .Attr("ksize: list(int) >= 4")
536 .Attr("strides: list(int) >= 4")
537 .Attr(GetPaddingAttrString())
538 .Attr(GetConvnetDataFormatAttrString())
539 .Attr("T: {float, half, double, bfloat16}")
540 .SetShapeFn(shape_inference::AvgPoolGradShape)
541 .Doc(R"doc(
542oneDNN version of AvgPoolGrad operator that does not depend on layout
543propagation. Uses oneDNN APIs to compute gradients of AvgPool operator.
544
545*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
546expected to invoke these operators.
547)doc");
548
549REGISTER_OP("_MklNativeAvgPool3D")
550 .Input("value: T")
551 .Output("output: T")
552 .Attr("ksize: list(int) >= 5")
553 .Attr("strides: list(int) >= 5")
554 .Attr(GetPaddingAttrString())
555 .Attr(GetConvnet3dDataFormatAttrString())
556 .Attr("T: {float, half, double, bfloat16}")
557 .SetShapeFn(shape_inference::Pool3DShape)
558 .Doc(R"doc(
559oneDNN version of AvgPool3D operator that does not depend on layout
560propagation. Uses oneDNN APIs to perform 3D average pooling on the input.
561
562*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
563expected to invoke these operators.
564)doc");
565
566REGISTER_OP("_MklNativeAvgPool3DGrad")
567 .Input("orig_input_shape: int32")
568 .Input("grad: T")
569 .Output("output: T")
570 .Attr("ksize: list(int) >= 5")
571 .Attr("strides: list(int) >= 5")
572 .Attr(GetPaddingAttrString())
573 .Attr(GetConvnet3dDataFormatAttrString())
574 .Attr("T: {float, half, double, bfloat16}")
575 .SetShapeFn(shape_inference::AvgPool3DGradShape)
576 .Doc(R"doc(
577oneDNN version of AvgPool3DGrad operator that does not depend on layout
578propagation. Uses oneDNN APIs to compute gradients of AvgPool3D function.
579
580*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
581expected to invoke these operators.
582)doc");
583
584REGISTER_OP("_MklNativeMaxPool")
585 .Attr("T: {float, half, bfloat16} = DT_FLOAT")
586 .Attr("ksize: list(int) >= 4")
587 .Attr("strides: list(int) >= 4")
588 .Attr(GetPaddingAttrString())
589 .Attr(GetConvnetDataFormatAttrString())
590 .Attr(GetExplicitPaddingsAttrString())
591 .Attr("workspace_enabled: bool = false")
592 .Input("input: T")
593 .Output("output: T")
594 .Output("workspace: uint8")
595 .SetShapeFn(shape_inference::MaxPoolShape)
596 .Doc(R"doc(
597oneDNN version of MaxPool operator that does not depend
598on layout propagation. Uses oneDNN APIs to perform max pooling
599on the input.
600*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
601expected to invoke these operators.
602)doc");
603
604REGISTER_OP("_MklNativeMaxPoolGrad")
605 .Attr("T: {float, half, bfloat16} = DT_FLOAT")
606 .Attr("ksize: list(int) >= 4")
607 .Attr("strides: list(int) >= 4")
608 .Attr("workspace_enabled: bool = false")
609 .Attr(GetPaddingAttrString())
610 .Attr(GetConvnetDataFormatAttrString())
611 .Attr(GetExplicitPaddingsAttrString())
612 .Input("orig_input: T")
613 .Input("orig_output: T")
614 .Input("grad: T")
615 .Input("workspace: uint8")
616 .Output("output: T")
617 .SetShapeFn(shape_inference::MaxPoolGradShape)
618 .Doc(R"doc(
619oneDNN version of MaxPoolGrad that does not depend on layout propagation.
620Uses oneDNN APIs to compute gradients of MaxPool operator.
621*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
622expected to invoke these operators.
623)doc");
624
625REGISTER_OP("_MklNativeMaxPool3D")
626 .Input("input: T")
627 .Output("output: T")
628 .Output("workspace: uint8")
629 .Attr("ksize: list(int) >= 5")
630 .Attr("strides: list(int) >= 5")
631 .Attr(GetPaddingAttrString())
632 .Attr(GetConvnet3dDataFormatAttrString())
633 .Attr("T: {half, bfloat16, float}")
634 .Attr("workspace_enabled: bool = false")
635 .SetShapeFn(shape_inference::Pool3DShape)
636 .Doc(R"doc(
637oneDNN version of MaxPool3D operator that does not depend on layout propagation.
638Uses oneDNN APIs to perform 3D max pooling on the input.
639*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
640expected to invoke these operators.
641)doc");
642
643REGISTER_OP("_MklNativeMaxPool3DGrad")
644 .Input("orig_input: TInput")
645 .Input("orig_output: TInput")
646 .Input("grad: T")
647 .Input("workspace: uint8")
648 .Output("output: T")
649 .Attr("ksize: list(int) >= 5")
650 .Attr("strides: list(int) >= 5")
651 .Attr(GetPaddingAttrString())
652 .Attr(GetConvnet3dDataFormatAttrString())
653 .Attr("T: {half, bfloat16, float} = DT_FLOAT")
654 .Attr("TInput: {half, bfloat16, float} = DT_FLOAT")
655 .Attr("workspace_enabled: bool = false")
656 .SetShapeFn(shape_inference::MaxPool3DGradShape)
657 .Doc(R"doc(
658oneDNN version of MaxPool3DGrad operator that does not depend on layout
659propagation. Uses oneDNN APIs to compute gradients of MaxPool3D function.
660*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
661expected to invoke these operators.
662)doc");
663
664REGISTER_OP("_MklQuantizedMaxPool")
665 .Input("input: T")
666 .Input("min_input: float")
667 .Input("max_input: float")
668 .Output("output: T")
669 .Output("min_output: float")
670 .Output("max_output: float")
671 .Attr("T: quantizedtype")
672 .Attr("ksize: list(int) >= 4")
673 .Attr("strides: list(int) >= 4")
674 .Attr(GetPaddingAttrString())
675 .SetShapeFn(shape_inference::MaxPoolShape)
676 .Doc(R"doc(
677MKL version of QuantizedMaxPool operator. Uses MKL DNN APIs to perform max pooling
678on the quantized input.
679*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
680expected to invoke these operators.
681)doc");
682
683REGISTER_OP("_MklQuantizedAvgPool")
684 .Input("input: T")
685 .Input("min_input: float")
686 .Input("max_input: float")
687 .Output("output: T")
688 .Output("min_output: float")
689 .Output("max_output: float")
690 .Attr("T: quantizedtype")
691 .Attr("ksize: list(int) >= 4")
692 .Attr("strides: list(int) >= 4")
693 .Attr(GetPaddingAttrString())
694 .SetShapeFn(shape_inference::QuantizedAvgPoolShape)
695 .Doc(R"doc(
696MKL version of QuantizedAvgPool operator. Uses MKL DNN APIs to perform average pooling
697on the quantized input.
698*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
699expected to invoke these operators.
700)doc");
701
702REGISTER_OP("_FusedQuantizedConv2D")
703 .Input("device_inputs: Tdevice_inputs")
704 .Input("host_inputs: Thost_inputs")
705 .Output("device_outputs: Tdevice_outputs")
706 .Output("host_outputs: Thost_outputs")
707 .Attr("Tinput: quantizedtype = DT_QUINT8")
708 .Attr("Tfilter: quantizedtype = DT_QINT8")
709 .Attr("Tbias: {float, qint32} = DT_QINT32")
710 .Attr("Tsummand: {float, quint8, qint8, qint32}")
711 .Attr("out_type: quantizedtype = DT_QINT32")
712 .Attr("Tdevice_inputs: list(type) >= 0 = []")
713 .Attr("Thost_inputs: list(type) >= 0")
714 .Attr("Tdevice_outputs: list(type) >= 0 = []")
715 .Attr("Thost_outputs: list(type) >= 0")
716 .Attr("data_format: string = 'NHWC'")
717 .Attr("strides: list(int)")
718 .Attr("is_filter_const: bool = true")
719 .Attr("is_bias_const: bool = true")
720 .Attr(GetPaddingAttrStringWithExplicit())
721 .Attr(GetExplicitPaddingsAttrString())
722 .Attr("dilations: list(int) = [1, 1, 1, 1]")
723 .Attr("fused_ops: list(string) = []")
724 .Attr("alpha: float = 0.0")
725 .SetShapeFn(shape_inference::FusedQuantizedConv2DShape);
726
727REGISTER_OP("_FusedQuantizedDepthwiseConv2D")
728 .Input("device_inputs: Tdevice_inputs")
729 .Input("host_inputs: Thost_inputs")
730 .Output("device_outputs: Tdevice_outputs")
731 .Output("host_outputs: Thost_outputs")
732 .Attr("Tinput: quantizedtype = DT_QUINT8")
733 .Attr("Tfilter: quantizedtype = DT_QINT8")
734 .Attr("Tbias: {float, qint32} = DT_QINT32")
735 .Attr("Tsummand: {float, quint8, qint8, qint32}")
736 .Attr("out_type: quantizedtype = DT_QINT32")
737 .Attr("Tdevice_inputs: list(type) >= 0 = []")
738 .Attr("Thost_inputs: list(type) >= 0")
739 .Attr("Tdevice_outputs: list(type) >= 0 = []")
740 .Attr("Thost_outputs: list(type) >= 0")
741 .Attr("data_format: string = 'NHWC'")
742 .Attr("strides: list(int)")
743 .Attr("is_filter_const: bool = true")
744 .Attr("is_bias_const: bool = true")
745 .Attr(GetPaddingAttrStringWithExplicit())
746 .Attr(GetExplicitPaddingsAttrString())
747 .Attr("dilations: list(int) = [1, 1, 1, 1]")
748 .Attr("fused_ops: list(string) = []")
749 .Attr("alpha: float = 0.0")
750 .SetShapeFn(shape_inference::FusedQuantizedDepthwiseConv2D);
751
752REGISTER_OP("_MklQuantizedConv2D")
753 .Input("input: Tinput")
754 .Input("filter: Tfilter")
755 .Input("min_input: float")
756 .Input("max_input: float")
757 .Input("min_filter: float")
758 .Input("max_filter: float")
759 .Output("output: out_type")
760 .Output("min_output: float")
761 .Output("max_output: float")
762 .Attr("Tinput: quantizedtype")
763 .Attr("Tfilter: quantizedtype")
764 .Attr("out_type: quantizedtype = DT_QINT32")
765 .Attr("data_format: string = 'NHWC'")
766 .Attr("strides: list(int)")
767 .Attr("is_filter_const: bool = true")
768 .Attr(GetPaddingAttrString())
769 .Attr("dilations: list(int) = [1, 1, 1, 1]")
770 .Attr("padding_list: list(int) = []")
771 .SetShapeFn(shape_inference::QuantizedConv2DShape);
772
773// TODO(nammbash): Most of the TF_RETURN_IF_ERROR(c->WithRank) checks
774// seems to be similar and hence can be moved into a single function
775// with appropriate arguments for a cleaner design.
776REGISTER_OP("_MklQuantizedConv2DAndRequantize")
777 .Input("input: Tinput")
778 .Input("filter: Tfilter")
779 .Input("min_input: float")
780 .Input("max_input: float")
781 .Input("min_filter: float")
782 .Input("max_filter: float")
783 .Input("min_freezed_output: float")
784 .Input("max_freezed_output: float")
785 .Output("output: out_type")
786 .Output("min_output: float")
787 .Output("max_output: float")
788 .Attr("Tinput: quantizedtype")
789 .Attr("Tfilter: quantizedtype")
790 .Attr("out_type: quantizedtype = DT_QINT8")
791 .Attr("data_format: string = 'NHWC'")
792 .Attr("strides: list(int)")
793 .Attr("is_filter_const: bool = true")
794 .Attr(GetPaddingAttrString())
795 .Attr("dilations: list(int) = [1, 1, 1, 1]")
796 .Attr("padding_list: list(int) = []")
797 .SetShapeFn([](InferenceContext* c) {
798 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
799 ShapeHandle unused;
800 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
801 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
802 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused));
803 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused));
804 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
805 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
806 c->set_output(1, c->Scalar());
807 c->set_output(2, c->Scalar());
808 return Status::OK();
809 });
810
811REGISTER_OP("_MklQuantizedConv2DWithBias")
812 .Input("input: Tinput")
813 .Input("filter: Tfilter")
814 .Input("bias: float")
815 .Input("min_input: float")
816 .Input("max_input: float")
817 .Input("min_filter: float")
818 .Input("max_filter: float")
819 .Output("output: out_type")
820 .Output("min_output: float")
821 .Output("max_output: float")
822 .Attr("Tinput: quantizedtype")
823 .Attr("Tfilter: quantizedtype")
824 .Attr("out_type: quantizedtype = DT_QINT32")
825 .Attr("data_format: string = 'NHWC'")
826 .Attr("strides: list(int)")
827 .Attr("is_filter_const: bool = true")
828 .Attr("is_bias_const: bool = true")
829 .Attr(GetPaddingAttrString())
830 .Attr("dilations: list(int) = [1, 1, 1, 1]")
831 .Attr("padding_list: list(int) = []")
832 .SetShapeFn([](InferenceContext* c) {
833 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
834 ShapeHandle unused, channel;
835 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
836 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
837 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
838 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
839 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
840 c->set_output(1, channel);
841 c->set_output(2, channel);
842 return Status::OK();
843 });
844
845REGISTER_OP("_MklQuantizedConv2DWithBiasAndRequantize")
846 .Input("input: Tinput")
847 .Input("filter: Tfilter")
848 .Input("bias: Tbias")
849 .Input("min_input: float")
850 .Input("max_input: float")
851 .Input("min_filter: float")
852 .Input("max_filter: float")
853 .Input("min_freezed_output: float")
854 .Input("max_freezed_output: float")
855 .Output("output: out_type")
856 .Output("min_output: float")
857 .Output("max_output: float")
858 .Attr("Tinput: quantizedtype")
859 .Attr("Tfilter: quantizedtype")
860 .Attr("Tbias: {float, qint32}")
861 .Attr("out_type: quantizedtype = DT_QINT8")
862 .Attr("data_format: string = 'NHWC'")
863 .Attr("strides: list(int)")
864 .Attr("is_filter_const: bool = true")
865 .Attr("is_bias_const: bool = true")
866 .Attr(GetPaddingAttrString())
867 .Attr("dilations: list(int) = [1, 1, 1, 1]")
868 .Attr("padding_list: list(int) = []")
869 .SetShapeFn([](InferenceContext* c) {
870 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
871 ShapeHandle unused;
872 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
873 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
874 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
875 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused));
876 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &unused));
877 c->set_output(1, c->Scalar());
878 c->set_output(2, c->Scalar());
879 return Status::OK();
880 });
881
882REGISTER_OP("_MklQuantizedConv2DAndRelu")
883 .Input("input: Tinput")
884 .Input("filter: Tfilter")
885 .Input("min_input: float")
886 .Input("max_input: float")
887 .Input("min_filter: float")
888 .Input("max_filter: float")
889 .Output("output: out_type")
890 .Output("min_output: float")
891 .Output("max_output: float")
892 .Attr("Tinput: quantizedtype")
893 .Attr("Tfilter: quantizedtype")
894 .Attr("out_type: quantizedtype = DT_QINT32")
895 .Attr("data_format: string = 'NHWC'")
896 .Attr("strides: list(int)")
897 .Attr("is_filter_const: bool = true")
898 .Attr(GetPaddingAttrString())
899 .Attr("dilations: list(int) = [1, 1, 1, 1]")
900 .Attr("padding_list: list(int) = []")
901 .SetShapeFn([](InferenceContext* c) {
902 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
903 ShapeHandle unused, channel;
904 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
905 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
906 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel));
907 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
908 c->set_output(1, channel);
909 c->set_output(2, channel);
910 return Status::OK();
911 });
912
913REGISTER_OP("_MklQuantizedConv2DAndReluAndRequantize")
914 .Input("input: Tinput")
915 .Input("filter: Tfilter")
916 .Input("min_input: float")
917 .Input("max_input: float")
918 .Input("min_filter: float")
919 .Input("max_filter: float")
920 .Input("min_freezed_output: float")
921 .Input("max_freezed_output: float")
922 .Output("output: out_type")
923 .Output("min_output: float")
924 .Output("max_output: float")
925 .Attr("Tinput: quantizedtype")
926 .Attr("Tfilter: quantizedtype")
927 .Attr("out_type: quantizedtype = DT_QUINT8")
928 .Attr("data_format: string = 'NHWC'")
929 .Attr("strides: list(int)")
930 .Attr("is_filter_const: bool = true")
931 .Attr(GetPaddingAttrString())
932 .Attr("dilations: list(int) = [1, 1, 1, 1]")
933 .Attr("padding_list: list(int) = []")
934 .SetShapeFn([](InferenceContext* c) {
935 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
936 ShapeHandle unused;
937 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
938 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
939 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused));
940 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused));
941 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
942 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
943 c->set_output(1, c->Scalar());
944 c->set_output(2, c->Scalar());
945 return Status::OK();
946 });
947
948REGISTER_OP("_MklQuantizedConv2DWithBiasAndRelu")
949 .Input("input: Tinput")
950 .Input("filter: Tfilter")
951 .Input("bias: float")
952 .Input("min_input: float")
953 .Input("max_input: float")
954 .Input("min_filter: float")
955 .Input("max_filter: float")
956 .Output("output: out_type")
957 .Output("min_output: float")
958 .Output("max_output: float")
959 .Attr("Tinput: quantizedtype")
960 .Attr("Tfilter: quantizedtype")
961 .Attr("out_type: quantizedtype = DT_QINT32")
962 .Attr("data_format: string = 'NHWC'")
963 .Attr("strides: list(int)")
964 .Attr("is_filter_const: bool = true")
965 .Attr("is_bias_const: bool = true")
966 .Attr(GetPaddingAttrString())
967 .Attr("dilations: list(int) = [1, 1, 1, 1]")
968 .Attr("padding_list: list(int) = []")
969 .SetShapeFn([](InferenceContext* c) {
970 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
971 ShapeHandle unused, channel;
972 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
973 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
974 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
975 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
976 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
977 c->set_output(1, channel);
978 c->set_output(2, channel);
979 return Status::OK();
980 });
981
982REGISTER_OP("_MklQuantizedConv2DWithBiasAndReluAndRequantize")
983 .Input("input: Tinput")
984 .Input("filter: Tfilter")
985 .Input("bias: Tbias")
986 .Input("min_input: float")
987 .Input("max_input: float")
988 .Input("min_filter: float")
989 .Input("max_filter: float")
990 .Input("min_freezed_output: float")
991 .Input("max_freezed_output: float")
992 .Output("output: out_type")
993 .Output("min_output: float")
994 .Output("max_output: float")
995 .Attr("Tinput: quantizedtype")
996 .Attr("Tfilter: quantizedtype")
997 .Attr("Tbias: {float, qint32}")
998 .Attr("out_type: quantizedtype = DT_QUINT8")
999 .Attr("data_format: string = 'NHWC'")
1000 .Attr("strides: list(int)")
1001 .Attr("is_filter_const: bool = true")
1002 .Attr("is_bias_const: bool = true")
1003 .Attr(GetPaddingAttrString())
1004 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1005 .Attr("padding_list: list(int) = []")
1006 .SetShapeFn([](InferenceContext* c) {
1007 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
1008 ShapeHandle unused;
1009 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1010 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1011 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1012 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused));
1013 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &unused));
1014 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
1015 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
1016 c->set_output(1, c->Scalar());
1017 c->set_output(2, c->Scalar());
1018 return Status::OK();
1019 });
1020
1021REGISTER_OP("_MklQuantizedConv2DWithBiasSumAndRelu")
1022 .Input("input: Tinput")
1023 .Input("filter: Tfilter")
1024 .Input("bias: float")
1025 .Input("min_input: float")
1026 .Input("max_input: float")
1027 .Input("min_filter: float")
1028 .Input("max_filter: float")
1029 .Input("summand: float")
1030 .Output("output: out_type")
1031 .Output("min_output: float")
1032 .Output("max_output: float")
1033 .Attr("Tinput: quantizedtype")
1034 .Attr("Tfilter: quantizedtype")
1035 .Attr("out_type: quantizedtype = DT_QINT32")
1036 .Attr("data_format: string = 'NHWC'")
1037 .Attr("strides: list(int)")
1038 .Attr("is_filter_const: bool = true")
1039 .Attr("is_bias_const: bool = true")
1040 .Attr(GetPaddingAttrString())
1041 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1042 .Attr("padding_list: list(int) = []")
1043 .SetShapeFn([](InferenceContext* c) {
1044 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
1045 ShapeHandle unused, channel;
1046 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1047 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1048 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1049 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
1050 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
1051 c->set_output(1, channel);
1052 c->set_output(2, channel);
1053 return Status::OK();
1054 });
1055
1056REGISTER_OP("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize")
1057 .Input("input: Tinput")
1058 .Input("filter: Tfilter")
1059 .Input("bias: Tbias")
1060 .Input("min_input: float")
1061 .Input("max_input: float")
1062 .Input("min_filter: float")
1063 .Input("max_filter: float")
1064 .Input("min_freezed_output: float")
1065 .Input("max_freezed_output: float")
1066 .Input("summand: Tsummand")
1067 .Input("min_summand: float")
1068 .Input("max_summand: float")
1069 .Output("output: out_type")
1070 .Output("min_output: float")
1071 .Output("max_output: float")
1072 .Attr("Tinput: quantizedtype")
1073 .Attr("Tfilter: quantizedtype")
1074 .Attr("Tbias: {float, qint32}")
1075 .Attr("Tsummand: quantizedtype")
1076 .Attr("out_type: quantizedtype = DT_QUINT8")
1077 .Attr("data_format: string = 'NHWC'")
1078 .Attr("strides: list(int)")
1079 .Attr("is_filter_const: bool = true")
1080 .Attr("is_bias_const: bool = true")
1081 .Attr(GetPaddingAttrString())
1082 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1083 .Attr("padding_list: list(int) = []")
1084 .SetShapeFn([](InferenceContext* c) {
1085 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
1086 ShapeHandle unused;
1087 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1088 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1089 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1090 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused));
1091 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &unused));
1092 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
1093 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
1094 c->set_output(1, c->Scalar());
1095 c->set_output(2, c->Scalar());
1096 return Status::OK();
1097 });
1098
1099REGISTER_OP("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
1100 .Input("input: Tinput")
1101 .Input("filter: Tfilter")
1102 .Input("bias: Tbias")
1103 .Input("min_input: float")
1104 .Input("max_input: float")
1105 .Input("min_filter: float")
1106 .Input("max_filter: float")
1107 .Input("min_freezed_output: float")
1108 .Input("max_freezed_output: float")
1109 .Input("summand: Tsummand")
1110 .Input("min_summand: float")
1111 .Input("max_summand: float")
1112 .Output("output: out_type")
1113 .Output("min_output: float")
1114 .Output("max_output: float")
1115 .Attr("Tinput: quantizedtype")
1116 .Attr("Tfilter: quantizedtype")
1117 .Attr("Tbias: {float, qint32}")
1118 .Attr("Tsummand: quantizedtype")
1119 .Attr("out_type: quantizedtype = DT_QUINT8")
1120 .Attr("data_format: string = 'NHWC'")
1121 .Attr("strides: list(int)")
1122 .Attr("is_filter_const: bool = true")
1123 .Attr("is_bias_const: bool = true")
1124 .Attr(GetPaddingAttrString())
1125 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1126 .Attr("padding_list: list(int) = []")
1127 .SetShapeFn([](InferenceContext* c) {
1128 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
1129 ShapeHandle unused;
1130 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1131 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1132 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1133 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused));
1134 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &unused));
1135 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
1136 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
1137 c->set_output(1, c->Scalar());
1138 c->set_output(2, c->Scalar());
1139 return Status::OK();
1140 });
1141
1142REGISTER_OP("_MklQuantizedConv2DPerChannel")
1143 .Input("input: Tinput")
1144 .Input("filter: Tfilter")
1145 .Input("min_input: float")
1146 .Input("max_input: float")
1147 .Input("min_filter: float")
1148 .Input("max_filter: float")
1149 .Output("output: out_type")
1150 .Output("min_output: float")
1151 .Output("max_output: float")
1152 .Attr("Tinput: quantizedtype")
1153 .Attr("Tfilter: quantizedtype")
1154 .Attr("out_type: quantizedtype = DT_QINT32")
1155 .Attr("data_format: string = 'NHWC'")
1156 .Attr("strides: list(int)")
1157 .Attr("is_filter_const: bool = false")
1158 .Attr(GetPaddingAttrString())
1159 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1160 .Attr("padding_list: list(int) = []")
1161 .SetShapeFn([](InferenceContext* c) {
1162 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
1163 ShapeHandle unused, channel;
1164 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1165 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1166 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel));
1167 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
1168 c->set_output(1, channel);
1169 c->set_output(2, channel);
1170 return Status::OK();
1171 })
1172 .Doc(R"doc(
1173MKL-DNN implementation of QuantizedConv2D op.
1174)doc");
1175
1176REGISTER_OP("_MklDepthwiseConv2dNativeBackpropInput")
1177 .Input("input_sizes: int32")
1178 .Input("filter: T")
1179 .Input("out_backprop: T")
1180 .Input("mkl_input: uint8")
1181 .Input("mkl_filter: uint8")
1182 .Input("mkl_out_backprop: uint8")
1183 .Output("output: T")
1184 .Output("mkl_output: uint8")
1185 .Attr("T: {half, bfloat16, float, double}")
1186 .Attr("strides: list(int)")
1187 .Attr(GetPaddingAttrString())
1188 .Attr(GetConvnetDataFormatAttrString())
1189 .Attr(GetExplicitPaddingsAttrString())
1190 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1191 .SetShapeFn([](InferenceContext* c) {
1192 ShapeHandle s;
1193 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
1194 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1195 c->set_output(0, s);
1196 return Status::OK();
1197 });
1198
1199REGISTER_OP("_MklEinsum")
1200 .Input("inputs: N * T")
1201 .Output("output: T")
1202 .Attr("equation: string")
1203 .Attr("N: int >= 1")
1204 .Attr("T: {bfloat16, float}")
1205 .SetShapeFn(shape_inference::EinsumShape);
1206
1207REGISTER_OP("_MklDepthwiseConv2dNativeBackpropFilter")
1208 .Input("input: T")
1209 .Input("filter_sizes: int32")
1210 .Input("out_backprop: T")
1211 .Input("mkl_input: uint8")
1212 .Input("mkl_filter: uint8")
1213 .Input("mkl_out_backprop: uint8")
1214 .Output("output: T")
1215 .Output("mkl_output: uint8")
1216 .Attr("T: {half, bfloat16, float, double}")
1217 .Attr("strides: list(int)")
1218 .Attr(GetPaddingAttrString())
1219 .Attr(GetConvnetDataFormatAttrString())
1220 .Attr(GetExplicitPaddingsAttrString())
1221 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1222 .SetShapeFn([](InferenceContext* c) {
1223 ShapeHandle s;
1224 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
1225 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1226 c->set_output(0, s);
1227 return Status::OK();
1228 });
1229
1230REGISTER_OP("_MklQuantizedMatMulWithBias")
1231 .Input("a: T1")
1232 .Input("b: T2")
1233 .Input("bias: Tbias")
1234 .Input("min_a: float")
1235 .Input("max_a: float")
1236 .Input("min_b: float")
1237 .Input("max_b: float")
1238 .Output("out: Toutput")
1239 .Output("min_out: float")
1240 .Output("max_out: float")
1241 .Attr("T1: quantizedtype")
1242 .Attr("T2: quantizedtype")
1243 .Attr("Tbias: {float, qint32}")
1244 .Attr("Toutput: quantizedtype = DT_QINT32")
1245 .Attr("transpose_a: bool = false")
1246 .Attr("transpose_b: bool = false")
1247 .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
1248 .Attr("is_weight_const: bool = true")
1249 .SetShapeFn([](InferenceContext* c) {
1250 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
1251 ShapeHandle unused;
1252 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1253 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1254 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1255 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1256 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
1257
1258 c->set_output(1, c->Scalar());
1259 c->set_output(2, c->Scalar());
1260 return Status::OK();
1261 });
1262
1263REGISTER_OP("_MklQuantizedMatMulWithBiasAndRelu")
1264 .Input("a: T1")
1265 .Input("b: T2")
1266 // TODO(intel-tf): Modify bias type as Tbias and add relevant attribute.
1267 .Input("bias: float")
1268 .Input("min_a: float")
1269 .Input("max_a: float")
1270 .Input("min_b: float")
1271 .Input("max_b: float")
1272 .Output("out: Toutput")
1273 .Output("min_out: float")
1274 .Output("max_out: float")
1275 .Attr("T1: quantizedtype")
1276 .Attr("T2: quantizedtype")
1277 .Attr("Toutput: quantizedtype = DT_QINT32")
1278 .Attr("transpose_a: bool = false")
1279 .Attr("transpose_b: bool = false")
1280 .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
1281 .Attr("is_weight_const: bool = true")
1282 .SetShapeFn([](InferenceContext* c) {
1283 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
1284 ShapeHandle unused;
1285 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1286 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1287 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1288 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1289 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
1290
1291 c->set_output(1, c->Scalar());
1292 c->set_output(2, c->Scalar());
1293 return Status::OK();
1294 });
1295
1296REGISTER_OP("_MklQuantizedMatMulWithBiasAndReluAndRequantize")
1297 .Input("a: T1")
1298 .Input("b: T2")
1299 .Input("bias: Tbias")
1300 .Input("min_a: float")
1301 .Input("max_a: float")
1302 .Input("min_b: float")
1303 .Input("max_b: float")
1304 .Input("min_freezed_output: float")
1305 .Input("max_freezed_output: float")
1306 .Output("out: Toutput")
1307 .Output("min_out: float")
1308 .Output("max_out: float")
1309 .Attr("T1: quantizedtype")
1310 .Attr("T2: quantizedtype")
1311 .Attr("Tbias: {float, qint32}")
1312 .Attr("Toutput: quantizedtype = DT_QUINT8")
1313 .Attr("transpose_a: bool = false")
1314 .Attr("transpose_b: bool = false")
1315 .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
1316 .Attr("is_weight_const: bool = true")
1317 .SetShapeFn([](InferenceContext* c) {
1318 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
1319 ShapeHandle unused;
1320 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1321 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1322 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1323 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1324 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
1325 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
1326 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
1327
1328 c->set_output(1, c->Scalar());
1329 c->set_output(2, c->Scalar());
1330 return Status::OK();
1331 });
1332
1333REGISTER_OP("_MklQuantizedMatMulWithBiasAndDequantize")
1334 .Input("a: T1")
1335 .Input("b: T2")
1336 .Input("bias: Tbias")
1337 .Input("min_a: float")
1338 .Input("max_a: float")
1339 .Input("min_b: float")
1340 .Input("max_b: float")
1341 .Input("min_freezed_output: float")
1342 .Input("max_freezed_output: float")
1343 .Output("out: Toutput")
1344 .Attr("T1: quantizedtype")
1345 .Attr("T2: quantizedtype")
1346 .Attr("Tbias: {float, qint32}")
1347 .Attr("Toutput: {float}")
1348 .Attr("transpose_a: bool = false")
1349 .Attr("transpose_b: bool = false")
1350 .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
1351 .Attr("is_weight_const: bool = true")
1352 .SetShapeFn([](InferenceContext* c) {
1353 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
1354 ShapeHandle unused;
1355 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1356 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1357 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1358 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1359 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
1360 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
1361 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
1362
1363 return Status::OK();
1364 });
1365
1366REGISTER_OP("_MklQuantizedMatMulWithBiasAndRequantize")
1367 .Input("a: T1")
1368 .Input("b: T2")
1369 .Input("bias: Tbias")
1370 .Input("min_a: float")
1371 .Input("max_a: float")
1372 .Input("min_b: float")
1373 .Input("max_b: float")
1374 .Input("min_freezed_output: float")
1375 .Input("max_freezed_output: float")
1376 .Output("out: Toutput")
1377 .Output("min_out: float")
1378 .Output("max_out: float")
1379 .Attr("T1: quantizedtype")
1380 .Attr("T2: quantizedtype")
1381 .Attr("Tbias: {float, qint32}")
1382 .Attr("Toutput: quantizedtype = DT_QUINT8")
1383 .Attr("transpose_a: bool = false")
1384 .Attr("transpose_b: bool = false")
1385 .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
1386 .Attr("is_weight_const: bool = true")
1387 .SetShapeFn([](InferenceContext* c) {
1388 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
1389 ShapeHandle unused;
1390 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1391 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1392 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1393 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1394 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
1395 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
1396 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
1397
1398 c->set_output(1, c->Scalar());
1399 c->set_output(2, c->Scalar());
1400 return Status::OK();
1401 });
1402
1403REGISTER_OP("_MklQuantizedDepthwiseConv2D")
1404 .Input("input: Tinput")
1405 .Input("filter: Tfilter")
1406 .Input("min_input: float")
1407 .Input("max_input: float")
1408 .Input("min_filter: float")
1409 .Input("max_filter: float")
1410 .Output("output: out_type")
1411 .Output("min_output: float")
1412 .Output("max_output: float")
1413 .Attr("Tinput: quantizedtype")
1414 .Attr("Tfilter: quantizedtype")
1415 .Attr("out_type: quantizedtype = DT_QINT32")
1416 .Attr("data_format: string = 'NHWC'")
1417 .Attr("strides: list(int)")
1418 .Attr("is_filter_const: bool = true")
1419 .Attr(GetPaddingAttrString())
1420 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1421 .SetShapeFn([](InferenceContext* c) {
1422 // TODO(bhavanis): Print an error message during the return.
1423 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
1424 ShapeHandle unused, channel;
1425 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1426 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1427 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &unused));
1428 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
1429 c->set_output(1, channel);
1430 c->set_output(2, channel);
1431 return Status::OK();
1432 })
1433 .Doc(R"doc(
1434MKL-DNN implementation of quantized depthwise Conv2D.
1435*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
1436expected to invoke this operator.
1437)doc");
1438
1439REGISTER_OP("_MklQuantizedDepthwiseConv2DWithBias")
1440 .Input("input: Tinput")
1441 .Input("filter: Tfilter")
1442 .Input("bias: float")
1443 .Input("min_input: float")
1444 .Input("max_input: float")
1445 .Input("min_filter: float")
1446 .Input("max_filter: float")
1447 .Output("output: out_type")
1448 .Output("min_output: float")
1449 .Output("max_output: float")
1450 .Attr("Tinput: quantizedtype")
1451 .Attr("Tfilter: quantizedtype")
1452 .Attr("out_type: quantizedtype = DT_QINT32")
1453 .Attr("data_format: string = 'NHWC'")
1454 .Attr("strides: list(int)")
1455 .Attr("is_filter_const: bool = true")
1456 .Attr("is_bias_const: bool = true")
1457 .Attr(GetPaddingAttrString())
1458 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1459 .SetShapeFn([](InferenceContext* c) {
1460 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
1461 ShapeHandle unused, channel;
1462 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1463 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1464 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1465 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
1466 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
1467 c->set_output(1, channel);
1468 c->set_output(2, channel);
1469 return Status::OK();
1470 })
1471 .Doc(R"doc(
1472MKL-DNN implementation of quantized depthwise Conv2D with Bias.
1473*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
1474expected to invoke this operator.
1475)doc");
1476
1477REGISTER_OP("_MklQuantizedDepthwiseConv2DWithBiasAndRelu")
1478 .Input("input: Tinput")
1479 .Input("filter: Tfilter")
1480 .Input("bias: float")
1481 .Input("min_input: float")
1482 .Input("max_input: float")
1483 .Input("min_filter: float")
1484 .Input("max_filter: float")
1485 .Output("output: out_type")
1486 .Output("min_output: float")
1487 .Output("max_output: float")
1488 .Attr("Tinput: quantizedtype")
1489 .Attr("Tfilter: quantizedtype")
1490 .Attr("out_type: quantizedtype = DT_QINT32")
1491 .Attr("data_format: string = 'NHWC'")
1492 .Attr("strides: list(int)")
1493 .Attr("is_filter_const: bool = true")
1494 .Attr("is_bias_const: bool = true")
1495 .Attr(GetPaddingAttrString())
1496 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1497 .Attr("padding_list: list(int) = []")
1498 .SetShapeFn([](InferenceContext* c) {
1499 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
1500 ShapeHandle unused, channel;
1501 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1502 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1503 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1504 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
1505 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
1506 c->set_output(1, channel);
1507 c->set_output(2, channel);
1508 return Status::OK();
1509 })
1510 .Doc(R"doc(
1511MKL-DNN implementation of quantized depthwise Conv2D with Bias and Relu.
1512*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
1513expected to invoke this operator.
1514)doc");
1515
1516REGISTER_OP("_MklQuantizedDepthwiseConv2DWithBiasAndReluAndRequantize")
1517 .Input("input: Tinput")
1518 .Input("filter: Tfilter")
1519 .Input("bias: Tbias")
1520 .Input("min_input: float")
1521 .Input("max_input: float")
1522 .Input("min_filter: float")
1523 .Input("max_filter: float")
1524 .Input("min_freezed_output: float")
1525 .Input("max_freezed_output: float")
1526 .Output("output: out_type")
1527 .Output("min_output: float")
1528 .Output("max_output: float")
1529 .Attr("Tinput: quantizedtype")
1530 .Attr("Tfilter: quantizedtype")
1531 .Attr("Tbias: {float, qint32}")
1532 .Attr("out_type: quantizedtype = DT_QUINT8")
1533 .Attr("data_format: string = 'NHWC'")
1534 .Attr("strides: list(int)")
1535 .Attr("is_filter_const: bool = true")
1536 .Attr("is_bias_const: bool = true")
1537 .Attr(GetPaddingAttrString())
1538 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1539 .Attr("padding_list: list(int) = []")
1540 .SetShapeFn([](InferenceContext* c) {
1541 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
1542 ShapeHandle unused;
1543 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1544 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1545 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1546 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &unused));
1547 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &unused));
1548 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
1549 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
1550 c->set_output(1, c->Scalar());
1551 c->set_output(2, c->Scalar());
1552 return Status::OK();
1553 })
1554 .Doc(R"doc(
1555MKL-DNN implementation of quantized depthwise Conv2D with Bias, Relu and Requantize.
1556*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
1557expected to invoke this operator.
1558)doc");
1559
1560REGISTER_OP("_MklFusedBatchNormV3")
1561 .Input("x: T")
1562 .Input("scale: U")
1563 .Input("offset: U")
1564 .Input("mean: U")
1565 .Input("variance: U")
1566 .Input("mkl_x: uint8")
1567 .Input("mkl_scale: uint8")
1568 .Input("mkl_offset: uint8")
1569 .Input("mkl_mean: uint8")
1570 .Input("mkl_variance: uint8")
1571 .Output("y: T")
1572 .Output("batch_mean: U")
1573 .Output("batch_variance: U")
1574 .Output("reserve_space_1: U")
1575 .Output("reserve_space_2: U")
1576 .Output("reserve_space_3: U")
1577 .Output("mkl_y: uint8")
1578 .Output("mkl_batch_mean: uint8")
1579 .Output("mkl_batch_variance: uint8")
1580 .Output("mkl_reserve_space_1: uint8")
1581 .Output("mkl_reserve_space_2: uint8")
1582 .Output("mkl_reserve_space_3: uint8")
1583 .Attr("T: {half, bfloat16, float}")
1584 .Attr("U: {float}")
1585 .Attr("epsilon: float = 0.0001")
1586 .Attr(GetConvnetDataFormatAttrString())
1587 .Attr("exponential_avg_factor: float = 1.0")
1588 .Attr("is_training: bool = true")
1589 .SetShapeFn(shape_inference::FusedBatchNormShape)
1590 .Doc(
1591 R"doc(MKL-DNN implementation of FusedBatchNormV3: Do not invoke this operator directly in Python.
1592 Graph rewrite pass is expected to invoke this operator.)doc");
1593
1594REGISTER_OP("_MklFusedBatchNormGradV3")
1595 .Input("y_backprop: T")
1596 .Input("x: T")
1597 .Input("scale: float")
1598 .Input("reserve_space_1: U")
1599 .Input("reserve_space_2: U")
1600 .Input("reserve_space_3: U")
1601 .Input("mkl_y_backprop: uint8")
1602 .Input("mkl_x: uint8")
1603 .Input("mkl_scale: uint8")
1604 .Input("mkl_reserve_space_1: uint8")
1605 .Input("mkl_reserve_space_2: uint8")
1606 .Input("mkl_reserve_space_3: uint8")
1607 .Output("x_backprop: T")
1608 .Output("scale_backprop: U")
1609 .Output("offset_backprop: U")
1610 .Output("reserve_space_4: U")
1611 .Output("reserve_space_5: U")
1612 .Output("mkl_x_backprop: uint8")
1613 .Output("mkl_scale_backprop: uint8")
1614 .Output("mkl_offset_backprop: uint8")
1615 .Output("mkl_reserve_space_4: uint8")
1616 .Output("mkl_reserve_space_5: uint8")
1617 .Attr("T: {half, bfloat16, float}")
1618 .Attr("U: {float}")
1619 .Attr("epsilon: float = 0.0001")
1620 .Attr(GetConvnetDataFormatAttrString())
1621 .Attr("is_training: bool = true")
1622 .SetShapeFn(shape_inference::FusedBatchNormGradShape)
1623 .Doc(
1624 R"doc(MKL-DNN implementation of FusedBatchNormGradV3: Do not invoke this operator directly in Python.
1625 Graph rewrite pass is expected to invoke this operator.)doc");
1626
1627REGISTER_OP("_MklFusedBatchNormEx")
1628 .Input("x: T")
1629 .Input("scale: U")
1630 .Input("offset: U")
1631 .Input("mean: U")
1632 .Input("variance: U")
1633 .Input("side_input: num_side_inputs * T")
1634 .Input("mkl_x: uint8")
1635 .Input("mkl_scale: uint8")
1636 .Input("mkl_offset: uint8")
1637 .Input("mkl_mean: uint8")
1638 .Input("mkl_variance: uint8")
1639 .Input("mkl_side_input: num_side_inputs * uint8")
1640 .Output("y: T")
1641 .Output("batch_mean: U")
1642 .Output("batch_variance: U")
1643 .Output("reserve_space_1: U")
1644 .Output("reserve_space_2: U")
1645 .Output("reserve_space_3: U")
1646 .Output("mkl_y: uint8")
1647 .Output("mkl_batch_mean: uint8")
1648 .Output("mkl_batch_variance: uint8")
1649 .Output("mkl_reserve_space_1: uint8")
1650 .Output("mkl_reserve_space_2: uint8")
1651 .Output("mkl_reserve_space_3: uint8")
1652 .Attr("T: {bfloat16, float}")
1653 .Attr("U: {float}")
1654 .Attr("epsilon: float = 0.0001")
1655 .Attr("exponential_avg_factor: float = 1.0")
1656 .Attr(GetConvnetDataFormatAttrString())
1657 .Attr("num_side_inputs: int >= 0 = 0")
1658 .Attr("activation_mode: string = \"Identity\"")
1659 .Attr("is_training: bool = true")
1660 .SetShapeFn(shape_inference::FusedBatchNormShape)
1661 .Doc(R"doc(
1662MKL version of FusedBatchNormEx operator. Uses MKL DNN APIs to perform fused
1663batch normalization and relu.
1664
1665NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1666expected to invoke these operators.
1667)doc");
1668
1669REGISTER_OP("_MklNativeFusedBatchNorm")
1670 .Input("x: T")
1671 .Input("scale: T")
1672 .Input("offset: T")
1673 .Input("mean: T")
1674 .Input("variance: T")
1675 .Output("y: T")
1676 .Output("batch_mean: T")
1677 .Output("batch_variance: T")
1678 .Output("reserve_space_1: T")
1679 .Output("reserve_space_2: T")
1680 .Attr("T: numbertype")
1681 .Attr("epsilon: float = 0.0001")
1682 .Attr("data_format: string = 'NHWC'")
1683 .Attr("exponential_avg_factor: float = 1.0")
1684 .Attr("is_training: bool = true")
1685 .SetShapeFn(shape_inference::FusedBatchNormShape)
1686 .Doc(R"doc(
1687oneDNN version of FusedBatchNorm operator that does not depend on layout
1688propagation. Uses oneDNN APIs to perform fused batch normalization.
1689
1690*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
1691expected to invoke these operators.
1692)doc");
1693
1694REGISTER_OP("_MklNativeFusedBatchNormGrad")
1695 .Input("y_backprop: T")
1696 .Input("x: T")
1697 .Input("scale: T")
1698 .Input("reserve_space_1: T")
1699 .Input("reserve_space_2: T")
1700 .Output("x_backprop: T")
1701 .Output("scale_backprop: T")
1702 .Output("offset_backprop: T")
1703 .Output("reserve_space_3: T")
1704 .Output("reserve_space_4: T")
1705 .Attr("T: numbertype")
1706 .Attr("epsilon: float = 0.0001")
1707 .Attr("data_format: string = 'NHWC'")
1708 .Attr("is_training: bool = true")
1709 .SetShapeFn(shape_inference::FusedBatchNormGradShape)
1710 .Doc(R"doc(
1711oneDNN version of FusedBatchNormGrad operator that does not depend
1712on layout propagation. Uses oneDNN APIs to compute gradients for fused
1713batch normalization.
1714
1715*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
1716expected to invoke these operators.
1717)doc");
1718
1719REGISTER_OP("_MklNativeFusedBatchNormV2")
1720 .Input("x: T")
1721 .Input("scale: U")
1722 .Input("offset: U")
1723 .Input("mean: U")
1724 .Input("variance: U")
1725 .Output("y: T")
1726 .Output("batch_mean: U")
1727 .Output("batch_variance: U")
1728 .Output("reserve_space_1: U")
1729 .Output("reserve_space_2: U")
1730 .Attr("T: {bfloat16, float}")
1731 .Attr("U: {float}")
1732 .Attr("epsilon: float = 0.0001")
1733 .Attr(GetConvnetDataFormatAttrString())
1734 .Attr("exponential_avg_factor: float = 1.0")
1735 .Attr("is_training: bool = true")
1736 .SetShapeFn(shape_inference::FusedBatchNormShape);
1737
1738REGISTER_OP("_MklNativeFusedBatchNormGradV2")
1739 .Input("y_backprop: T")
1740 .Input("x: T")
1741 .Input("scale: float")
1742 .Input("reserve_space_1: U")
1743 .Input("reserve_space_2: U")
1744 .Output("x_backprop: T")
1745 .Output("scale_backprop: U")
1746 .Output("offset_backprop: U")
1747 .Output("reserve_space_3: U")
1748 .Output("reserve_space_4: U")
1749 .Attr("T: {bfloat16, float}")
1750 .Attr("U: {float}")
1751 .Attr("epsilon: float = 0.0001")
1752 .Attr(GetConvnetDataFormatAttrString())
1753 .Attr("is_training: bool = true")
1754 .SetShapeFn(shape_inference::FusedBatchNormGradShape);
1755
1756REGISTER_OP("_MklNativeFusedBatchNormV3")
1757 .Input("x: T")
1758 .Input("scale: U")
1759 .Input("offset: U")
1760 .Input("mean: U")
1761 .Input("variance: U")
1762 .Output("y: T")
1763 .Output("batch_mean: U")
1764 .Output("batch_variance: U")
1765 .Output("reserve_space_1: U")
1766 .Output("reserve_space_2: U")
1767 .Output("reserve_space_3: U")
1768 .Attr("T: {half, bfloat16, float}")
1769 .Attr("U: {float}")
1770 .Attr("epsilon: float = 0.0001")
1771 .Attr(GetConvnetDataFormatAttrString())
1772 .Attr("exponential_avg_factor: float = 1.0")
1773 .Attr("is_training: bool = true")
1774 .SetShapeFn(shape_inference::FusedBatchNormShape)
1775 .Doc(
1776 R"doc(oneDNN version of FusedBatchNormV3 operator that does not depend
1777 on layout propagation. Do not invoke this operator directly in Python.
1778 Graph rewrite pass is expected to invoke this operator.)doc");
1779
1780REGISTER_OP("_MklNativeFusedBatchNormGradV3")
1781 .Input("y_backprop: T")
1782 .Input("x: T")
1783 .Input("scale: float")
1784 .Input("reserve_space_1: U")
1785 .Input("reserve_space_2: U")
1786 .Input("reserve_space_3: U")
1787 .Output("x_backprop: T")
1788 .Output("scale_backprop: U")
1789 .Output("offset_backprop: U")
1790 .Output("reserve_space_4: U")
1791 .Output("reserve_space_5: U")
1792 .Attr("T: {half, bfloat16, float}")
1793 .Attr("U: {float}")
1794 .Attr("epsilon: float = 0.0001")
1795 .Attr(GetConvnetDataFormatAttrString())
1796 .Attr("is_training: bool = true")
1797 .SetShapeFn(shape_inference::FusedBatchNormGradShape)
1798 .Doc(
1799 R"doc(oneDNN version of FusedBatchNormGradV3 that does not depend
1800 on layout propagation. Do not invoke this operator directly in Python.
1801 Graph rewrite pass is expected to invoke this operator.)doc");
1802
1803REGISTER_OP("_MklNativeFusedBatchNormEx")
1804 .Input("x: T")
1805 .Input("scale: U")
1806 .Input("offset: U")
1807 .Input("mean: U")
1808 .Input("variance: U")
1809 .Input("side_input: num_side_inputs * T")
1810 .Output("y: T")
1811 .Output("batch_mean: U")
1812 .Output("batch_variance: U")
1813 .Output("reserve_space_1: U")
1814 .Output("reserve_space_2: U")
1815 .Output("reserve_space_3: U")
1816 .Attr("T: {bfloat16, float}")
1817 .Attr("U: {float}")
1818 .Attr("epsilon: float = 0.0001")
1819 .Attr("exponential_avg_factor: float = 1.0")
1820 .Attr(GetConvnetDataFormatAttrString())
1821 .Attr("num_side_inputs: int >= 0 = 0")
1822 .Attr("activation_mode: string = \"Identity\"")
1823 .Attr("is_training: bool = true")
1824 .SetShapeFn(shape_inference::FusedBatchNormShape)
1825 .Doc(R"doc(
1826oneDNN version of FusedBatchNormEx operator that does not depend on layout propagation.
1827Uses oneDNN APIs to perform fused batch normalization and relu.
1828
1829*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
1830expected to invoke these operators.
1831)doc");
1832
1833REGISTER_OP("_MklFusedMish")
1834 .Input("features: T")
1835 .Output("activations: T")
1836 .Attr("T: {bfloat16, float}")
1837 .SetShapeFn(shape_inference::UnchangedShape)
1838 .Doc(R"doc(
1839oneDNN version of the Mish operator. Uses oneDNN APIs to implement Mish operator.
1840
1841*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is expected
1842to invoke these operators.
1843)doc");
1844
1845REGISTER_OP("_MklFusedBatchMatMulV2")
1846 .Input("x: T")
1847 .Input("y: T")
1848 .Input("args: num_args * T")
1849 .Output("output: T")
1850 .Attr("T: {bfloat16, float}")
1851 .Attr("adj_x: bool = false")
1852 .Attr("adj_y: bool = false")
1853 .Attr("num_args: int >= 0")
1854 .Attr("fused_ops: list(string) = []")
1855 .SetShapeFn(shape_inference::BatchMatMulV2Shape)
1856 .Doc(R"doc(
1857*NOTE*: Do not invoke this operator directly in Python. Grappler is
1858expected to create these operators.
1859)doc");
1860
1861REGISTER_OP("_MklSwish")
1862 .Input("features: T")
1863 .Output("activations: T")
1864 .Attr("T: {float, bfloat16} = DT_FLOAT")
1865 .SetShapeFn(shape_inference::UnchangedShape)
1866 .Doc(R"doc(
1867MKL version of Swish operator. Uses MKL DNN APIs to implement Swish operator.
1868NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1869expected to invoke these operators.
1870)doc");
1871
1872REGISTER_OP("_MklLayerNorm")
1873 .Input("x: T")
1874 .Input("scale: T")
1875 .Input("offset: T")
1876 .Output("y: T")
1877 .Attr("T: {float, bfloat16}")
1878 .Attr("epsilon: float = 0.001")
1879 .SetShapeFn(shape_inference::UnchangedShape);
1880
1881} // namespace tensorflow
1882
1883#endif // INTEL_MKL
1884